diff --git a/src/rapids_singlecell/preprocessing/_hvg.py b/src/rapids_singlecell/preprocessing/_hvg.py index 6db78c74..54d6785d 100644 --- a/src/rapids_singlecell/preprocessing/_hvg.py +++ b/src/rapids_singlecell/preprocessing/_hvg.py @@ -268,10 +268,8 @@ def in_bounds( def _hvg_expm1(X): if isinstance(X, DaskArray): - if isinstance(X._meta, cp.ndarray): - X = X.map_blocks(lambda X: cp.expm1(X), meta=_meta_dense(X.dtype)) - elif isinstance(X._meta, csr_matrix): - X = X.map_blocks(lambda X: X.expm1(), meta=_meta_sparse(X.dtype)) + meta = _meta_sparse if isinstance(X._meta, csr_matrix) else _meta_dense + X = X.map_blocks(_hvg_expm1, meta=meta(X.dtype)) else: X = X.copy() if issparse(X):