Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): _first_pass_qc single dispatch refactor #180

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
35 changes: 35 additions & 0 deletions src/rapids_singlecell/_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

import cupy as cp
from cupyx.scipy.sparse import csr_matrix

try:
from dask.array import Array as DaskArray
except ImportError:

class DaskArray:
pass


try:
from dask.distributed import Client as DaskClient
except ImportError:

class DaskClient:
pass


def _get_dask_client(client=None):
from dask.distributed import default_client

if client is None or not isinstance(client, DaskClient):
return default_client()
return client


def _meta_dense(dtype):
return cp.zeros([0], dtype=dtype)


def _meta_sparse(dtype):
return csr_matrix(cp.array((1.0,), dtype=dtype))
3 changes: 2 additions & 1 deletion src/rapids_singlecell/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
from ._neighbors import neighbors
from ._normalize import log1p, normalize_pearson_residuals, normalize_total
from ._pca import pca
from ._qc import calculate_qc_metrics
from ._qc_refactored import calculate_qc_metrics_refactored
from ._regress_out import regress_out
from ._scale import scale
from ._scrublet import scrublet, scrublet_simulate_doublets
from ._simple import (
calculate_qc_metrics,
filter_cells,
filter_genes,
filter_highly_variable,
Expand Down
47 changes: 36 additions & 11 deletions src/rapids_singlecell/preprocessing/_hvg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import cupy as cp
import numpy as np
import pandas as pd
from cupyx.scipy.sparse import issparse
from cupyx.scipy.sparse import csr_matrix, issparse
from scanpy.get import _get_obs_rep

from ._simple import calculate_qc_metrics
from rapids_singlecell._compat import DaskArray, DaskClient, _meta_dense, _meta_sparse

from ._qc import calculate_qc_metrics
from ._utils import _check_gpu_X, _check_nonnegative_integers, _get_mean_var

if TYPE_CHECKING:
Expand Down Expand Up @@ -47,6 +49,7 @@ def highly_variable_genes(
chunksize: int = 1000,
n_samples: int = 10000,
batch_key: str = None,
client: DaskClient | None = None,
) -> None:
"""\
Annotate highly variable genes.
Expand Down Expand Up @@ -116,6 +119,8 @@ def highly_variable_genes(
of enrichment of zeros for each gene (only for `flavor='poisson_gene_selection'`).
batch_key
If specified, highly-variable genes are selected within each batch separately and merged.
client
Dask client to use for computation. If `None`, the default client is used. Only used if `X` is a Dask array.

Returns
-------
Expand Down Expand Up @@ -188,7 +193,12 @@ def highly_variable_genes(

if batch_key is None:
df = _highly_variable_genes_single_batch(
adata, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor
adata,
layer=layer,
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
client=client,
)
else:
df = _highly_variable_genes_batched(
Expand All @@ -198,6 +208,7 @@ def highly_variable_genes(
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
client=client,
)

adata.uns["hvg"] = {"flavor": flavor}
Expand Down Expand Up @@ -267,6 +278,7 @@ def _highly_variable_genes_single_batch(
cutoff: _Cutoffs | int,
n_bins: int = 20,
flavor: Literal["seurat", "cell_ranger"] = "seurat",
client: DaskClient | None = None,
) -> pd.DataFrame:
"""\
See `highly_variable_genes`.
Expand All @@ -277,18 +289,25 @@ def _highly_variable_genes_single_batch(
`highly_variable`, `means`, `dispersions`, and `dispersions_norm`.
"""
X = _get_obs_rep(adata, layer=layer)

_check_gpu_X(X, allow_dask=True)
if hasattr(X, "_view_args"): # AnnData array view
# For compatibility with anndata<0.9
X = X.copy() # Doesn't actually copy memory, just removes View class wrapper

if flavor == "seurat":
X = X.copy()
if issparse(X):
X = X.expm1()
if isinstance(X, DaskArray):
if isinstance(X._meta, cp.ndarray):
X = X.map_blocks(cp.expm1, meta=_meta_dense(X.dtype))
elif isinstance(X._meta, csr_matrix):
X = X.map_blocks(lambda X: X.expm1(), meta=_meta_sparse(X.dtype))
else:
X = cp.expm1(X)
mean, var = _get_mean_var(X, axis=0)
X = X.copy()
if issparse(X):
X = X.expm1()
else:
X = cp.expm1(X)

mean, var = _get_mean_var(X, axis=0, client=client)
mean[mean == 0] = 1e-12
disp = var / mean
if flavor == "seurat": # logarithmized mean as in Seurat
Expand Down Expand Up @@ -407,6 +426,7 @@ def _highly_variable_genes_batched(
n_bins: int,
flavor: Literal["seurat", "cell_ranger"],
cutoff: _Cutoffs | int,
client: DaskClient | None = None,
) -> pd.DataFrame:
adata._sanitize()
batches = adata.obs[batch_key].cat.categories
Expand All @@ -415,12 +435,17 @@ def _highly_variable_genes_batched(
for batch in batches:
adata_subset = adata[adata.obs[batch_key] == batch]

calculate_qc_metrics(adata_subset, layer=layer)
calculate_qc_metrics(adata_subset, layer=layer, client=client)
filt = adata_subset.var["n_cells_by_counts"].to_numpy() > 0
adata_subset = adata_subset[:, filt]

hvg = _highly_variable_genes_single_batch(
adata_subset, layer=layer, cutoff=cutoff, n_bins=n_bins, flavor=flavor
adata_subset,
layer=layer,
cutoff=cutoff,
n_bins=n_bins,
flavor=flavor,
client=client,
)
hvg.reset_index(drop=False, inplace=True, names=["gene"])

Expand Down
Loading
Loading