Skip to content

Add scib-rapids to scverse ecosystem packages#344

Open
maarten-devries wants to merge 2 commits into
scverse:mainfrom
maarten-devries:add-scib-rapids
Open

Add scib-rapids to scverse ecosystem packages#344
maarten-devries wants to merge 2 commits into
scverse:mainfrom
maarten-devries:add-scib-rapids

Conversation

@maarten-devries

@maarten-devries maarten-devries commented Mar 20, 2026

Copy link
Copy Markdown

Name of the tool: scib-rapids

Short description: GPU-accelerated single-cell integration benchmarking metrics using RAPIDS (cuML, CuPy) as a drop-in replacement for the JAX-based metrics in scib-metrics.

How does the package use scverse data structures: scib-rapids takes AnnData objects as input throughout its API, reading embeddings, neighbors graphs, and cluster labels from the standard AnnData slots (e.g. obsm, obsp, obs), consistent with the scverse ecosystem conventions.


Mandatory

  • The code is publicly available under an OSI-approved license (BSD-3-Clause)
  • The package provides versioned releases
  • The package can be installed from a standard registry (PyPI: pip install scib-rapids)
  • Automated tests cover essential functions of the package and a reasonable range of inputs and conditions
  • Continuous integration (CI) automatically executes these tests on each push or pull request
  • The package provides API documentation via a website (https://scib-rapids.readthedocs.io/)
  • The package uses scverse datastructures where appropriate (AnnData)
  • I am an author or maintainer of the tool and agree on listing the package on the scverse website

Recommended

  • Package announcement on scverse communication channels
  • The package provides tutorials (or "vignettes") that help getting users started quickly
  • The package uses the scverse cookiecutter template

@mikkelnrasmussen

Copy link
Copy Markdown
Collaborator

Hi @maarten-devries,

Thanks for submitting scib-rapids - really cool with GPU-accelerated scib metrics. In general the package and documentation looks good. The only comment I have is regarding this CI point:

Continuous integration (CI) automatically executes these tests on each push or pull request

It looks to me that the tests are disabled on push/PR and are only run manually, because the package requires a GPU to run these tests successfully. Would it be possible to use a Github or self-hosted GPU runner for the tests?
@grst Do you have experience with tests run on GPUs in CI?

Best,
Mikkel

@grst

grst commented Mar 23, 2026

Copy link
Copy Markdown
Contributor

We made good experience with https://cirun.io/ for custom GPU runners. The problem is that you need to host them somewhere (in our case AWS) and pay for them.

Running GPU tests would of course be ideal, but with jax, couldn't they at least be run on CPU?

If there's enough interest, we could consider move the repo into the scverse org and then use our GPU runners. Or maybe there's interest in upstreaming this into scib-metris? Ping @ori-kron-wis

@ori-kron-wis

ori-kron-wis commented Mar 23, 2026

Copy link
Copy Markdown
Contributor

Thanks for this package, looking forward to testing it. I have a few questions (also wrote in the package repo):

Is there some performance benchmark information you can share? i.e comparing this implementation vs the scib-metrics (JAX-based) implementation on the same HW and data (CPU and GPU)? is there a specific version of jax/RAPIDS that either outperforms the other one, or is it true for all versions? is the benefit coming from the NN part? I would like to see a tutorial on that as well.

Is there any aim to also provide spatial transcriptomics metrics here (seems RAPIDS can do a good job there using spatialdata)?

Re: testing on GPUs, can you have a self-hosted local CUDA server to run them (this is what we do in scvi-tools)?
There is a jax[cuda12/13] version that works on GPU as well, the default works on CPU so be aware.

@maarten-devries

maarten-devries commented Mar 23, 2026

Copy link
Copy Markdown
Author

Thanks for the reviews everyone!

@mikkelnrasmussen Good point on the CI — I did indeed disable the tests on push/PR because they require a GPU. I can look into setting up a GPU runner (self-hosted or via cirun.io) to get those running automatically.

@grst To clarify — scib-rapids does not use JAX. That's actually the whole raison d'être of the package :).

