From 8e9eb882ddbef3fb2043a93d6d0553813dd2bc2b Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 1 Oct 2024 16:41:31 +0200 Subject: [PATCH] Simplify roundtrip io tests (#1702) --- tests/test_readwrite.py | 91 ++++++++++++++++++++--------------------- 1 file changed, 45 insertions(+), 46 deletions(-) diff --git a/tests/test_readwrite.py b/tests/test_readwrite.py index c8ed19d3f..04d20d272 100644 --- a/tests/test_readwrite.py +++ b/tests/test_readwrite.py @@ -3,6 +3,7 @@ import re import warnings from contextlib import contextmanager +from functools import partial from importlib.util import find_spec from pathlib import Path from string import ascii_letters @@ -25,6 +26,7 @@ if TYPE_CHECKING: from os import PathLike + from typing import Literal HERE = Path(__file__).parent @@ -658,30 +660,13 @@ def random_cats(n): assert_equal(orig, curr) -def test_write_string_types(tmp_path, diskfmt): - # https://github.com/scverse/anndata/issues/456 - adata_pth = tmp_path / f"adata.{diskfmt}" - - adata = ad.AnnData( - obs=pd.DataFrame( - np.ones((3, 2)), - columns=["a", np.str_("b")], - index=["a", "b", "c"], - ), - ) - - write = getattr(adata, f"write_{diskfmt}") - read = getattr(ad, f"read_{diskfmt}") - - write(adata_pth) - from_disk = read(adata_pth) - - assert_equal(adata, from_disk) - +def test_write_string_type_error(tmp_path, diskfmt): + adata = ad.AnnData(obs=dict(obs_names=list("abc"))) adata.obs[b"c"] = np.zeros(3) + # This should error, and tell you which key is at fault with pytest.raises(TypeError, match=r"writing key 'obs'") as exc_info: - write(adata_pth) + getattr(adata, f"write_{diskfmt}")(tmp_path / f"adata.{diskfmt}") assert "b'c'" in str(exc_info.value) @@ -722,15 +707,39 @@ def test_zarr_chunk_X(tmp_path): # Round-tripping scanpy datasets ################################ -diskfmt2 = diskfmt + +def _do_roundtrip( + adata: ad.AnnData, pth: Path, diskfmt: Literal["h5ad", "zarr"] +) -> ad.AnnData: + getattr(adata, f"write_{diskfmt}")(pth) + return getattr(ad, f"read_{diskfmt}")(pth) + + +@pytest.fixture +def roundtrip(diskfmt): + return partial(_do_roundtrip, diskfmt=diskfmt) + + +def test_write_string_types(tmp_path, diskfmt, roundtrip): + # https://github.com/scverse/anndata/issues/456 + adata_pth = tmp_path / f"adata.{diskfmt}" + + adata = ad.AnnData( + obs=pd.DataFrame( + np.ones((3, 2)), + columns=["a", np.str_("b")], + index=["a", "b", "c"], + ), + ) + + from_disk = roundtrip(adata, adata_pth) + + assert_equal(adata, from_disk) @pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed") -def test_scanpy_pbmc68k(tmp_path, diskfmt, diskfmt2): - read1 = lambda pth: getattr(ad, f"read_{diskfmt}")(pth) - write1 = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth) - read2 = lambda pth: getattr(ad, f"read_{diskfmt2}")(pth) - write2 = lambda adata, pth: getattr(adata, f"write_{diskfmt2}")(pth) +def test_scanpy_pbmc68k(tmp_path, diskfmt, roundtrip, diskfmt2): + roundtrip2 = partial(_do_roundtrip, diskfmt=diskfmt2) filepth1 = tmp_path / f"test1.{diskfmt}" filepth2 = tmp_path / f"test2.{diskfmt2}" @@ -745,17 +754,15 @@ def test_scanpy_pbmc68k(tmp_path, diskfmt, diskfmt2): warnings.simplefilter("ignore", ad.OldFormatWarning) pbmc = sc.datasets.pbmc68k_reduced() - write1(pbmc, filepth1) - from_disk1 = read1(filepth1) # Do we read okay - write2(from_disk1, filepth2) # Can we round trip - from_disk2 = read2(filepth2) + from_disk1 = roundtrip(pbmc, filepth1) # Do we read okay + from_disk2 = roundtrip2(from_disk1, filepth2) # Can we round trip assert_equal(pbmc, from_disk1) # Not expected to be exact due to `nan`s assert_equal(pbmc, from_disk2) @pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed") -def test_scanpy_krumsiek11(tmp_path, diskfmt): +def test_scanpy_krumsiek11(tmp_path, diskfmt, roundtrip): filepth = tmp_path / f"test.{diskfmt}" with warnings.catch_warnings(): warnings.filterwarnings( @@ -769,11 +776,10 @@ def test_scanpy_krumsiek11(tmp_path, diskfmt): del orig.uns["highlights"] # Can’t write int keys # Can’t write "string" dtype: https://github.com/scverse/anndata/issues/679 orig.obs["cell_type"] = orig.obs["cell_type"].astype(str) - getattr(orig, f"write_{diskfmt}")(filepth) with pytest.warns(UserWarning, match=r"Observation names are not unique"): - read = getattr(ad, f"read_{diskfmt}")(filepth) + curr = roundtrip(orig, filepth) - assert_equal(orig, read, exact=True) + assert_equal(orig, curr, exact=True) # Checking if we can read legacy zarr files @@ -808,11 +814,8 @@ def test_backwards_compat_zarr(): assert_equal(pbmc_zarr, pbmc_orig) -# TODO: use diskfmt fixture once zarr backend implemented -def test_adata_in_uns(tmp_path, diskfmt): +def test_adata_in_uns(tmp_path, diskfmt, roundtrip): pth = tmp_path / f"adatas_in_uns.{diskfmt}" - read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth) - write = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth) orig = gen_adata((4, 5)) orig.uns["adatas"] = { @@ -823,20 +826,16 @@ def test_adata_in_uns(tmp_path, diskfmt): another_one.raw = gen_adata((2, 7)) orig.uns["adatas"]["b"].uns["another_one"] = another_one - write(orig, pth) - curr = read(pth) + curr = roundtrip(orig, pth) assert_equal(orig, curr) -def test_io_dtype(tmp_path, diskfmt, dtype): +def test_io_dtype(tmp_path, diskfmt, dtype, roundtrip): pth = tmp_path / f"adata_dtype.{diskfmt}" - read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth) - write = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth) orig = ad.AnnData(np.ones((5, 8), dtype=dtype)) - write(orig, pth) - curr = read(pth) + curr = roundtrip(orig, pth) assert curr.X.dtype == dtype