Skip to content

Commit

Permalink
ancestral states
Browse files Browse the repository at this point in the history
  • Loading branch information
colganwi committed May 24, 2024
1 parent e89beb6 commit bae807b
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
.. autosummary::
:toctree: generated
tl.ancestral_states
tl.clades
tl.sort
```
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
"numpy",
"pandas",
"session-info",
"scipy",
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions src/pycea/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .clades import clades
from .sort import sort
from .ancestral_states import ancestral_states
90 changes: 90 additions & 0 deletions src/pycea/tl/ancestral_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from __future__ import annotations

from collections.abc import Sequence

import networkx as nx
import numpy as np
import pandas as pd
import treedata as td

from pycea.utils import get_keyed_node_data, get_keyed_obs_data, get_trees


def _most_common(arr):
"""Finds the most common element in a list."""
unique_values, counts = np.unique(arr, return_counts=True)
most_common_index = np.argmax(counts)
return unique_values[most_common_index]


def _ancestral_states(tree, key, method="mean"):
"""Finds the ancestral state of a node in a tree."""
# Get summation function
if method == "mean":
sum_func = np.mean
elif method == "median":
sum_func = np.median
elif method == "mode":
sum_func = _most_common
else:
raise ValueError(f"Method {method} not recognized.")
# Get aggregation function
if method in ["mean", "median", "mode"]:
agg_func = np.concatenate
# infer ancestral states
for node in nx.dfs_postorder_nodes(tree):
if tree.out_degree(node) == 0:
tree.nodes[node]["_message"] = np.array([tree.nodes[node][key]])
else:
subtree_values = agg_func([tree.nodes[child]["_message"] for child in tree.successors(node)])
tree.nodes[node]["_message"] = subtree_values
tree.nodes[node][key] = sum_func(subtree_values)
# remove messages
for node in tree.nodes:
del tree.nodes[node]["_message"]


def ancestral_states(
tdata: td.TreeData,
keys: str | Sequence[str],
method: str = "mean",
tree: str | Sequence[str] | None = None,
copy: bool = False,
) -> None:
"""Reconstructs ancestral states for an attribute.
Parameters
----------
tdata
TreeData object.
keys
One or more `obs_keys`, `var_names`, `obsm_keys`, or `obsp_keys` to reconstruct.
method
Method to reconstruct ancestral states. One of "mean", "median", or "mode".
tree
The `obst` key or keys of the trees to use. If `None`, all trees are used.
copy
If True, returns a pd.DataFrame with ancestral states.
"""
if isinstance(keys, str):
keys = [keys]
tree_keys = tree
trees = get_trees(tdata, tree_keys)
for _, tree in trees.items():
data, _ = get_keyed_obs_data(tdata, keys)
for key in keys:
nx.set_node_attributes(tree, data[key].to_dict(), key)
_ancestral_states(tree, key, method)
if copy:
states = []
for name, tree in trees.items():
tree_states = []
for key in keys:
data = get_keyed_node_data(tree, key)
tree_states.append(data)
tree_states = pd.concat(tree_states, axis=1)
tree_states["tree"] = name
states.append(tree_states)
states = pd.concat(states)
states["node"] = states.index
return states.reset_index(drop=True)
2 changes: 1 addition & 1 deletion src/pycea/tl/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _sort_tree(tree, key, reverse=False):


def sort(tdata: td.TreeData, key: str, reverse: bool = False, tree: str | Sequence[str] | None = None) -> None:
"""Sorts the children of each internal node in a tree based on a given key.
"""Reorders branches based on a given key.
Parameters
----------
Expand Down
31 changes: 31 additions & 0 deletions tests/test_ancestral_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import networkx as nx
import pandas as pd
import pytest
import treedata as td

from pycea.tl.ancestral_states import ancestral_states


@pytest.fixture
def tdata():
tree1 = nx.DiGraph([("root", "B"), ("root", "C"), ("C", "D"), ("C", "E")])
tree2 = nx.DiGraph([("root", "F")])
tdata = td.TreeData(
obs=pd.DataFrame({"value": [0, 0, 3, 2], "str_value": ["0", "0", "3", "2"]}, index=["B", "D", "E", "F"]),
obst={"tree1": tree1, "tree2": tree2},
)
yield tdata


def test_ancestral_states(tdata):
# Mean
states = ancestral_states(tdata, "value", method="mean", copy=True)
assert tdata.obst["tree1"].nodes["root"]["value"] == 1
assert tdata.obst["tree1"].nodes["C"]["value"] == 1.5
assert states["value"].tolist() == [1, 0, 1.5, 0, 3, 2, 2]
# Median
states = ancestral_states(tdata, "value", method="median", copy=True)
assert tdata.obst["tree1"].nodes["root"]["value"] == 0
# Mode
ancestral_states(tdata, "str_value", method="mode", copy=False, tree="tree1")
assert tdata.obst["tree1"].nodes["root"]["str_value"] == "0"

0 comments on commit bae807b

Please sign in to comment.