@ori-kron-wis I will do some timing benchmarks and equivalency tests between scib-metrics and scib-rapids and provide them here as well. To set expectations: the goal is not necessarily that everything is faster (though that could happen as a side effect) — it's mostly to avoid the heavy JAX install dependency. This could be especially nice for people who already use rapids-singlecell anyway and don't want to pull in JAX on top of that.

@grst

grst commented Mar 23, 2026

Copy link
Copy Markdown
Contributor

@grst To clarify — scib-rapids does not use JAX. That's actually the whole raison d'être of the package :).

Thanks for clarifying. In retrospect, I don't get why I even had the idea that it were Jax-based. Maybe because I looked up scib-metrics in between.

@maarten-devries

Copy link
Copy Markdown
Author

Benchmark: scib-rapids vs scib-metrics

As requested by @mikkelnrasmussen — here are head-to-head GPU benchmark results comparing scib-rapids (CuPy/cuML) against scib-metrics (JAX) on real scRNA-seq data.

Setup:

Equivalency

All deterministic metrics match to <0.03% relative difference. Two metrics show small expected diffs explained below.

Metric n=1k metrics n=1k rapids Rel diff n=20k metrics n=20k rapids Rel diff
silhouette_label 0.5360 0.5360 0.00% 0.5117 0.5117 0.00%
silhouette_batch 0.8047 0.8047 0.00% 0.7889 0.7889 0.00%
bras 0.6793 0.6791 0.02% 0.6415 0.6415 0.00%
ilisi_knn 0.2083 0.2083 0.03% 0.0921 0.0921 0.01%
clisi_knn 0.9926 0.9926 0.00% 0.9990 0.9990 0.00%
kbet 0.1720 0.1720 0.00% 0.0258 0.0258 0.00%
kbet_per_label† 0.8559 0.8514 0.53% 0.4724 0.4681 0.90%
nmi (leiden) 0.6376 0.6376 0.00% 0.7690 0.7690 0.00%
ari (leiden) 0.3752 0.3752 0.00% 0.6367 0.6367 0.00%
nmi (kmeans)‡ 0.7901 0.7817 1.06% 0.7177 0.7165 0.18%
ari (kmeans)‡ 0.4447 0.4270 3.98% 0.3006 0.3112 3.54%
isolated_labels 0.5763 0.5763 0.00% 0.5529 0.5529 0.00%
graph_connectivity 0.9563 0.9563 0.00% 0.8704 0.8704 0.00%
pcr_comparison 0.8736 0.8736 0.00% 0.9220 0.9220 0.00%

kbet_per_label (~0.5-0.9% diff): kbet_per_label uses diffusion maps internally (eigsh decomposition → pynndescent). Small clusters with repeated eigenvalues have non-unique eigenspaces, so eigenvectors can be arbitrary rotations of each other across processes. This is a fundamental property of eigsh on small structured graphs, not an implementation difference — on identical inputs, both packages produce identical kBET test statistics and p-values.

kmeans NMI/ARI (~1-4% diff): Both packages use seed=0, but scib-metrics uses jax.random.PRNGKey(0) while scib-rapids uses np.random.default_rng(0). Different PRNG implementations produce different k-means++ initializations. Verified that given the same initial centroids, both produce identical cluster assignments and identical NMI/ARI scores — the algorithm is correct, only the random initialization differs.

Timing (seconds)

