diff --git a/src/anndata/_core/aligned_mapping.py b/src/anndata/_core/aligned_mapping.py index 9df5ac977..a2fda845f 100644 --- a/src/anndata/_core/aligned_mapping.py +++ b/src/anndata/_core/aligned_mapping.py @@ -5,6 +5,7 @@ from collections.abc import MutableMapping, Sequence from copy import copy from dataclasses import dataclass +from types import NoneType from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np @@ -38,12 +39,13 @@ # TODO: pd.DataFrame only allowed in AxisArrays? Value = pd.DataFrame | spmatrix | np.ndarray +K = TypeVar("K", str, str | None) P = TypeVar("P", bound="AlignedMappingBase") """Parent mapping an AlignedView is based on.""" I = TypeVar("I", OneDIdx, TwoDIdx) -class AlignedMappingBase(MutableMapping[str, Value], ABC): +class AlignedMappingBase(MutableMapping[K, Value], ABC, Generic[K]): """\ An abstract base class for Mappings containing array-like values aligned to either one or both AnnData axes. @@ -61,13 +63,13 @@ class AlignedMappingBase(MutableMapping[str, Value], ABC): _parent: AnnData | Raw """The parent object that this mapping is aligned to.""" - def __repr__(self): - return f"{type(self).__name__} with keys: {', '.join(self.keys())}" + def __repr__(self) -> str: + return f"{type(self).__name__} with keys: {', '.join(map(repr, self.keys()))}" - def _ipython_key_completions_(self) -> list[str]: + def _ipython_key_completions_(self) -> list[K]: return list(self.keys()) - def _validate_value(self, val: Value, key: str) -> Value: + def _validate_value(self, val: Value, key: K) -> Value: """Raises an error if value is invalid""" if isinstance(val, AwkArray): warn_once( @@ -117,13 +119,14 @@ def is_view(self) -> bool: ... def parent(self) -> AnnData | Raw: return self._parent - def copy(self) -> dict[str, Value]: + def copy(self) -> dict[K, Value]: # Shallow copy for awkward array since their buffers are immutable return { - k: copy(v) if isinstance(v, AwkArray) else v.copy() for k, v in self.items() + k: copy(v) if isinstance(v, AwkArray | NoneType) else v.copy() + for k, v in self.items() } - def _view(self, parent: AnnData, subset_idx: I) -> AlignedView[Self, I]: + def _view(self, parent: AnnData, subset_idx: I) -> AlignedView[K, Self, I]: """Returns a subset copy-on-write view of the object.""" return self._view_class(self, parent, subset_idx) @@ -132,7 +135,7 @@ def as_dict(self) -> dict: return dict(self) -class AlignedView(AlignedMappingBase, Generic[P, I]): +class AlignedView(AlignedMappingBase[K], Generic[K, P, I]): is_view: ClassVar[Literal[True]] = True # override docstring @@ -156,13 +159,15 @@ def __init__(self, parent_mapping: P, parent_view: AnnData, subset_idx: I): # LayersBase has no _axis, the rest does self._axis = parent_mapping._axis # type: ignore - def __getitem__(self, key: str) -> Value: + def __getitem__(self, key: K) -> Value: + if self.parent_mapping[key] is None: + return None return as_view( _subset(self.parent_mapping[key], self.subset_idx), ElementRef(self.parent, self.attrname, (key,)), ) - def __setitem__(self, key: str, value: Value) -> None: + def __setitem__(self, key: K, value: Value) -> None: value = self._validate_value(value, key) # Validate before mutating warnings.warn( f"Setting element `.{self.attrname}['{key}']` of view, " @@ -171,9 +176,12 @@ def __setitem__(self, key: str, value: Value) -> None: stacklevel=2, ) with view_update(self.parent, self.attrname, ()) as new_mapping: - new_mapping[key] = value + if value is None: + del new_mapping[key] + else: + new_mapping[key] = value - def __delitem__(self, key: str) -> None: + def __delitem__(self, key: K) -> None: if key not in self: raise KeyError( "'{key!r}' not found in view of {self.attrname}" @@ -187,49 +195,52 @@ def __delitem__(self, key: str) -> None: with view_update(self.parent, self.attrname, ()) as new_mapping: del new_mapping[key] - def __contains__(self, key: str) -> bool: + def __contains__(self, key: K) -> bool: return key in self.parent_mapping - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[K]: return iter(self.parent_mapping) def __len__(self) -> int: return len(self.parent_mapping) -class AlignedActual(AlignedMappingBase): +class AlignedActual(AlignedMappingBase[K], Generic[K]): is_view: ClassVar[Literal[False]] = False - _data: MutableMapping[str, Value] + _data: MutableMapping[K, Value] """Underlying mapping to the data""" - def __init__(self, parent: AnnData | Raw, *, store: MutableMapping[str, Value]): + def __init__(self, parent: AnnData | Raw, *, store: MutableMapping[K, Value]): self._parent = parent self._data = store for k, v in self._data.items(): + if v is None: + continue self._data[k] = self._validate_value(v, k) - def __getitem__(self, key: str) -> Value: + def __getitem__(self, key: K) -> Value: return self._data[key] - def __setitem__(self, key: str, value: Value): - value = self._validate_value(value, key) + def __setitem__(self, key: K, value: Value): + if value is not None: + value = self._validate_value(value, key) self._data[key] = value - def __contains__(self, key: str) -> bool: + def __contains__(self, key: K) -> bool: return key in self._data - def __delitem__(self, key: str): + def __delitem__(self, key: K): del self._data[key] - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[K]: return iter(self._data) def __len__(self) -> int: return len(self._data) -class AxisArraysBase(AlignedMappingBase): +class AxisArraysBase(AlignedMappingBase[str]): """\ Mapping of key→array-like, where array-like is aligned to an axis of parent AnnData. @@ -283,7 +294,7 @@ def dim_names(self) -> pd.Index: return (self.parent.obs_names, self.parent.var_names)[self._axis] -class AxisArrays(AlignedActual, AxisArraysBase): +class AxisArrays(AlignedActual[str], AxisArraysBase): def __init__( self, parent: AnnData | Raw, @@ -297,7 +308,7 @@ def __init__( super().__init__(parent, store=store) -class AxisArraysView(AlignedView[AxisArraysBase, OneDIdx], AxisArraysBase): +class AxisArraysView(AlignedView[str, AxisArraysBase, OneDIdx], AxisArraysBase): pass @@ -305,7 +316,7 @@ class AxisArraysView(AlignedView[AxisArraysBase, OneDIdx], AxisArraysBase): AxisArraysBase._actual_class = AxisArrays -class LayersBase(AlignedMappingBase): +class LayersBase(AlignedMappingBase[str | None]): """\ Mapping of key: array-like, where array-like is aligned to both axes of the parent anndata. @@ -316,11 +327,11 @@ class LayersBase(AlignedMappingBase): axes: ClassVar[tuple[Literal[0], Literal[1]]] = (0, 1) -class Layers(AlignedActual, LayersBase): +class Layers(AlignedActual[str | None], LayersBase): pass -class LayersView(AlignedView[LayersBase, TwoDIdx], LayersBase): +class LayersView(AlignedView[str | None, LayersBase, TwoDIdx], LayersBase): pass @@ -328,7 +339,7 @@ class LayersView(AlignedView[LayersBase, TwoDIdx], LayersBase): LayersBase._actual_class = Layers -class PairwiseArraysBase(AlignedMappingBase): +class PairwiseArraysBase(AlignedMappingBase[str]): """\ Mapping of key: array-like, where both axes of array-like are aligned to one axis of the parent anndata. @@ -354,7 +365,7 @@ def dim(self) -> str: return self._dimnames[self._axis] -class PairwiseArrays(AlignedActual, PairwiseArraysBase): +class PairwiseArrays(AlignedActual[str], PairwiseArraysBase): def __init__( self, parent: AnnData, @@ -368,7 +379,9 @@ def __init__( super().__init__(parent, store=store) -class PairwiseArraysView(AlignedView[PairwiseArraysBase, OneDIdx], PairwiseArraysBase): +class PairwiseArraysView( + AlignedView[str, PairwiseArraysBase, OneDIdx], PairwiseArraysBase +): pass @@ -389,7 +402,7 @@ class PairwiseArraysView(AlignedView[PairwiseArraysBase, OneDIdx], PairwiseArray @dataclass -class AlignedMappingProperty(property, Generic[T]): +class AlignedMappingProperty(property, Generic[K, T]): """A :class:`property` that creates an ephemeral AlignedMapping. The actual data is stored as `f'_{self.name}'` in the parent object. @@ -402,7 +415,7 @@ class AlignedMappingProperty(property, Generic[T]): axis: Literal[0, 1] | None = None """Axis of the parent to align to.""" - def construct(self, obj: AnnData, *, store: MutableMapping[str, Value]) -> T: + def construct(self, obj: AnnData, *, store: MutableMapping[K, Value]) -> T: if self.axis is None: return self.cls(obj, store=store) return self.cls(obj, axis=self.axis, store=store) @@ -429,7 +442,7 @@ def __get__(self, obj: None | AnnData, objtype: type | None = None) -> T: return parent._view(obj, tuple(idxs[ax] for ax in parent.axes)) def __set__( - self, obj: AnnData, value: Mapping[str, Value] | Iterable[tuple[str, Value]] + self, obj: AnnData, value: Mapping[K, Value] | Iterable[tuple[K, Value]] | None ) -> None: value = convert_to_dict(value) _ = self.construct(obj, store=value) # Validate @@ -437,5 +450,6 @@ def __set__( obj._init_as_actual(obj.copy()) setattr(obj, f"_{self.name}", value) - def __delete__(self, obj) -> None: - setattr(obj, self.name, dict()) + def __delete__(self, obj: AnnData) -> None: + new = {None: x} if (x := getattr(obj, self.name).get(None)) is not None else {} + setattr(obj, self.name, new) diff --git a/src/anndata/_core/anndata.py b/src/anndata/_core/anndata.py index e972ba23c..94c434a23 100644 --- a/src/anndata/_core/anndata.py +++ b/src/anndata/_core/anndata.py @@ -17,14 +17,13 @@ import numpy as np import pandas as pd from natsort import natsorted -from numpy import ma from pandas.api.types import infer_dtype from scipy import sparse from scipy.sparse import issparse from .. import utils from .._settings import settings -from ..compat import DaskArray, SpArray, ZarrArray, _move_adj_mtx +from ..compat import SpArray, _move_adj_mtx from ..logging import anndata_logger as logger from ..utils import ( axis_len, @@ -32,19 +31,17 @@ ensure_df_homogeneous, raise_value_error_if_multiindex_columns, ) -from .access import ElementRef from .aligned_df import _gen_dataframe from .aligned_mapping import AlignedMappingProperty, AxisArrays, Layers, PairwiseArrays from .file_backing import AnnDataFileManager, to_memory from .index import _normalize_indices, _subset, get_vector from .raw import Raw -from .sparse_dataset import BaseCompressedSparseDataset, sparse_dataset +from .sparse_dataset import BaseCompressedSparseDataset from .storage import coerce_array from .views import ( DataFrameView, DictView, _resolve_idxs, - as_view, ) if TYPE_CHECKING: @@ -73,18 +70,6 @@ def _gen_keys_from_multicol_key(key_multicol, n_keys): return keys -def _check_2d_shape(X): - """\ - Check shape of array or sparse matrix. - - Assure that X is always 2D: Unlike numpy we always deal with 2D arrays. - """ - if X.dtype.names is None and len(X.shape) != 2: - raise ValueError( - f"X needs to be 2-dimensional, not {len(X.shape)}-dimensional." - ) - - class AnnData(metaclass=utils.DeprecationMixinMeta): """\ An annotated data matrix. @@ -228,7 +213,6 @@ def __init__( varm: np.ndarray | Mapping[str, Sequence[Any]] | None = None, layers: Mapping[str, np.ndarray | sparse.spmatrix] | None = None, raw: Mapping[str, Any] | None = None, - dtype: np.dtype | type | str | None = None, shape: tuple[int, int] | None = None, filename: PathLike | None = None, filemode: Literal["r", "r+"] | None = None, @@ -257,7 +241,6 @@ def __init__( varm=varm, raw=raw, layers=layers, - dtype=dtype, shape=shape, obsp=obsp, varp=varp, @@ -305,10 +288,6 @@ def _init_as_view(self, adata_ref: AnnData, oidx: Index, vidx: Index): self._var = DataFrameView(var_sub, view_args=(self, "var")) self._uns = uns - # set data - if self.isbacked: - self._X = None - # set raw, easy, as it’s immutable anyways... if adata_ref._raw is not None: # slicing along variables axis is ignored @@ -329,7 +308,6 @@ def _init_as_actual( obsp=None, raw=None, layers=None, - dtype=None, shape=None, filename=None, filemode=None, @@ -359,8 +337,7 @@ def _init_as_actual( raise ValueError( "If `X` is a dict no further arguments must be provided." ) - X, obs, var, uns, obsm, varm, obsp, varp, layers, raw = ( - X._X, + obs, var, uns, obsm, varm, obsp, varp, layers, raw = ( X.obs, X.var, X.uns, @@ -371,6 +348,7 @@ def _init_as_actual( X.layers, X.raw, ) + X = X.layers.get(None) # init from DataFrame elif isinstance(X, pd.DataFrame): @@ -394,28 +372,10 @@ def _init_as_actual( X = coerce_array(X, name="X") if shape is not None: raise ValueError("`shape` needs to be `None` if `X` is not `None`.") - _check_2d_shape(X) - # if type doesn’t match, a copy is made, otherwise, use a view - if dtype is not None: - warnings.warn( - "The dtype argument is deprecated and will be removed in late 2024.", - FutureWarning, - ) - if issparse(X) or isinstance(X, ma.MaskedArray): - # TODO: maybe use view on data attribute of sparse matrix - # as in readwrite.read_10x_h5 - if X.dtype != np.dtype(dtype): - X = X.astype(dtype) - elif isinstance(X, ZarrArray | DaskArray): - X = X.astype(dtype) - else: # is np.ndarray or a subclass, convert to true np.ndarray - X = np.asarray(X, dtype) # data matrix and shape - self._X = X n_obs, n_vars = X.shape source = "X" else: - self._X = None n_obs, n_vars = (None, None) if shape is None else shape source = "shape" @@ -461,13 +421,14 @@ def _init_as_actual( elif isinstance(raw, Mapping): self._raw = Raw(self, **raw) else: # is a Raw from another AnnData - self._raw = Raw(self, raw._X, raw.var, raw.varm) + self._raw = Raw(self, raw.X, raw.var, raw.varm) # clean up old formats self._clean_up_old_format(uns) # layers self.layers = layers + self.X = X def __sizeof__(self, show_stratified=None, with_disk: bool = False) -> int: def get_size(X) -> int: @@ -543,124 +504,18 @@ def shape(self) -> tuple[int, int]: @property def X(self) -> ArrayDataStructureType | None: """Data matrix of shape :attr:`n_obs` × :attr:`n_vars`.""" - if self.isbacked: - if not self.file.is_open: - self.file.open() - X = self.file["X"] - if isinstance(X, h5py.Group): - X = sparse_dataset(X) - # This is so that we can index into a backed dense dataset with - # indices that aren’t strictly increasing - if self.is_view: - X = _subset(X, (self._oidx, self._vidx)) - elif self.is_view and self._adata_ref.X is None: - X = None - elif self.is_view: - X = as_view( - _subset(self._adata_ref.X, (self._oidx, self._vidx)), - ElementRef(self, "X"), - ) - else: - X = self._X - return X - # if self.n_obs == 1 and self.n_vars == 1: - # return X[0, 0] - # elif self.n_obs == 1 or self.n_vars == 1: - # if issparse(X): X = X.toarray() - # return X.flatten() - # else: - # return X + return self.layers.get(None) @X.setter def X(self, value: np.ndarray | sparse.spmatrix | SpArray | None): - if value is None: - if self.isbacked: - raise NotImplementedError( - "Cannot currently remove data matrix from backed object." - ) - if self.is_view: - self._init_as_actual(self.copy()) - self._X = None - return - value = coerce_array(value, name="X", allow_array_like=True) - - # If indices are both arrays, we need to modify them - # so we don’t set values like coordinates - # This can occur if there are successive views - if ( - self.is_view - and isinstance(self._oidx, np.ndarray) - and isinstance(self._vidx, np.ndarray) - ): - oidx, vidx = np.ix_(self._oidx, self._vidx) - else: - oidx, vidx = self._oidx, self._vidx - if ( - np.isscalar(value) - or (hasattr(value, "shape") and (self.shape == value.shape)) - or (self.n_vars == 1 and self.n_obs == len(value)) - or (self.n_obs == 1 and self.n_vars == len(value)) - ): - if not np.isscalar(value): - if self.is_view and any( - isinstance(idx, np.ndarray) - and len(np.unique(idx)) != len(idx.ravel()) - for idx in [oidx, vidx] - ): - msg = ( - "You are attempting to set `X` to a matrix on a view which has non-unique indices. " - "The resulting `adata.X` will likely not equal the value to which you set it. " - "To avoid this potential issue, please make a copy of the data first. " - "In the future, this operation will throw an error." - ) - warnings.warn(msg, FutureWarning, stacklevel=1) - if self.shape != value.shape: - # For assigning vector of values to 2d array or matrix - # Not necessary for row of 2d array - value = value.reshape(self.shape) - if self.isbacked: - if self.is_view: - X = self.file["X"] - if isinstance(X, h5py.Group): - X = sparse_dataset(X) - X[oidx, vidx] = value - else: - self._set_backed("X", value) - else: - if self.is_view: - if sparse.issparse(self._adata_ref._X) and isinstance( - value, np.ndarray - ): - if isinstance(self._adata_ref.X, SpArray): - memory_class = sparse.coo_array - else: - memory_class = sparse.coo_matrix - value = memory_class(value) - elif sparse.issparse(value) and isinstance( - self._adata_ref._X, np.ndarray - ): - warnings.warn( - "Trying to set a dense array with a sparse array on a view." - "Densifying the sparse array." - "This may incur excessive memory usage", - stacklevel=2, - ) - value = value.toarray() - self._adata_ref._X[oidx, vidx] = value - else: - self._X = value - else: - raise ValueError( - f"Data matrix has wrong shape {value.shape}, " - f"need to be {self.shape}." - ) + self.layers[None] = value @X.deleter - def X(self): - self.X = None + def X(self) -> None: + del self.layers[None] - layers: AlignedMappingProperty[Layers | LayersView] = AlignedMappingProperty( - "layers", Layers + layers: AlignedMappingProperty[str | None, Layers | LayersView] = ( + AlignedMappingProperty("layers", Layers) ) """\ Dictionary-like object with values of the same dimensions as :attr:`X`. @@ -868,8 +723,8 @@ def uns(self, value: MutableMapping): def uns(self): self.uns = OrderedDict() - obsm: AlignedMappingProperty[AxisArrays | AxisArraysView] = AlignedMappingProperty( - "obsm", AxisArrays, 0 + obsm: AlignedMappingProperty[str, AxisArrays | AxisArraysView] = ( + AlignedMappingProperty("obsm", AxisArrays, 0) ) """\ Multi-dimensional annotation of observations @@ -880,8 +735,8 @@ def uns(self): Is sliced with `data` and `obs` but behaves otherwise like a :term:`mapping`. """ - varm: AlignedMappingProperty[AxisArrays | AxisArraysView] = AlignedMappingProperty( - "varm", AxisArrays, 1 + varm: AlignedMappingProperty[str, AxisArrays | AxisArraysView] = ( + AlignedMappingProperty("varm", AxisArrays, 1) ) """\ Multi-dimensional annotation of variables/features @@ -892,7 +747,7 @@ def uns(self): Is sliced with `data` and `var` but behaves otherwise like a :term:`mapping`. """ - obsp: AlignedMappingProperty[PairwiseArrays | PairwiseArraysView] = ( + obsp: AlignedMappingProperty[str, PairwiseArrays | PairwiseArraysView] = ( AlignedMappingProperty("obsp", PairwiseArrays, 0) ) """\ @@ -904,7 +759,7 @@ def uns(self): Is sliced with `data` and `obs` but behaves otherwise like a :term:`mapping`. """ - varp: AlignedMappingProperty[PairwiseArrays | PairwiseArraysView] = ( + varp: AlignedMappingProperty[str, PairwiseArrays | PairwiseArraysView] = ( AlignedMappingProperty("varp", PairwiseArrays, 1) ) """\ @@ -990,8 +845,8 @@ def filename(self, filename: PathLike | None): self.write(filename, as_dense=as_dense) # open new file for accessing self.file.open(filename, "r+") - # as the data is stored on disk, we can safely set self._X to None - self._X = None + # as the data is stored on disk, we can safely set self.X to None + del self.X def _set_backed(self, attr, value): from .._io.utils import write_attribute @@ -1006,7 +861,7 @@ def __delitem__(self, index: Index): obs, var = self._normalize_indices(index) # TODO: does this really work? if not self.isbacked: - del self._X[obs, var] + del self.X[obs, var] else: X = self.file["X"] del X[obs, var] @@ -1175,7 +1030,7 @@ def __setitem__( raise ValueError("Object is view and cannot be accessed with `[]`.") obs, var = self._normalize_indices(index) if not self.isbacked: - self._X[obs, var] = val + self.X[obs, var] = val else: X = self.file["X"] X[obs, var] = val @@ -1369,10 +1224,6 @@ def _mutated_copy(self, **kwargs): new[key] = kwargs[key] else: new[key] = getattr(self, key).copy() - if "X" in kwargs: - new["X"] = kwargs["X"] - elif self._has_X(): - new["X"] = self.X.copy() if "uns" in kwargs: new["uns"] = kwargs["uns"] else: diff --git a/src/anndata/_core/raw.py b/src/anndata/_core/raw.py index d138440b5..1804b3e0b 100644 --- a/src/anndata/_core/raw.py +++ b/src/anndata/_core/raw.py @@ -51,7 +51,9 @@ def __init__( self.varm = varm elif X is None: # construct from adata # Move from GPU to CPU since it's large and not always used - if isinstance(adata.X, CupyArray | CupySparseMatrix): + if adata.X is None: + self._X = None + elif isinstance(adata.X, CupyArray | CupySparseMatrix): self._X = adata.X.get() else: self._X = adata.X.copy() @@ -162,7 +164,7 @@ def to_adata(self) -> AnnData: from anndata import AnnData return AnnData( - X=self.X.copy(), + X=None if self.X is None else self.X.copy(), var=self.var.copy(), varm=None if self._varm is None else self._varm.copy(), obs=self._adata.obs.copy(), diff --git a/src/anndata/_io/h5ad.py b/src/anndata/_io/h5ad.py index edf4977cc..baa90d940 100644 --- a/src/anndata/_io/h5ad.py +++ b/src/anndata/_io/h5ad.py @@ -106,7 +106,12 @@ def write_h5ad( write_elem(f, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) write_elem(f, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) write_elem(f, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - write_elem(f, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) + write_elem( + f, + "layers", + {k: v for k, v in adata.layers.items() if k is not None}, + dataset_kwargs=dataset_kwargs, + ) write_elem(f, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) diff --git a/src/anndata/_io/specs/methods.py b/src/anndata/_io/specs/methods.py index 8e02b9283..cae2b3f94 100644 --- a/src/anndata/_io/specs/methods.py +++ b/src/anndata/_io/specs/methods.py @@ -280,7 +280,12 @@ def write_anndata( _writer.write_elem(g, "varm", dict(adata.varm), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "obsp", dict(adata.obsp), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "varp", dict(adata.varp), dataset_kwargs=dataset_kwargs) - _writer.write_elem(g, "layers", dict(adata.layers), dataset_kwargs=dataset_kwargs) + _writer.write_elem( + g, + "layers", + {k: v for k, v in adata.layers.items() if k is not None}, + dataset_kwargs=dataset_kwargs, + ) _writer.write_elem(g, "uns", dict(adata.uns), dataset_kwargs=dataset_kwargs) _writer.write_elem(g, "raw", adata.raw, dataset_kwargs=dataset_kwargs) diff --git a/src/anndata/tests/helpers.py b/src/anndata/tests/helpers.py index 6ed637ed8..f79b8b289 100644 --- a/src/anndata/tests/helpers.py +++ b/src/anndata/tests/helpers.py @@ -524,7 +524,7 @@ def subset_func(request): ################### -def format_msg(elem_name): +def format_msg(elem_name: str | None) -> str: if elem_name is not None: return f"Error raised from element {elem_name!r}." else: @@ -562,6 +562,11 @@ def _assert_equal(a, b): @singledispatch def assert_equal(a, b, exact=False, elem_name=None): + a_handler, b_handler, default_handler = map( + assert_equal.dispatch, (type(a), type(b), object) + ) + if (a_handler is default_handler) and (b_handler is not default_handler): + return assert_equal(b, a, exact=exact, elem_name=elem_name) _assert_equal(a, b, _elem_name=elem_name) @@ -654,7 +659,9 @@ def assert_equal_awkarray(a, b, exact=False, elem_name=None): @assert_equal.register(Mapping) def assert_equal_mapping(a, b, exact=False, elem_name=None): - assert set(a.keys()) == set(b.keys()), format_msg(elem_name) + assert set(a.keys()) == set(b.keys()), ( + format_msg(elem_name) + f" {a.keys()} != {b.keys()}" + ) for k in a.keys(): if elem_name is None: elem_name = "" diff --git a/tests/test_base.py b/tests/test_base.py index e1401ed74..28160b209 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -207,13 +207,15 @@ def test_convert_matrix(attr, when): assert not isinstance(arr, np.matrix), f"{arr} is still a matrix" -def test_attr_deletion(): +@pytest.mark.parametrize( + "attr", ["X", "obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns"] +) +def test_attr_deletion(attr: str): full = gen_adata((30, 30)) # Empty has just X, obs_names, var_names empty = AnnData(None, obs=full.obs[[]], var=full.var[[]]) - for attr in ["X", "obs", "var", "obsm", "varm", "obsp", "varp", "layers", "uns"]: - delattr(full, attr) - assert_equal(getattr(full, attr), getattr(empty, attr)) + delattr(full, attr) + assert_equal(getattr(full, attr), getattr(empty, attr)) assert_equal(full, empty, exact=True) diff --git a/tests/test_layers.py b/tests/test_layers.py index ba1f96e49..4b95419a0 100644 --- a/tests/test_layers.py +++ b/tests/test_layers.py @@ -19,7 +19,7 @@ def test_creation(): adata = AnnData(X=X, layers=dict(L=L.copy())) - assert list(adata.layers.keys()) == ["L"] + assert adata.layers.keys() == {"L", None} assert "L" in adata.layers assert "X" not in adata.layers assert "some_other_thing" not in adata.layers