From 465ef1801ee7ea190a394348177c8dabae0ead42 Mon Sep 17 00:00:00 2001 From: colganwi Date: Thu, 24 Oct 2024 18:30:15 -0400 Subject: [PATCH] clades dtype argument --- .gitignore | 1 + src/pycea/tl/clades.py | 13 ++++++++++--- tests/test_clades.py | 12 ++++++++++++ 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 7b6c63b..846a601 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ .DS_Store *~ buck-out/ +.ipynb_checkpoints/ # Compiled files .venv/ diff --git a/src/pycea/tl/clades.py b/src/pycea/tl/clades.py index b1b4b68..36cb10b 100755 --- a/src/pycea/tl/clades.py +++ b/src/pycea/tl/clades.py @@ -21,11 +21,15 @@ def _nodes_at_depth(tree, parent, nodes, depth, depth_key): return nodes -def _clade_name_generator(): +def _clade_name_generator(dtype=int): """Generates clade names.""" + valid_dtypes = {"str": str, "int": int, "float": float, str: str, int: int, float: float} + if dtype not in valid_dtypes: + raise ValueError("dtype must be one of str, int, or float") + converter = valid_dtypes[dtype] i = 0 while True: - yield str(i) + yield converter(i) i += 1 @@ -65,6 +69,7 @@ def clades( clades: str | Sequence[str] = None, key_added: str = "clade", update: bool = False, + dtype: type | str = str, tree: str | Sequence[str] | None = None, copy: bool = False, ) -> None | Mapping: @@ -84,6 +89,8 @@ def clades( Key to store clades in. update If True, updates existing clades instead of overwriting. + dtype + Data type of clade names. One of `str`, `int`, or `float`. tree The `obst` key or keys of the trees to use. If `None`, all trees are used. copy @@ -107,7 +114,7 @@ def clades( if clades and len(trees) > 1: raise ValueError("Multiple trees are present. Must specify a single tree if clades are given.") # Identify clades - name_generator = _clade_name_generator() + name_generator = _clade_name_generator(dtype=dtype) lcas = [] for key, tree in trees.items(): tree_lcas = _clades(tree, depth, depth_key, clades, key_added, name_generator, update) diff --git a/tests/test_clades.py b/tests/test_clades.py index 5cc14f7..c2f434b 100755 --- a/tests/test_clades.py +++ b/tests/test_clades.py @@ -77,6 +77,18 @@ def test_clades_multiple_trees(): assert pd.isna(tdata.obs.loc["B", "test"]) +def test_clades_dtype(tdata): + clades(tdata, depth=0, dtype=int) + assert tdata.obs["clade"].dtype == int + assert tdata.obst["tree"].nodes["A"]["clade"] == 0 + clades(tdata, depth=0, dtype="int") + assert tdata.obs["clade"].dtype == int + assert tdata.obst["tree"].nodes["A"]["clade"] == 0 + clades(tdata, depth=1, dtype=float) + assert tdata.obs["clade"].dtype == float + assert tdata.obst["tree"].nodes["C"]["clade"] == 1.0 + + def test_clades_invalid(tdata): with pytest.raises(ValueError): clades(td.TreeData(), clades={"A": 0}, depth=0)