Metric n=1k metrics n=1k rapids Speedup n=20k metrics n=20k rapids Speedup
silhouette_label 3.13 0.53 5.9x 717.8 5.23 137x
silhouette_batch 26.77 0.11 248x 112.6 0.42 268x
bras 18.75 0.21 89x 77.5 0.49 157x
ilisi_knn 1.61 0.005 342x 2.13 0.016 132x
clisi_knn 0.30 0.002 178x 0.98 0.006 166x
kbet 1.04 0.008 132x 2.25 0.017 135x
kbet_per_label 17.24 3.50 4.9x 61.3 26.8 2.3x
nmi_ari (kmeans) 3.09 0.66 4.7x 8.83 3.09 2.9x
nmi_ari (leiden) 1.08 1.03 1.1x 45.5 48.7 0.9x
pcr_comparison 1.14 0.08 13.8x 1.46 0.07 19.8x
isolated_labels 0.02 0.12 0.1x 0.14 4.87 0.03x
graph_connectivity 0.05 0.05 1.0x 0.16 0.15 1.1x

Key takeaways:

  • Core batch-integration metrics (silhouette_batch, bras, lisi, kbet) see 100–340x speedups on GPU
  • Speedups grow with dataset size — silhouette_label goes from 6x at 1k to 137x at 20k cells
  • 12 of 14 metrics are numerically identical (<0.03%), confirming correctness as a drop-in replacement
  • The 2 metrics with small diffs (kmeans, kbet_per_label) are verified to be algorithmically identical — diffs come from PRNG sequences and eigenvector non-uniqueness, not implementation
  • A 400k cell run was planned but not yet completed; given the scaling trends, speedups should be even more pronounced at that scale

@ori-kron-wis

Copy link
Copy Markdown
Contributor

Thanks @maarten-devries - did you compare the jax[cuda] or jax[cpu] and which version? It wasn't stated.

@ilan-gold

Copy link
Copy Markdown
Contributor

A couple of things:

  1. We have been offering cupy-cuda{12,13}x as optional dependencies to allow users to opt in to one or the other in rapids_singlecell and annbatch, for example https://github.com/scverse/rapids-singlecell/blob/main/pyproject.toml#L35-L36
  2. While rapids_singlecell does not, I think you should use the cupy-cuda{12,13}x[ctk] extra or at least investigate it if the goal is to make installation easier with torch: cupy_backends.cuda.api.driver.CUDADriverError when torch is installed on type conversion with torch>=2.11 installed cupy/cupy#9827
  3. Following up, is installing the rapids ecosystem a step up over Jax as a dependency? I assume so since people seem to use rapids with DL setups, but also I personally ran into the issue raised above and we haven't been using the ctk extra in rapids_singlecell which makes me wonder if this may not be as battle-tested as we think it is
  4. IIUC, your current "distribution" (not the source code) is GPL2 due to the dependency on igraph. This might be worth investigating.

@maarten-devries

Copy link
Copy Markdown
Author

Hi all, sorry for the late reply. I have some time again and will address your questions asap.

@maarten-devries

Copy link
Copy Markdown
Author

Thanks everyone for the patience. I've cleaned up scib-benchmark with a proper README (methodology, raw metric values, timing tables for n=1k and n=20k, and notes on the small numerical diffs). Install commands there use uv pip install, which is also what I'd recommend for scib-rapids.

@ori-kron-wis: yes, it's jax[cuda12] — verified jax.devices()[CudaDevice(id=0)], no CPU fallback.

@ilan-gold — all four points addressed in scib-rapids:

  1. cupy-cuda12x moved out of core deps into [cu12]/[cu13] extras, e.g. uv pip install 'scib-rapids[cu12]'.
  2. ✅ The [cu12]/[cu13] extras now use cupy-cuda12x[ctk] / cupy-cuda13x[ctk], which pin cuda-toolkit to the matching major — so installing alongside torch ≥2.11 no longer mismatches CUDA majors (the exact scenario in cupy#9827).
  3. I'm not a JAX expert and there's nothing wrong with scib-metrics — I've used it successfully before. One thing I did notice is that JAX's JIT pre-compiles every computation on first call, which can slow things down in practice.
  4. igraph is gone. Leiden now runs through cugraph.leiden (Apache-2), which is license-clean and ~10× faster than the igraph path at n=20k. Full numbers in the benchmark README.

GPU CI runner is still on my list (cirun.io or self-hosted) — tackling next.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants