From ca6e7ffc20935eb687921cb0de8422182b77b335 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 27 May 2024 18:18:52 +0200 Subject: [PATCH 01/67] Add function for calculating niches --- src/squidpy/gr/_niche.py | 130 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 src/squidpy/gr/_niche.py diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py new file mode 100644 index 00000000..d86e1c9a --- /dev/null +++ b/src/squidpy/gr/_niche.py @@ -0,0 +1,130 @@ +from spatialdata import SpatialData +from anndata import AnnData +import pandas as pd +import numpy as np +import scanpy as sc +from typing import Union, List +from sklearn.neighbors import KDTree +from scipy.spatial import cKDTree +from squidpy._docs import d + +__all__ = ["niche"] + +@d.dedent +def calculate_niche( + adata: Union[AnnData, SpatialData], + groups: str, + flavor: str, + radius: Union[float, None], + n_neighbors: Union[int, None], + limit_to: Union[str, List, None] = None, + table_key: Union[str, None] = None, + spatial_key: str = "spatial", +)-> Union[AnnData, SpatialData]: + + # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present + is_sdata = False + if isinstance(adata, SpatialData): + is_sdata = True + if table_key is not None: + table = adata.tables[table_key] + else: + if len(adata.tables) > 1: + count = 0 + for key in adata.tables.keys(): + if groups in table.obs: + count += 1 + table_key = key + if count > 1: + raise ValueError( + f"Multiple tables in `spatialdata` with group `{groups}` detected. Please specify which table to use in `table_key`." + ) + elif count == 0: + raise ValueError( + f"Group `{groups}` not found in any table in `spatialdata`. Please specify a valid group in `groups`." + ) + else: + table = adata.tables[table_key] + else: + (key, table), = adata.tables.items() + if groups not in table.obs: + raise ValueError( + f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`." + ) + else: + table = adata + + # check whether to use radius or knn for neighborhood profile calculation + if radius is None and n_neighbors is None: + raise ValueError("Either `radius` or `n_neighbors` must be provided, but both are `None`.") + if radius is not None and n_neighbors is not None: + raise ValueError("Either `radius` and `n_neighbors` must be provided, but both were provided.") + + # subset adata if only observations within specified groups are to be considered + if limit_to is not None: + if isinstance(limit_to, str): + limit_to = [limit_to] + table_subset = table[table.obs[groups].isin([limit_to])] + else: + table_subset = table + + if flavor == "neighborhood": + rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile(table, radius, n_neighbors, table_subset, spatial_key) + nhood_table = _df_to_adata(rel_nhood_profile, table_subset.obs.index) + sc.pp.neighbors(nhood_table, use_rep="X") + sc.tl.leiden(nhood_table) + table.obs["niche"] = nhood_table.obs["leiden"] + if is_sdata: + adata.tables[f"{flavor}_niche"] = nhood_table + else: + rel_nhood_profile = rel_nhood_profile.reindex(table.obs.index) + table.obsm[f"{flavor}_niche"] = rel_nhood_profile + + return + +def _calculate_neighborhood_profile( + adata: Union[AnnData, SpatialData], + groups: str, + radius: float, + n_neighbors: Union[int, None], + subset: AnnData, + spatial_key: str, +)-> tuple[pd.DataFrame, pd.DataFrame]: + + # reset index + adata.obs = adata.obs.reset_index() + + if n_neighbors is not None: + # get k-nearest neighbors for each observation + tree = KDTree(adata[spatial_key]) + _, indices = tree.query(subset[spatial_key], k=n_neighbors) + else: + # get neighbors within a given radius for each observation + tree = cKDTree(adata[spatial_key]) + indices = tree.query_ball_point(subset[spatial_key], r=radius) + + # get unique categories + category_arr = adata.obs[groups].values + unique_categories = np.unique(category_arr) + + # get obs x k matrix where each column is the category of the k-th neighbor + cat_by_id = np.take(category_arr, indices) + + # in obs x k matrix convert categorical values to numerical values + cat_indices = {category: index for index, category in enumerate(unique_categories)} + cat_values = np.vectorize(cat_indices.get)(cat_by_id) + + # For each obs calculate absolute frequency for all (not just k) categories, given the subset of categories present in obs x k matrix + m, k = cat_by_id.shape + abs_freq = np.zeros((m, len(unique_categories)), dtype=int) + np.add.at(abs_freq, (np.arange(m)[:, None], cat_values), 1) + + # normalize by n_neighbors to get relative frequency of each category + rel_freq = abs_freq / k + + return rel_freq, abs_freq + +def _df_to_adata(df: pd.DataFrame, index: pd.core.indexes.base.Index) -> AnnData: + adata = AnnData(X=df.values) + adata.obs.index = index + return adata From e08189a21d06783e27ee4cbf1a74bc93688f31ed Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 May 2024 16:25:29 +0000 Subject: [PATCH 02/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 51 +++++++++++++++++++++++----------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index d86e1c9a..0c2085da 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,15 +1,20 @@ -from spatialdata import SpatialData -from anndata import AnnData -import pandas as pd +from __future__ import annotations + +from typing import List, Union + import numpy as np +import pandas as pd import scanpy as sc -from typing import Union, List -from sklearn.neighbors import KDTree +from anndata import AnnData from scipy.spatial import cKDTree +from sklearn.neighbors import KDTree +from spatialdata import SpatialData + from squidpy._docs import d __all__ = ["niche"] + @d.dedent def calculate_niche( adata: Union[AnnData, SpatialData], @@ -20,8 +25,7 @@ def calculate_niche( limit_to: Union[str, List, None] = None, table_key: Union[str, None] = None, spatial_key: str = "spatial", -)-> Union[AnnData, SpatialData]: - +) -> Union[AnnData, SpatialData]: # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present is_sdata = False if isinstance(adata, SpatialData): @@ -42,11 +46,11 @@ def calculate_niche( elif count == 0: raise ValueError( f"Group `{groups}` not found in any table in `spatialdata`. Please specify a valid group in `groups`." - ) - else: + ) + else: table = adata.tables[table_key] else: - (key, table), = adata.tables.items() + ((key, table),) = adata.tables.items() if groups not in table.obs: raise ValueError( f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`." @@ -67,9 +71,11 @@ def calculate_niche( table_subset = table[table.obs[groups].isin([limit_to])] else: table_subset = table - + if flavor == "neighborhood": - rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile(table, radius, n_neighbors, table_subset, spatial_key) + rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( + table, radius, n_neighbors, table_subset, spatial_key + ) nhood_table = _df_to_adata(rel_nhood_profile, table_subset.obs.index) sc.pp.neighbors(nhood_table, use_rep="X") sc.tl.leiden(nhood_table) @@ -79,18 +85,18 @@ def calculate_niche( else: rel_nhood_profile = rel_nhood_profile.reindex(table.obs.index) table.obsm[f"{flavor}_niche"] = rel_nhood_profile - - return + + return + def _calculate_neighborhood_profile( - adata: Union[AnnData, SpatialData], - groups: str, - radius: float, - n_neighbors: Union[int, None], - subset: AnnData, - spatial_key: str, -)-> tuple[pd.DataFrame, pd.DataFrame]: - + adata: Union[AnnData, SpatialData], + groups: str, + radius: float, + n_neighbors: Union[int, None], + subset: AnnData, + spatial_key: str, +) -> tuple[pd.DataFrame, pd.DataFrame]: # reset index adata.obs = adata.obs.reset_index() @@ -124,6 +130,7 @@ def _calculate_neighborhood_profile( return rel_freq, abs_freq + def _df_to_adata(df: pd.DataFrame, index: pd.core.indexes.base.Index) -> AnnData: adata = AnnData(X=df.values) adata.obs.index = index From 1b492eada8e47ced9e0ac39325bd075c807d5d7c Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 27 May 2024 18:37:26 +0200 Subject: [PATCH 03/67] Fix pre-commit --- src/squidpy/gr/_niche.py | 61 ++++++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 28 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index d86e1c9a..ecc11230 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,27 +1,31 @@ -from spatialdata import SpatialData -from anndata import AnnData -import pandas as pd +from __future__ import annotations + +from typing import Any + import numpy as np +import pandas as pd import scanpy as sc -from typing import Union, List -from sklearn.neighbors import KDTree +from anndata import AnnData from scipy.spatial import cKDTree +from sklearn.neighbors import KDTree +from spatialdata import SpatialData + from squidpy._docs import d -__all__ = ["niche"] +__all__ = ["calculate_niche"] + @d.dedent def calculate_niche( - adata: Union[AnnData, SpatialData], + adata: AnnData | SpatialData, groups: str, flavor: str, - radius: Union[float, None], - n_neighbors: Union[int, None], - limit_to: Union[str, List, None] = None, - table_key: Union[str, None] = None, + radius: float | None, + n_neighbors: int | None, + limit_to: str | list[Any] | None = None, + table_key: str | None = None, spatial_key: str = "spatial", -)-> Union[AnnData, SpatialData]: - +) -> AnnData | SpatialData: # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present is_sdata = False if isinstance(adata, SpatialData): @@ -42,11 +46,11 @@ def calculate_niche( elif count == 0: raise ValueError( f"Group `{groups}` not found in any table in `spatialdata`. Please specify a valid group in `groups`." - ) - else: + ) + else: table = adata.tables[table_key] else: - (key, table), = adata.tables.items() + ((key, table),) = adata.tables.items() if groups not in table.obs: raise ValueError( f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`." @@ -67,9 +71,11 @@ def calculate_niche( table_subset = table[table.obs[groups].isin([limit_to])] else: table_subset = table - + if flavor == "neighborhood": - rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile(table, radius, n_neighbors, table_subset, spatial_key) + rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( + table, groups, radius, n_neighbors, table_subset, spatial_key + ) nhood_table = _df_to_adata(rel_nhood_profile, table_subset.obs.index) sc.pp.neighbors(nhood_table, use_rep="X") sc.tl.leiden(nhood_table) @@ -79,18 +85,16 @@ def calculate_niche( else: rel_nhood_profile = rel_nhood_profile.reindex(table.obs.index) table.obsm[f"{flavor}_niche"] = rel_nhood_profile - - return + def _calculate_neighborhood_profile( - adata: Union[AnnData, SpatialData], - groups: str, - radius: float, - n_neighbors: Union[int, None], - subset: AnnData, - spatial_key: str, -)-> tuple[pd.DataFrame, pd.DataFrame]: - + adata: AnnData | SpatialData, + groups: str, + radius: float | None, + n_neighbors: int | None, + subset: AnnData, + spatial_key: str, +) -> tuple[pd.DataFrame, pd.DataFrame]: # reset index adata.obs = adata.obs.reset_index() @@ -124,6 +128,7 @@ def _calculate_neighborhood_profile( return rel_freq, abs_freq + def _df_to_adata(df: pd.DataFrame, index: pd.core.indexes.base.Index) -> AnnData: adata = AnnData(X=df.values) adata.obs.index = index From 139819a3a9a1bb6c42c11a29e3e3f7e2f003bad0 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 27 May 2024 18:43:00 +0200 Subject: [PATCH 04/67] Update __init__.py --- src/squidpy/gr/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/squidpy/gr/__init__.py b/src/squidpy/gr/__init__.py index 9045b6a5..2ff0c890 100644 --- a/src/squidpy/gr/__init__.py +++ b/src/squidpy/gr/__init__.py @@ -5,6 +5,7 @@ from squidpy.gr._build import spatial_neighbors from squidpy.gr._ligrec import ligrec from squidpy.gr._nhood import centrality_scores, interaction_matrix, nhood_enrichment +from squidpy.gr._niche import calculate_niche from squidpy.gr._ppatterns import co_occurrence, spatial_autocorr from squidpy.gr._ripley import ripley from squidpy.gr._sepal import sepal From a5b810edfe1d3131a0d7f52ba08f85a5658f32ae Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 29 May 2024 13:45:47 +0200 Subject: [PATCH 05/67] Add function --- src/squidpy/gr/_niche.py | 63 ++++++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index ecc11230..797d258a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any +from typing import Any, Optional import numpy as np import pandas as pd @@ -9,23 +9,23 @@ from scipy.spatial import cKDTree from sklearn.neighbors import KDTree from spatialdata import SpatialData - -from squidpy._docs import d +from utag import utag __all__ = ["calculate_niche"] -@d.dedent def calculate_niche( adata: AnnData | SpatialData, groups: str, - flavor: str, - radius: float | None, - n_neighbors: int | None, + flavor: str = "neighborhood", + library_key: str | None = None, + radius: float | None = None, + n_neighbors: int | None = None, limit_to: str | list[Any] | None = None, table_key: str | None = None, spatial_key: str = "spatial", -) -> AnnData | SpatialData: + copy: bool = False, +)-> AnnData | pd.DataFrame: # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present is_sdata = False if isinstance(adata, SpatialData): @@ -76,15 +76,39 @@ def calculate_niche( rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( table, groups, radius, n_neighbors, table_subset, spatial_key ) - nhood_table = _df_to_adata(rel_nhood_profile, table_subset.obs.index) - sc.pp.neighbors(nhood_table, use_rep="X") + df = pd.DataFrame(rel_nhood_profile, index=table_subset.obs.index) + nhood_table = _df_to_adata(df) + sc.pp.neighbors(nhood_table, n_neighbors=n_neighbors, use_rep="X") sc.tl.leiden(nhood_table) table.obs["niche"] = nhood_table.obs["leiden"] if is_sdata: + if copy: + return nhood_table adata.tables[f"{flavor}_niche"] = nhood_table else: - rel_nhood_profile = rel_nhood_profile.reindex(table.obs.index) - table.obsm[f"{flavor}_niche"] = rel_nhood_profile + if copy: + return df + df = df.reindex(table.obs.index) + table.obsm[f"{flavor}_niche"] = df + + elif flavor == "utag": + result = utag( + table_subset, + slide_key=library_key, + max_dist=10, + normalization_mode="l1_norm", + apply_clustering=True, + clustering_method="leiden", + resolutions=1.0) + if is_sdata: + if copy: + return result + adata.tables[f"{flavor}_niche"] = result + else: + if copy: + return result + df = df.reindex(table.obs.index) + table.obsm[f"{flavor}_niche"] = df def _calculate_neighborhood_profile( @@ -100,12 +124,12 @@ def _calculate_neighborhood_profile( if n_neighbors is not None: # get k-nearest neighbors for each observation - tree = KDTree(adata[spatial_key]) - _, indices = tree.query(subset[spatial_key], k=n_neighbors) + tree = KDTree(adata.obsm[spatial_key]) + _, indices = tree.query(subset.obsm[spatial_key], k=n_neighbors) else: # get neighbors within a given radius for each observation - tree = cKDTree(adata[spatial_key]) - indices = tree.query_ball_point(subset[spatial_key], r=radius) + tree = cKDTree(adata.obsm[spatial_key]) + indices = tree.query_ball_point(subset.obsm[spatial_key], r=radius) # get unique categories category_arr = adata.obs[groups].values @@ -129,7 +153,8 @@ def _calculate_neighborhood_profile( return rel_freq, abs_freq -def _df_to_adata(df: pd.DataFrame, index: pd.core.indexes.base.Index) -> AnnData: - adata = AnnData(X=df.values) - adata.obs.index = index +def _df_to_adata(df: pd.DataFrame) -> AnnData: + df.index = df.index.map(str) + adata = AnnData(X=df) + adata.obs.index = df.index return adata From 50a8474f7fb68fa15e516f64b74e92cd9a7c7465 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 May 2024 11:46:17 +0000 Subject: [PATCH 06/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 797d258a..67a42363 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -25,7 +25,7 @@ def calculate_niche( table_key: str | None = None, spatial_key: str = "spatial", copy: bool = False, -)-> AnnData | pd.DataFrame: +) -> AnnData | pd.DataFrame: # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present is_sdata = False if isinstance(adata, SpatialData): @@ -93,13 +93,14 @@ def calculate_niche( elif flavor == "utag": result = utag( - table_subset, - slide_key=library_key, - max_dist=10, - normalization_mode="l1_norm", - apply_clustering=True, + table_subset, + slide_key=library_key, + max_dist=10, + normalization_mode="l1_norm", + apply_clustering=True, clustering_method="leiden", - resolutions=1.0) + resolutions=1.0, + ) if is_sdata: if copy: return result From 38c67fbaa0ee773e5a5fc1ecb99ce1dcae9259e1 Mon Sep 17 00:00:00 2001 From: LLehner Date: Sat, 8 Jun 2024 23:22:19 +0200 Subject: [PATCH 07/67] Update --- src/squidpy/gr/_niche.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 797d258a..762d188b 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -25,7 +25,7 @@ def calculate_niche( table_key: str | None = None, spatial_key: str = "spatial", copy: bool = False, -)-> AnnData | pd.DataFrame: +) -> AnnData | pd.DataFrame: # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present is_sdata = False if isinstance(adata, SpatialData): @@ -89,17 +89,19 @@ def calculate_niche( if copy: return df df = df.reindex(table.obs.index) + print(df.head()) table.obsm[f"{flavor}_niche"] = df elif flavor == "utag": result = utag( - table_subset, - slide_key=library_key, - max_dist=10, - normalization_mode="l1_norm", - apply_clustering=True, + table_subset, + slide_key=library_key, + max_dist=10, + normalization_mode="l1_norm", + apply_clustering=True, clustering_method="leiden", - resolutions=1.0) + resolutions=1.0, + ) if is_sdata: if copy: return result From 86d5efdfe137756e07e198192166723db08e8701 Mon Sep 17 00:00:00 2001 From: Blampey Quentin Date: Tue, 11 Jun 2024 09:53:12 +0200 Subject: [PATCH 08/67] adding fide score and jsd metrics --- src/squidpy/gr/_niche.py | 103 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 762d188b..1f15c84d 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Iterator from typing import Any, Optional import numpy as np @@ -7,10 +8,15 @@ import scanpy as sc from anndata import AnnData from scipy.spatial import cKDTree +from sklearn import metrics +from sklearn.decomposition import PCA from sklearn.neighbors import KDTree +from sklearn.preprocessing import StandardScaler from spatialdata import SpatialData from utag import utag +from squidpy._utils import NDArrayA + __all__ = ["calculate_niche"] @@ -160,3 +166,100 @@ def _df_to_adata(df: pd.DataFrame) -> AnnData: adata = AnnData(X=df) adata.obs.index = df.index return adata + + +def mean_fide_score( + adatas: AnnData | list[AnnData], + obs_key: str, + slide_key: str | None = None, + n_classes: int | None = None, +) -> float: + """Mean FIDE score over all slides. A low score indicates a great domain continuity.""" + return float(np.mean([fide_score(adata, obs_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key)])) + + +def fide_score(adata: AnnData, obs_key: str, n_classes: int | None = None) -> float: + """ + F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. + + The F1-score is computed for every class, then all F1-scores are averaged. If some classes + are not predicted, the `n_classes` argument allows to pad with zeros before averaging the F1-scores. + """ + i_left, i_right = adata.obsp["spatial_connectivities"].nonzero() + classes_left, classes_right = ( + adata.obs.iloc[i_left][obs_key], + adata.obs.iloc[i_right][obs_key], + ) + + f1_scores = metrics.f1_score(classes_left, classes_right, average=None) + + if n_classes is None: + return float(f1_scores.mean()) + + assert n_classes >= len(f1_scores), f"Expected {n_classes:=}, but found {len(f1_scores)}, which is greater" + + return float(np.pad(f1_scores, (0, n_classes - len(f1_scores))).mean()) + + +def jensen_shannon_divergence(adatas: AnnData | list[AnnData], obs_key: str, slide_key: str | None = None) -> float: + """Jensen-Shannon divergence (JSD) over all slides + + Args: + adata: One or a list of AnnData object(s) + obs_key: The key containing the clusters + slide_key: The slide ID obs key + + Returns: + A float corresponding to the JSD + """ + distributions = [ + adata.obs[obs_key].value_counts(sort=False).values for adata in _iter_uid(adatas, slide_key, obs_key) + ] + + return _jensen_shannon_divergence(np.array(distributions)) + + +def _jensen_shannon_divergence(distributions: NDArrayA) -> float: + """Compute the Jensen-Shannon divergence (JSD) for a multiple probability distributions. + + The lower the score, the better distribution of clusters among the different batches. + + Args: + distributions: An array of shape (B x C), where B is the number of batches, and C is the number of clusters. For each batch, it contains the percentage of each cluster among cells. + + Returns: + A float corresponding to the JSD + """ + distributions = distributions / distributions.sum(1)[:, None] + mean_distribution = np.mean(distributions, 0) + + return _entropy(mean_distribution) - float(np.mean([_entropy(dist) for dist in distributions])) + + +def _entropy(distribution: NDArrayA) -> float: + """Shannon entropy + + Args: + distribution: An array of probabilities (should sum to one) + + Returns: + The Shannon entropy + """ + return float(-(distribution * np.log(distribution + 1e-8)).sum()) + + +def _iter_uid(adatas: AnnData | list[AnnData], slide_key: str | None, obs_key: str | None = None) -> Iterator[AnnData]: + if isinstance(adatas, AnnData): + adatas = [adatas] + + if obs_key is not None: + categories = set.union(*[set(adata.obs[obs_key].unique().dropna()) for adata in adatas]) + for adata in adatas: + adata.obs[obs_key] = adata.obs[obs_key].astype("category").cat.set_categories(categories) + + for adata in adatas: + if slide_key is not None: + for slide_id in adata.obs[slide_key].unique(): + yield adata[adata.obs[slide_key] == slide_id] + else: + yield adata From 334b7fb8864180a68f83901607e5bfd27e549b96 Mon Sep 17 00:00:00 2001 From: Laurens Lehner Date: Tue, 11 Jun 2024 14:06:40 +0200 Subject: [PATCH 09/67] Add function to test for niche similarity by comparing max (99th percentile) counts --- src/squidpy/gr/_niche.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 1f15c84d..e5b4cf3e 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -14,6 +14,8 @@ from sklearn.preprocessing import StandardScaler from spatialdata import SpatialData from utag import utag +import itertools +from scipy.stats import ranksums from squidpy._utils import NDArrayA @@ -167,6 +169,29 @@ def _df_to_adata(df: pd.DataFrame) -> AnnData: adata.obs.index = df.index return adata +def pairwise_niche_comparison( + adata: AnnData, + niche_key: str, +) -> pd.DataFrame: + niches = adata.obs[niche_key].unique().tolist() + niche_dict = {} + for niche in adata.obs[niche_key].unique(): + niche_adata = adata[adata.obs[niche_key] == niche] + n_cols = niche_adata.X.shape[1] + arr = np.ones(n_cols) + for i in range(n_cols): + col_data = niche_adata.X.getcol(i).data + percentile_99 = np.percentile(col_data, 99) + arr[i] = percentile_99 + niche_dict[niche] = arr + var_by_niche = pd.DataFrame(niche_dict) + result = pd.DataFrame(index=niches, columns=niches, data=None) + combinations = list(itertools.combinations_with_replacement(niches, 2)) + for pair in combinations: + p_val = ranksums(var_by_niche[pair[0]], var_by_niche[pair[1]], alternative="two-sided")[1] + result.at[pair[0], pair[1]] = p_val + result.at[pair[1], pair[0]] = p_val + return result def mean_fide_score( adatas: AnnData | list[AnnData], From 54936f9a75990075b7c982081f80eeff51db2de1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 12:07:08 +0000 Subject: [PATCH 10/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index e5b4cf3e..a569e157 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,5 +1,6 @@ from __future__ import annotations +import itertools from collections.abc import Iterator from typing import Any, Optional @@ -8,14 +9,13 @@ import scanpy as sc from anndata import AnnData from scipy.spatial import cKDTree +from scipy.stats import ranksums from sklearn import metrics from sklearn.decomposition import PCA from sklearn.neighbors import KDTree from sklearn.preprocessing import StandardScaler from spatialdata import SpatialData from utag import utag -import itertools -from scipy.stats import ranksums from squidpy._utils import NDArrayA @@ -169,9 +169,10 @@ def _df_to_adata(df: pd.DataFrame) -> AnnData: adata.obs.index = df.index return adata + def pairwise_niche_comparison( - adata: AnnData, - niche_key: str, + adata: AnnData, + niche_key: str, ) -> pd.DataFrame: niches = adata.obs[niche_key].unique().tolist() niche_dict = {} @@ -193,6 +194,7 @@ def pairwise_niche_comparison( result.at[pair[1], pair[0]] = p_val return result + def mean_fide_score( adatas: AnnData | list[AnnData], obs_key: str, From 2c5cac8751c8f5e121770611b13e574e4c1e02f2 Mon Sep 17 00:00:00 2001 From: Laurens Lehner Date: Tue, 11 Jun 2024 14:17:05 +0200 Subject: [PATCH 11/67] Fix result dataframe --- src/squidpy/gr/_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index e5b4cf3e..13792a19 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -185,7 +185,7 @@ def pairwise_niche_comparison( arr[i] = percentile_99 niche_dict[niche] = arr var_by_niche = pd.DataFrame(niche_dict) - result = pd.DataFrame(index=niches, columns=niches, data=None) + result = pd.DataFrame(index=niches, columns=niches, data=None, dtype=float) combinations = list(itertools.combinations_with_replacement(niches, 2)) for pair in combinations: p_val = ranksums(var_by_niche[pair[0]], var_by_niche[pair[1]], alternative="two-sided")[1] From b5cb056630e87a0cd8ceb44b7d3206c2934ff5b8 Mon Sep 17 00:00:00 2001 From: Laurens Lehner Date: Tue, 11 Jun 2024 15:00:11 +0200 Subject: [PATCH 12/67] Add scores to compare different niche calculations --- src/squidpy/gr/_niche.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index bf7b06bc..8ec7bc0c 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -14,6 +14,7 @@ from sklearn.decomposition import PCA from sklearn.neighbors import KDTree from sklearn.preprocessing import StandardScaler +from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, fowlkes_mallows_score from spatialdata import SpatialData from utag import utag @@ -174,8 +175,13 @@ def pairwise_niche_comparison( adata: AnnData, niche_key: str, ) -> pd.DataFrame: + """Do a simple pairwise DE test on the 99th percentile of each gene for each niche. + Can be used to plot heatmap showing similar (large p-value) or different (small p-value) niches. + For validating niche results, the niche pairs that are similar in expression are the ones of interest because + it could hint at niches not being well defined in those cases.""" niches = adata.obs[niche_key].unique().tolist() niche_dict = {} + # for each niche, calculate the 99th percentile of each gene for niche in adata.obs[niche_key].unique(): niche_adata = adata[adata.obs[niche_key] == niche] n_cols = niche_adata.X.shape[1] @@ -185,9 +191,12 @@ def pairwise_niche_comparison( percentile_99 = np.percentile(col_data, 99) arr[i] = percentile_99 niche_dict[niche] = arr + # create 99th percentile count x niche matrix var_by_niche = pd.DataFrame(niche_dict) result = pd.DataFrame(index=niches, columns=niches, data=None, dtype=float) + # construct all pairs (unordered and with pairs of the same niche) combinations = list(itertools.combinations_with_replacement(niches, 2)) + # create a p-value matrix for all niche pairs for pair in combinations: p_val = ranksums(var_by_niche[pair[0]], var_by_niche[pair[1]], alternative="two-sided")[1] result.at[pair[0], pair[1]] = p_val @@ -290,3 +299,14 @@ def _iter_uid(adatas: AnnData | list[AnnData], slide_key: str | None, obs_key: s yield adata[adata.obs[slide_key] == slide_id] else: yield adata + +def _compare_niche_definitions(adata: AnnData, niche_definitions: list) -> dict[str, pd.DataFrame]: + result = pd.DataFrame(index=niche_definitions, columns=niche_definitions, data=None, dtype=float) + combinations = list(itertools.combinations_with_replacement(niche_definitions, 2)) + scores = {"ARI:": adjusted_rand_score, "NMI": normalized_mutual_info_score, "FMI": fowlkes_mallows_score} + for score_name, score_func in scores.items(): + for pair in combinations: + score = score_func(adata.obs[pair[0]], adata.obs[pair[1]]) + result.at[pair[0], pair[1]] = score + result.at[pair[1], pair[0]] = score + adata.uns[f"niche_definition_comparison_{score_name}"] = result From 2b5ef61f5f231663b0a09ad16734b3fa038ed001 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 13:01:34 +0000 Subject: [PATCH 13/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 8ec7bc0c..b1edfbff 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -12,9 +12,9 @@ from scipy.stats import ranksums from sklearn import metrics from sklearn.decomposition import PCA +from sklearn.metrics import adjusted_rand_score, fowlkes_mallows_score, normalized_mutual_info_score from sklearn.neighbors import KDTree from sklearn.preprocessing import StandardScaler -from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score, fowlkes_mallows_score from spatialdata import SpatialData from utag import utag @@ -300,6 +300,7 @@ def _iter_uid(adatas: AnnData | list[AnnData], slide_key: str | None, obs_key: s else: yield adata + def _compare_niche_definitions(adata: AnnData, niche_definitions: list) -> dict[str, pd.DataFrame]: result = pd.DataFrame(index=niche_definitions, columns=niche_definitions, data=None, dtype=float) combinations = list(itertools.combinations_with_replacement(niche_definitions, 2)) From c98ec1bbd0a7fee0ddcd20de7cac5b751739096d Mon Sep 17 00:00:00 2001 From: Laurens Lehner Date: Tue, 11 Jun 2024 17:59:41 +0200 Subject: [PATCH 14/67] Update doc string and param names --- src/squidpy/gr/_niche.py | 105 +++++++++++++++++++++++++-------------- 1 file changed, 68 insertions(+), 37 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 8ec7bc0c..62e00d4f 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -28,13 +28,41 @@ def calculate_niche( groups: str, flavor: str = "neighborhood", library_key: str | None = None, - radius: float | None = None, - n_neighbors: int | None = None, + radius: float | None = None, #deprecate, use spatial graph instead + n_neighbors: int | None = None, #deprecate, use spatial graph instead limit_to: str | list[Any] | None = None, table_key: str | None = None, spatial_key: str = "spatial", copy: bool = False, ) -> AnnData | pd.DataFrame: + """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. + The resulting niche labels with be stored in 'adata.obs'. If flavor = 'all' then all available methods + will be applied and additionally compared using cluster validation scores. + + Parameters + ---------- + %(adata)s + groups + groups based on which to calculate neighborhood profile. + flavor + Method to use for niche calculation. Available options are: + + - `{c.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. + - `{c.SPOT.s!r}` - calculate niches using optimal transport. + - `{c.BANKSY.s!r}`- use Banksy algorithm. + - `{c.CELLCHARTER.s!r}` - use cellcharter. + - `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication). + - `{c.ALL.s!r}` - apply all available methods and compare them using cluster validation scores. + %(library_key)s + limit_to + Restrict niche calculation to a subset of the data. + table_key + Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed. + spatial_key + Location of spatial coordinates in `adata.obsm`. + %(copy)s + """ + # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present is_sdata = False if isinstance(adata, SpatialData): @@ -173,17 +201,17 @@ def _df_to_adata(df: pd.DataFrame) -> AnnData: def pairwise_niche_comparison( adata: AnnData, - niche_key: str, + library_key: str, ) -> pd.DataFrame: """Do a simple pairwise DE test on the 99th percentile of each gene for each niche. Can be used to plot heatmap showing similar (large p-value) or different (small p-value) niches. For validating niche results, the niche pairs that are similar in expression are the ones of interest because it could hint at niches not being well defined in those cases.""" - niches = adata.obs[niche_key].unique().tolist() + niches = adata.obs[library_key].unique().tolist() niche_dict = {} # for each niche, calculate the 99th percentile of each gene - for niche in adata.obs[niche_key].unique(): - niche_adata = adata[adata.obs[niche_key] == niche] + for niche in adata.obs[library_key].unique(): + niche_adata = adata[adata.obs[library_key] == niche] n_cols = niche_adata.X.shape[1] arr = np.ones(n_cols) for i in range(n_cols): @@ -206,15 +234,17 @@ def pairwise_niche_comparison( def mean_fide_score( adatas: AnnData | list[AnnData], - obs_key: str, + library_key: str, slide_key: str | None = None, n_classes: int | None = None, ) -> float: """Mean FIDE score over all slides. A low score indicates a great domain continuity.""" - return float(np.mean([fide_score(adata, obs_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key)])) + return float(np.mean([fide_score(adata, library_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key)])) -def fide_score(adata: AnnData, obs_key: str, n_classes: int | None = None) -> float: +def fide_score(adata: AnnData, + library_key: str, + n_classes: int | None = None) -> float: """ F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. @@ -223,8 +253,8 @@ def fide_score(adata: AnnData, obs_key: str, n_classes: int | None = None) -> fl """ i_left, i_right = adata.obsp["spatial_connectivities"].nonzero() classes_left, classes_right = ( - adata.obs.iloc[i_left][obs_key], - adata.obs.iloc[i_right][obs_key], + adata.obs.iloc[i_left][library_key], + adata.obs.iloc[i_right][library_key], ) f1_scores = metrics.f1_score(classes_left, classes_right, average=None) @@ -237,19 +267,11 @@ def fide_score(adata: AnnData, obs_key: str, n_classes: int | None = None) -> fl return float(np.pad(f1_scores, (0, n_classes - len(f1_scores))).mean()) -def jensen_shannon_divergence(adatas: AnnData | list[AnnData], obs_key: str, slide_key: str | None = None) -> float: - """Jensen-Shannon divergence (JSD) over all slides - - Args: - adata: One or a list of AnnData object(s) - obs_key: The key containing the clusters - slide_key: The slide ID obs key - - Returns: - A float corresponding to the JSD - """ +def jensen_shannon_divergence(adatas: AnnData | list[AnnData], + library_key: str, slide_key: str | None = None) -> float: + """Jensen-Shannon divergence (JSD) over all slides""" distributions = [ - adata.obs[obs_key].value_counts(sort=False).values for adata in _iter_uid(adatas, slide_key, obs_key) + adata.obs[library_key].value_counts(sort=False).values for adata in _iter_uid(adatas, slide_key, library_key) ] return _jensen_shannon_divergence(np.array(distributions)) @@ -257,14 +279,15 @@ def jensen_shannon_divergence(adatas: AnnData | list[AnnData], obs_key: str, sli def _jensen_shannon_divergence(distributions: NDArrayA) -> float: """Compute the Jensen-Shannon divergence (JSD) for a multiple probability distributions. - The lower the score, the better distribution of clusters among the different batches. - Args: - distributions: An array of shape (B x C), where B is the number of batches, and C is the number of clusters. For each batch, it contains the percentage of each cluster among cells. + Parameters + ---------- + distributions + An array of shape (B x C), where B is the number of batches, and C is the number of clusters. For each batch, it contains the percentage of each cluster among cells. - Returns: - A float corresponding to the JSD + Returns + JSD (float) """ distributions = distributions / distributions.sum(1)[:, None] mean_distribution = np.mean(distributions, 0) @@ -275,35 +298,43 @@ def _jensen_shannon_divergence(distributions: NDArrayA) -> float: def _entropy(distribution: NDArrayA) -> float: """Shannon entropy - Args: + Parameters + ---------- distribution: An array of probabilities (should sum to one) - Returns: + Returns The Shannon entropy """ return float(-(distribution * np.log(distribution + 1e-8)).sum()) -def _iter_uid(adatas: AnnData | list[AnnData], slide_key: str | None, obs_key: str | None = None) -> Iterator[AnnData]: +def _iter_uid(adatas: AnnData | list[AnnData], + slide_key: str | None, + library_key: str | None = None) -> Iterator[AnnData]: if isinstance(adatas, AnnData): adatas = [adatas] - if obs_key is not None: - categories = set.union(*[set(adata.obs[obs_key].unique().dropna()) for adata in adatas]) + if library_key is not None: + categories = set.union(*[set(adata.obs[library_key].unique().dropna()) for adata in adatas]) for adata in adatas: - adata.obs[obs_key] = adata.obs[obs_key].astype("category").cat.set_categories(categories) + adata.obs[library_key] = adata.obs[library_key].astype("category").cat.set_categories(categories) for adata in adatas: if slide_key is not None: - for slide_id in adata.obs[slide_key].unique(): - yield adata[adata.obs[slide_key] == slide_id] + for slide in adata.obs[slide_key].unique(): + yield adata[adata.obs[slide_key] == slide] else: yield adata -def _compare_niche_definitions(adata: AnnData, niche_definitions: list) -> dict[str, pd.DataFrame]: +def _compare_niche_definitions(adata: AnnData, + niche_definitions: list): + """Given different clustering results, compare them using different scores.""" + result = pd.DataFrame(index=niche_definitions, columns=niche_definitions, data=None, dtype=float) combinations = list(itertools.combinations_with_replacement(niche_definitions, 2)) scores = {"ARI:": adjusted_rand_score, "NMI": normalized_mutual_info_score, "FMI": fowlkes_mallows_score} + + # for each score, apply it on all pairs of niche definitions for score_name, score_func in scores.items(): for pair in combinations: score = score_func(adata.obs[pair[0]], adata.obs[pair[1]]) From bb3bdfb6f8fc7f202e6b929b8d3bff328d652813 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Jun 2024 16:04:41 +0000 Subject: [PATCH 15/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 6f24ce06..7b9de77e 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -28,17 +28,17 @@ def calculate_niche( groups: str, flavor: str = "neighborhood", library_key: str | None = None, - radius: float | None = None, #deprecate, use spatial graph instead - n_neighbors: int | None = None, #deprecate, use spatial graph instead + radius: float | None = None, # deprecate, use spatial graph instead + n_neighbors: int | None = None, # deprecate, use spatial graph instead limit_to: str | list[Any] | None = None, table_key: str | None = None, spatial_key: str = "spatial", copy: bool = False, ) -> AnnData | pd.DataFrame: - """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. + """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. The resulting niche labels with be stored in 'adata.obs'. If flavor = 'all' then all available methods will be applied and additionally compared using cluster validation scores. - + Parameters ---------- %(adata)s @@ -239,12 +239,12 @@ def mean_fide_score( n_classes: int | None = None, ) -> float: """Mean FIDE score over all slides. A low score indicates a great domain continuity.""" - return float(np.mean([fide_score(adata, library_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key)])) + return float( + np.mean([fide_score(adata, library_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key)]) + ) -def fide_score(adata: AnnData, - library_key: str, - n_classes: int | None = None) -> float: +def fide_score(adata: AnnData, library_key: str, n_classes: int | None = None) -> float: """ F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. @@ -267,8 +267,7 @@ def fide_score(adata: AnnData, return float(np.pad(f1_scores, (0, n_classes - len(f1_scores))).mean()) -def jensen_shannon_divergence(adatas: AnnData | list[AnnData], - library_key: str, slide_key: str | None = None) -> float: +def jensen_shannon_divergence(adatas: AnnData | list[AnnData], library_key: str, slide_key: str | None = None) -> float: """Jensen-Shannon divergence (JSD) over all slides""" distributions = [ adata.obs[library_key].value_counts(sort=False).values for adata in _iter_uid(adatas, slide_key, library_key) @@ -308,9 +307,9 @@ def _entropy(distribution: NDArrayA) -> float: return float(-(distribution * np.log(distribution + 1e-8)).sum()) -def _iter_uid(adatas: AnnData | list[AnnData], - slide_key: str | None, - library_key: str | None = None) -> Iterator[AnnData]: +def _iter_uid( + adatas: AnnData | list[AnnData], slide_key: str | None, library_key: str | None = None +) -> Iterator[AnnData]: if isinstance(adatas, AnnData): adatas = [adatas] @@ -326,8 +325,8 @@ def _iter_uid(adatas: AnnData | list[AnnData], else: yield adata -def _compare_niche_definitions(adata: AnnData, - niche_definitions: list): + +def _compare_niche_definitions(adata: AnnData, niche_definitions: list): """Given different clustering results, compare them using different scores.""" result = pd.DataFrame(index=niche_definitions, columns=niche_definitions, data=None, dtype=float) From ebdf1d50a39d2aeb70873e6a22967cf17ce39c13 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 17 Jun 2024 15:06:46 +0200 Subject: [PATCH 16/67] Update neighborhood profile, Remove utag import --- src/squidpy/gr/_niche.py | 64 +++++++++++----------------------------- 1 file changed, 18 insertions(+), 46 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 7b9de77e..0766528e 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -8,7 +8,6 @@ import pandas as pd import scanpy as sc from anndata import AnnData -from scipy.spatial import cKDTree from scipy.stats import ranksums from sklearn import metrics from sklearn.decomposition import PCA @@ -16,7 +15,6 @@ from sklearn.neighbors import KDTree from sklearn.preprocessing import StandardScaler from spatialdata import SpatialData -from utag import utag from squidpy._utils import NDArrayA @@ -28,11 +26,12 @@ def calculate_niche( groups: str, flavor: str = "neighborhood", library_key: str | None = None, - radius: float | None = None, # deprecate, use spatial graph instead - n_neighbors: int | None = None, # deprecate, use spatial graph instead + radius: float | None = None, + n_neighbors: int | None = None, limit_to: str | list[Any] | None = None, table_key: str | None = None, spatial_key: str = "spatial", + spatial_connectivities_key: str = "spatial_connectivities", copy: bool = False, ) -> AnnData | pd.DataFrame: """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. @@ -111,71 +110,44 @@ def calculate_niche( if flavor == "neighborhood": rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - table, groups, radius, n_neighbors, table_subset, spatial_key + table, groups, radius, n_neighbors, table_subset, spatial_connectivities_key ) df = pd.DataFrame(rel_nhood_profile, index=table_subset.obs.index) nhood_table = _df_to_adata(df) sc.pp.neighbors(nhood_table, n_neighbors=n_neighbors, use_rep="X") sc.tl.leiden(nhood_table) table.obs["niche"] = nhood_table.obs["leiden"] - if is_sdata: - if copy: - return nhood_table - adata.tables[f"{flavor}_niche"] = nhood_table + if copy: + return nhood_table else: - if copy: - return df - df = df.reindex(table.obs.index) - print(df.head()) - table.obsm[f"{flavor}_niche"] = df - - elif flavor == "utag": - result = utag( - table_subset, - slide_key=library_key, - max_dist=10, - normalization_mode="l1_norm", - apply_clustering=True, - clustering_method="leiden", - resolutions=1.0, - ) - if is_sdata: - if copy: - return result - adata.tables[f"{flavor}_niche"] = result - else: - if copy: - return result - df = df.reindex(table.obs.index) - table.obsm[f"{flavor}_niche"] = df + if is_sdata: + adata.tables[f"{flavor}_niche"] = nhood_table + else: + df = df.reindex(table.obs.index) + print(df.head()) + table.obsm[f"{flavor}_niche"] = df def _calculate_neighborhood_profile( adata: AnnData | SpatialData, groups: str, - radius: float | None, - n_neighbors: int | None, subset: AnnData, - spatial_key: str, + spatial_connectivities_key: str, ) -> tuple[pd.DataFrame, pd.DataFrame]: # reset index adata.obs = adata.obs.reset_index() - if n_neighbors is not None: - # get k-nearest neighbors for each observation - tree = KDTree(adata.obsm[spatial_key]) - _, indices = tree.query(subset.obsm[spatial_key], k=n_neighbors) - else: - # get neighbors within a given radius for each observation - tree = cKDTree(adata.obsm[spatial_key]) - indices = tree.query_ball_point(subset.obsm[spatial_key], r=radius) + # get obs x neighbor matrix from sparse matrix + matrix = adata.obsp[spatial_connectivities_key].tocoo() + nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0]))) + neighbor_matrix = pd.DataFrame(nonzero_indices) # get unique categories category_arr = adata.obs[groups].values unique_categories = np.unique(category_arr) # get obs x k matrix where each column is the category of the k-th neighbor - cat_by_id = np.take(category_arr, indices) + cat_by_id = np.take(category_arr, neighbor_matrix) # in obs x k matrix convert categorical values to numerical values cat_indices = {category: index for index, category in enumerate(unique_categories)} From c6d020b68a452c61ad2c1a43718a197e7fdc1280 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 17 Jun 2024 15:08:26 +0200 Subject: [PATCH 17/67] Fix pre-commit --- src/squidpy/gr/_niche.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 0766528e..cfaf1d87 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -110,7 +110,7 @@ def calculate_niche( if flavor == "neighborhood": rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - table, groups, radius, n_neighbors, table_subset, spatial_connectivities_key + table, groups, table_subset, spatial_connectivities_key ) df = pd.DataFrame(rel_nhood_profile, index=table_subset.obs.index) nhood_table = _df_to_adata(df) @@ -298,7 +298,7 @@ def _iter_uid( yield adata -def _compare_niche_definitions(adata: AnnData, niche_definitions: list): +def _compare_niche_definitions(adata: AnnData, niche_definitions: list[str]) -> pd.DataFrame: """Given different clustering results, compare them using different scores.""" result = pd.DataFrame(index=niche_definitions, columns=niche_definitions, data=None, dtype=float) From 9fa0157a5a64e2a240f0645c025205c6bc8bae64 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 17 Jun 2024 16:12:02 +0200 Subject: [PATCH 18/67] Add utag inner product step --- src/squidpy/gr/_niche.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index cfaf1d87..f730b705 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -10,10 +10,8 @@ from anndata import AnnData from scipy.stats import ranksums from sklearn import metrics -from sklearn.decomposition import PCA from sklearn.metrics import adjusted_rand_score, fowlkes_mallows_score, normalized_mutual_info_score -from sklearn.neighbors import KDTree -from sklearn.preprocessing import StandardScaler +from sklearn.preprocessing import normalize from spatialdata import SpatialData from squidpy._utils import NDArrayA @@ -92,7 +90,7 @@ def calculate_niche( f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`." ) else: - table = adata + table = adata.copy() # check whether to use radius or knn for neighborhood profile calculation if radius is None and n_neighbors is None: @@ -127,6 +125,17 @@ def calculate_niche( print(df.head()) table.obsm[f"{flavor}_niche"] = df + elif flavor == "utag": + new_feature_matrix = _utag(table, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) + table.X = new_feature_matrix + if copy: + return table + else: + if is_sdata: + adata.tables[f"{flavor}_niche"] = table + else: + table.obsm[f"{flavor}_niche"] = table.X + def _calculate_neighborhood_profile( adata: AnnData | SpatialData, @@ -164,6 +173,25 @@ def _calculate_neighborhood_profile( return rel_freq, abs_freq +def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: + """Performas inner product of adjacency matrix and feature matrix, + such that each observation inherits features from its immediate neighbors as described in UTAG paper. + + Parameters + ---------- + adata + Annotated data matrix. + normalize + If 'True', aggregate by the mean, else aggregate by the sum.""" + + adjacency_matrix = adata.obsp[spatial_connectivity_key] + + if normalize_adj: + return normalize(adjacency_matrix, norm="l1", axis=1) @ adata.X + else: + return adjacency_matrix @ adata.X + + def _df_to_adata(df: pd.DataFrame) -> AnnData: df.index = df.index.map(str) adata = AnnData(X=df) From 49b51caa4ab38af54c9b0633e21c25bf892a24b3 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 10 Jul 2024 11:52:29 +0200 Subject: [PATCH 19/67] Fix output; Remove subsetting, neighborhood options, dimreduction and clustering steps --- src/squidpy/gr/_niche.py | 62 +++++++++++----------------------------- 1 file changed, 17 insertions(+), 45 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index f730b705..771884c3 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -24,9 +24,6 @@ def calculate_niche( groups: str, flavor: str = "neighborhood", library_key: str | None = None, - radius: float | None = None, - n_neighbors: int | None = None, - limit_to: str | list[Any] | None = None, table_key: str | None = None, spatial_key: str = "spatial", spatial_connectivities_key: str = "spatial_connectivities", @@ -51,7 +48,7 @@ def calculate_niche( - `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication). - `{c.ALL.s!r}` - apply all available methods and compare them using cluster validation scores. %(library_key)s - limit_to + subset Restrict niche calculation to a subset of the data. table_key Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed. @@ -65,14 +62,15 @@ def calculate_niche( if isinstance(adata, SpatialData): is_sdata = True if table_key is not None: - table = adata.tables[table_key] + sdata = adata + adata = adata.tables[table_key].copy() else: if len(adata.tables) > 1: count = 0 - for key in adata.tables.keys(): + for table in adata.tables.keys(): if groups in table.obs: count += 1 - table_key = key + table_key = table if count > 1: raise ValueError( f"Multiple tables in `spatialdata` with group `{groups}` detected. Please specify which table to use in `table_key`." @@ -82,70 +80,44 @@ def calculate_niche( f"Group `{groups}` not found in any table in `spatialdata`. Please specify a valid group in `groups`." ) else: - table = adata.tables[table_key] + adata = adata.tables[table_key].copy() else: - ((key, table),) = adata.tables.items() - if groups not in table.obs: + ((key, adata),) = adata.tables.items() + if groups not in adata.obs: raise ValueError( f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`." ) - else: - table = adata.copy() - - # check whether to use radius or knn for neighborhood profile calculation - if radius is None and n_neighbors is None: - raise ValueError("Either `radius` or `n_neighbors` must be provided, but both are `None`.") - if radius is not None and n_neighbors is not None: - raise ValueError("Either `radius` and `n_neighbors` must be provided, but both were provided.") - - # subset adata if only observations within specified groups are to be considered - if limit_to is not None: - if isinstance(limit_to, str): - limit_to = [limit_to] - table_subset = table[table.obs[groups].isin([limit_to])] - else: - table_subset = table if flavor == "neighborhood": rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - table, groups, table_subset, spatial_connectivities_key + adata, groups, spatial_connectivities_key ) - df = pd.DataFrame(rel_nhood_profile, index=table_subset.obs.index) + df = pd.DataFrame(rel_nhood_profile, index=adata.obs.index) nhood_table = _df_to_adata(df) - sc.pp.neighbors(nhood_table, n_neighbors=n_neighbors, use_rep="X") - sc.tl.leiden(nhood_table) - table.obs["niche"] = nhood_table.obs["leiden"] if copy: - return nhood_table + return df else: if is_sdata: - adata.tables[f"{flavor}_niche"] = nhood_table + sdata.tables[f"{flavor}_niche"] = nhood_table else: - df = df.reindex(table.obs.index) - print(df.head()) - table.obsm[f"{flavor}_niche"] = df + adata.obsm["neighborhood_profile"] = df elif flavor == "utag": - new_feature_matrix = _utag(table, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) - table.X = new_feature_matrix + new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) if copy: - return table + return new_feature_matrix else: if is_sdata: - adata.tables[f"{flavor}_niche"] = table + sdata.tables[f"{flavor}_niche"] = new_feature_matrix else: - table.obsm[f"{flavor}_niche"] = table.X + adata.obsm[f"{flavor}_niche"] = new_feature_matrix def _calculate_neighborhood_profile( adata: AnnData | SpatialData, groups: str, - subset: AnnData, spatial_connectivities_key: str, ) -> tuple[pd.DataFrame, pd.DataFrame]: - # reset index - adata.obs = adata.obs.reset_index() - # get obs x neighbor matrix from sparse matrix matrix = adata.obsp[spatial_connectivities_key].tocoo() nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0]))) From d6df534230a924178d54560886523c4423ee0b44 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 10 Jul 2024 14:09:26 +0200 Subject: [PATCH 20/67] Update utag output --- src/squidpy/gr/_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 771884c3..710ef658 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -110,7 +110,7 @@ def calculate_niche( if is_sdata: sdata.tables[f"{flavor}_niche"] = new_feature_matrix else: - adata.obsm[f"{flavor}_niche"] = new_feature_matrix + adata.layers["utag"] = new_feature_matrix def _calculate_neighborhood_profile( From f797ef942e403d9a48c3eb533260685fb6bf0d47 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 10 Jul 2024 15:29:43 +0200 Subject: [PATCH 21/67] Add cellcharter aggregation step --- src/squidpy/gr/_niche.py | 67 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 710ef658..354e7c59 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -4,10 +4,12 @@ from collections.abc import Iterator from typing import Any, Optional +import anndata as ad import numpy as np import pandas as pd import scanpy as sc from anndata import AnnData +from scipy.sparse import csr_matrix, vstack from scipy.stats import ranksums from sklearn import metrics from sklearn.metrics import adjusted_rand_score, fowlkes_mallows_score, normalized_mutual_info_score @@ -26,7 +28,10 @@ def calculate_niche( library_key: str | None = None, table_key: str | None = None, spatial_key: str = "spatial", + adj_subsets: list[int] | None = None, + aggregation: str = "mean", spatial_connectivities_key: str = "spatial_connectivities", + spatial_distances_key: str = "spatial_distances", copy: bool = False, ) -> AnnData | pd.DataFrame: """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. @@ -112,6 +117,38 @@ def calculate_niche( else: adata.layers["utag"] = new_feature_matrix + elif flavor == "cellcharter": + adj_matrix_subsets = [] + if isinstance(adj_subsets, list): + for k in adj_subsets: + adj_matrix_subsets.append( + _get_adj_matrix_subsets( + adata.obsp[spatial_connectivities_key], adata.obsp[spatial_distances_key], k + ) + ) + if aggregation == "mean": + inner_products = [adata.X.dot(adj_subset) for adj_subset in adj_matrix_subsets] + elif aggregation == "variance": + inner_products = [ + _aggregate_var(matrix, adata.obsp[spatial_connectivities_key], adata) for matrix in inner_products + ] + else: + raise ValueError( + f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'." + ) + concatenated_matrix = vstack(inner_products) + if copy: + return concatenated_matrix + else: + if is_sdata: + sdata.tables[f"{flavor}_niche"] = ad.AnnData(concatenated_matrix) + else: + adata.obsm[f"{flavor}_niche"] = concatenated_matrix + else: + raise ValueError( + "Flavor 'cellcharter' requires list of neighbors to build adjacency matrices. Please provide a list of k_neighbors for 'adj_subsets'." + ) + def _calculate_neighborhood_profile( adata: AnnData | SpatialData, @@ -164,6 +201,31 @@ def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> return adjacency_matrix @ adata.X +def _get_adj_matrix_subsets(connectivities: csr_matrix, distances: csr_matrix, k_neighbors: int) -> csr_matrix: + # Convert the distance matrix to a dense format for easier manipulation + dist_dense = distances.todense() + + # Find the indices of the k closest neighbors for each row + closest_neighbors_indices = np.argsort(dist_dense, axis=1)[:, :k_neighbors] + + # Initialize lists to collect data for the new sparse matrix + rows = [] + cols = [] + data = [] + + # Iterate over each row to construct the new adjacency matrix + for row in range(dist_dense.shape[0]): + for col in closest_neighbors_indices[row].flat: + rows.append(row) + cols.append(col) + data.append(connectivities[row, col]) + + # Create the new sparse matrix with the reduced neighbors + new_adj_matrix = csr_matrix((data, (rows, cols)), shape=connectivities.shape) + + return new_adj_matrix + + def _df_to_adata(df: pd.DataFrame) -> AnnData: df.index = df.index.map(str) adata = AnnData(X=df) @@ -171,6 +233,11 @@ def _df_to_adata(df: pd.DataFrame) -> AnnData: return adata +def _aggregate_var(product: csr_matrix, connectivities: csr_matrix, adata: AnnData) -> csr_matrix: + mean_squared = connectivities.dot(adata.X.multiply(adata.X)) + return mean_squared - (product.multiply(product)) + + def pairwise_niche_comparison( adata: AnnData, library_key: str, From bd16795af0f529829bbdf8e27b6a2d0ac884692e Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 15 Jul 2024 22:32:55 +0200 Subject: [PATCH 22/67] Update --- src/squidpy/gr/_niche.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 354e7c59..9e11b17a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -121,13 +121,16 @@ def calculate_niche( adj_matrix_subsets = [] if isinstance(adj_subsets, list): for k in adj_subsets: - adj_matrix_subsets.append( - _get_adj_matrix_subsets( - adata.obsp[spatial_connectivities_key], adata.obsp[spatial_distances_key], k + if k == 0: + adj_matrix_subsets.append(adata.obsp[spatial_connectivities_key]) + else: + adj_matrix_subsets.append( + _get_adj_matrix_subsets( + adata.obsp[spatial_connectivities_key], adata.obsp[spatial_distances_key], k + ) ) - ) if aggregation == "mean": - inner_products = [adata.X.dot(adj_subset) for adj_subset in adj_matrix_subsets] + inner_products = [adj_subset.dot(adata.X) for adj_subset in adj_matrix_subsets] elif aggregation == "variance": inner_products = [ _aggregate_var(matrix, adata.obsp[spatial_connectivities_key], adata) for matrix in inner_products @@ -222,7 +225,7 @@ def _get_adj_matrix_subsets(connectivities: csr_matrix, distances: csr_matrix, k # Create the new sparse matrix with the reduced neighbors new_adj_matrix = csr_matrix((data, (rows, cols)), shape=connectivities.shape) - + print(new_adj_matrix.shape) return new_adj_matrix From 79c7b526ad3a8767e275e6c629bd37955a9985b8 Mon Sep 17 00:00:00 2001 From: Laurens Lehner Date: Tue, 16 Jul 2024 14:00:48 +0200 Subject: [PATCH 23/67] Fix sparse matrix output --- src/squidpy/gr/_niche.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 9e11b17a..a476a864 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -9,7 +9,7 @@ import pandas as pd import scanpy as sc from anndata import AnnData -from scipy.sparse import csr_matrix, vstack +from scipy.sparse import csr_matrix, hstack from scipy.stats import ranksums from sklearn import metrics from sklearn.metrics import adjusted_rand_score, fowlkes_mallows_score, normalized_mutual_info_score @@ -139,14 +139,24 @@ def calculate_niche( raise ValueError( f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'." ) - concatenated_matrix = vstack(inner_products) + concatenated_matrix = hstack(inner_products) + + # create df from sparse matrix + arr = concatenated_matrix.toarray() + df = pd.DataFrame(arr, index=adata.obs.index) + col_names = [] + for A_i in adj_subsets: + for var in adata.var_names: + col_names.append(f"{var}_Adj_{A_i}") + df.columns = col_names + if copy: return concatenated_matrix else: if is_sdata: sdata.tables[f"{flavor}_niche"] = ad.AnnData(concatenated_matrix) else: - adata.obsm[f"{flavor}_niche"] = concatenated_matrix + adata.obsm[f"{flavor}_niche"] = df else: raise ValueError( "Flavor 'cellcharter' requires list of neighbors to build adjacency matrices. Please provide a list of k_neighbors for 'adj_subsets'." @@ -225,7 +235,6 @@ def _get_adj_matrix_subsets(connectivities: csr_matrix, distances: csr_matrix, k # Create the new sparse matrix with the reduced neighbors new_adj_matrix = csr_matrix((data, (rows, cols)), shape=connectivities.shape) - print(new_adj_matrix.shape) return new_adj_matrix From 9d2104f821b75007a16cdb676d9ff141d7e5470b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 16 Jul 2024 12:01:20 +0000 Subject: [PATCH 24/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index a476a864..4afd4bc7 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -140,7 +140,7 @@ def calculate_niche( f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'." ) concatenated_matrix = hstack(inner_products) - + # create df from sparse matrix arr = concatenated_matrix.toarray() df = pd.DataFrame(arr, index=adata.obs.index) From 5219a30682bbfd52076900aef384603c2526a8fc Mon Sep 17 00:00:00 2001 From: Laurens Lehner Date: Tue, 16 Jul 2024 16:03:09 +0200 Subject: [PATCH 25/67] Fix neighborhood profile calculation --- src/squidpy/gr/_niche.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index a476a864..19cc24b0 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -27,9 +27,9 @@ def calculate_niche( flavor: str = "neighborhood", library_key: str | None = None, table_key: str | None = None, - spatial_key: str = "spatial", adj_subsets: list[int] | None = None, aggregation: str = "mean", + spatial_key: str = "spatial", spatial_connectivities_key: str = "spatial_connectivities", spatial_distances_key: str = "spatial_distances", copy: bool = False, @@ -174,11 +174,14 @@ def _calculate_neighborhood_profile( neighbor_matrix = pd.DataFrame(nonzero_indices) # get unique categories - category_arr = adata.obs[groups].values - unique_categories = np.unique(category_arr) + unique_categories = np.unique(adata.obs[groups].values) # get obs x k matrix where each column is the category of the k-th neighbor - cat_by_id = np.take(category_arr, neighbor_matrix) + indices_with_nan = neighbor_matrix.to_numpy() + valid_indices = neighbor_matrix.fillna(-1).astype(int).to_numpy() + cat_by_id = adata.obs[groups].values[valid_indices] + cat_by_id[indices_with_nan == -1] = np.nan + # cat_by_id = np.take(category_arr, neighbor_matrix) # in obs x k matrix convert categorical values to numerical values cat_indices = {category: index for index, category in enumerate(unique_categories)} From dcc542dcf752550c8b9e742a8ce7a2818519e897 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 26 Aug 2024 18:28:17 +0200 Subject: [PATCH 26/67] Add consensus function --- src/squidpy/gr/_niche.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index dc2808f8..441bcb5a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -27,6 +27,7 @@ def calculate_niche( flavor: str = "neighborhood", library_key: str | None = None, table_key: str | None = None, + abs_nhood: bool = False, adj_subsets: list[int] | None = None, aggregation: str = "mean", spatial_key: str = "spatial", @@ -37,7 +38,6 @@ def calculate_niche( """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. The resulting niche labels with be stored in 'adata.obs'. If flavor = 'all' then all available methods will be applied and additionally compared using cluster validation scores. - Parameters ---------- %(adata)s @@ -45,7 +45,6 @@ def calculate_niche( groups based on which to calculate neighborhood profile. flavor Method to use for niche calculation. Available options are: - - `{c.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - `{c.SPOT.s!r}` - calculate niches using optimal transport. - `{c.BANKSY.s!r}`- use Banksy algorithm. @@ -97,7 +96,11 @@ def calculate_niche( rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( adata, groups, spatial_connectivities_key ) - df = pd.DataFrame(rel_nhood_profile, index=adata.obs.index) + if not abs_nhood: + nhood_profile = rel_nhood_profile + else: + nhood_profile = abs_nhood_profile + df = pd.DataFrame(nhood_profile, index=adata.obs.index) nhood_table = _df_to_adata(df) if copy: return df @@ -198,6 +201,7 @@ def _calculate_neighborhood_profile( return rel_freq, abs_freq + def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: """Performas inner product of adjacency matrix and feature matrix, such that each observation inherits features from its immediate neighbors as described in UTAG paper. @@ -218,7 +222,7 @@ def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> def _get_adj_matrix_subsets(connectivities: csr_matrix, distances: csr_matrix, k_neighbors: int) -> csr_matrix: - # Convert the distance matrix to a dense format for easier manipulation + # Convert the distance matrix to a dense format dist_dense = distances.todense() # Find the indices of the k closest neighbors for each row From 2568f7c1b37337c5877cc7e42d13d03de74dc483 Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 27 Aug 2024 13:17:15 +0200 Subject: [PATCH 27/67] Add function to build consensus niche --- src/squidpy/gr/_niche.py | 106 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 441bcb5a..baa33260 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -18,7 +18,7 @@ from squidpy._utils import NDArrayA -__all__ = ["calculate_niche"] +__all__ = ["calculate_niche", "build_consensus_niche"] def calculate_niche( @@ -201,7 +201,6 @@ def _calculate_neighborhood_profile( return rel_freq, abs_freq - def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: """Performas inner product of adjacency matrix and feature matrix, such that each observation inherits features from its immediate neighbors as described in UTAG paper. @@ -398,3 +397,106 @@ def _compare_niche_definitions(adata: AnnData, niche_definitions: list[str]) -> result.at[pair[0], pair[1]] = score result.at[pair[1], pair[0]] = score adata.uns[f"niche_definition_comparison_{score_name}"] = result + + +def _get_subset_indices(df: pd.DataFrame, column: str) -> pd.Series: + return df.groupby(column).apply(lambda x: set(x.index)) + + +def _find_best_match(subset: set[str], other_subsets: dict[str, set[str]], exclude: set[str]) -> tuple[str, float]: + best_match = "" + max_overlap = 0.0 + for other_subset, indices in other_subsets.items(): + if other_subset in exclude: + continue # Skip excluded matches + overlap = len(subset & indices) / len(subset | indices) # jaccard index + if overlap > max_overlap: + max_overlap = overlap + best_match = other_subset + return (best_match, max_overlap) + + +def _get_initial_niches(niche_definitions: list[dict[str, set[str]]]) -> dict[str, set[str]]: + min_niches = {} + min_niche_count = 0 + + for niches in niche_definitions: + niche_count = len(niches) + + if niche_count < min_niche_count: # If this dictionary has fewer keys + min_niches = niches # Update the dictionary with the fewest keys + min_niche_count = niche_count # Update the minimum key count + + return min_niches + + +def _filter_overlap(initial_consensus: dict[str, set[str]]) -> dict[str, str]: + filtered_consensus = {} + processed_elements: set[str] = set() + + for key, values in initial_consensus.items(): + unique_values = values - processed_elements # Remove already processed elements + for value in unique_values: + filtered_consensus[value] = key # Add the value as the new key, with the original key as its value + processed_elements.update(unique_values) # Mark these values as processed + + return filtered_consensus + + +def build_consensus_niche(adata: AnnData, niche_definitions: list[str], merge: str = "union") -> AnnData: + """Given multiple niche definitions, construct a consensus niche using set matching.""" + + list_of_sets = [] + for definition in niche_definitions: + list_of_sets.append(_get_subset_indices(adata.obs, definition)) + + union_of_matches = _get_initial_niches(list_of_sets) + + avg_jaccard = np.zeros(len(union_of_matches)) + + for set_of_sets in range(len(list_of_sets) - 1): + current_matches = {} + used_matches: set[str] = set() + matches_A_B = { + subset: _find_best_match(indices, list_of_sets[set_of_sets + 1], exclude=used_matches) + for subset, indices in union_of_matches.items() + } + ranked_matches = sorted(matches_A_B.items(), key=lambda x: x[1][1], reverse=True) + for subset_A, (match, jaccard_index) in ranked_matches: + if match not in used_matches: + current_matches[subset_A] = (match, jaccard_index) + used_matches.add(match) + else: + new_match, new_jaccard = _find_best_match( + union_of_matches[subset_A], list_of_sets[set_of_sets + 1], exclude=used_matches + ) + if new_match: + current_matches[subset_A] = (new_match, new_jaccard) + used_matches.add(new_match) + + jaccard = np.asarray([jaccard_index for _, (_, jaccard_index) in current_matches.items()]) + avg_jaccard = (avg_jaccard + jaccard) / (set_of_sets + 1) + + if merge == "union": + consensus = { + subset_A: union_of_matches[subset_A] | list_of_sets[set_of_sets + 1][match] + for subset_A, (match, _) in current_matches.items() + } + if merge == "intersection": + consensus = { + subset_A: union_of_matches[subset_A] & list_of_sets[set_of_sets + 1][match] + for subset_A, (match, _) in current_matches.items() + } + + niche_categories = list(consensus.keys()) + consensus_by_jaccard = dict(zip(niche_categories, avg_jaccard)) + + sorted_by_jaccard = dict( + sorted(consensus_by_jaccard.items(), key=lambda item: item[1], reverse=True), + ) + + sorted_consensus = {key: consensus[key] for key in sorted_by_jaccard} + + filtered_consensus = _filter_overlap(sorted_consensus) + + adata.obs["consensus_niche"] = adata.obs.index.map(filtered_consensus).fillna("None") From 19d45079afc113449658d377fe2578c3308376e7 Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 27 Aug 2024 13:21:23 +0200 Subject: [PATCH 28/67] Update init --- src/squidpy/gr/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/__init__.py b/src/squidpy/gr/__init__.py index 0122a249..380cb06d 100644 --- a/src/squidpy/gr/__init__.py +++ b/src/squidpy/gr/__init__.py @@ -5,7 +5,7 @@ from squidpy.gr._build import mask_graph, spatial_neighbors from squidpy.gr._ligrec import ligrec from squidpy.gr._nhood import centrality_scores, interaction_matrix, nhood_enrichment -from squidpy.gr._niche import calculate_niche +from squidpy.gr._niche import calculate_niche, build_consensus_niche from squidpy.gr._ppatterns import co_occurrence, spatial_autocorr from squidpy.gr._ripley import ripley from squidpy.gr._sepal import sepal From 5c9765181bc97bc47d7e3f5fcdc4d7055bd60c40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:21:40 +0000 Subject: [PATCH 29/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/__init__.py b/src/squidpy/gr/__init__.py index 380cb06d..ac558368 100644 --- a/src/squidpy/gr/__init__.py +++ b/src/squidpy/gr/__init__.py @@ -5,7 +5,7 @@ from squidpy.gr._build import mask_graph, spatial_neighbors from squidpy.gr._ligrec import ligrec from squidpy.gr._nhood import centrality_scores, interaction_matrix, nhood_enrichment -from squidpy.gr._niche import calculate_niche, build_consensus_niche +from squidpy.gr._niche import build_consensus_niche, calculate_niche from squidpy.gr._ppatterns import co_occurrence, spatial_autocorr from squidpy.gr._ripley import ripley from squidpy.gr._sepal import sepal From 2b47e81813a0623fb23cb454188fd18fa57adf8c Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 27 Aug 2024 13:56:43 +0200 Subject: [PATCH 30/67] Fix _get_initial_niches --- src/squidpy/gr/_niche.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index baa33260..3726d0df 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -2,7 +2,6 @@ import itertools from collections.abc import Iterator -from typing import Any, Optional import anndata as ad import numpy as np @@ -418,14 +417,14 @@ def _find_best_match(subset: set[str], other_subsets: dict[str, set[str]], exclu def _get_initial_niches(niche_definitions: list[dict[str, set[str]]]) -> dict[str, set[str]]: min_niches = {} - min_niche_count = 0 + min_niche_count = float('inf') for niches in niche_definitions: niche_count = len(niches) - if niche_count < min_niche_count: # If this dictionary has fewer keys - min_niches = niches # Update the dictionary with the fewest keys - min_niche_count = niche_count # Update the minimum key count + if niche_count < min_niche_count: + min_niches = niches + min_niche_count = niche_count return min_niches @@ -494,9 +493,7 @@ def build_consensus_niche(adata: AnnData, niche_definitions: list[str], merge: s sorted_by_jaccard = dict( sorted(consensus_by_jaccard.items(), key=lambda item: item[1], reverse=True), ) - sorted_consensus = {key: consensus[key] for key in sorted_by_jaccard} - filtered_consensus = _filter_overlap(sorted_consensus) adata.obs["consensus_niche"] = adata.obs.index.map(filtered_consensus).fillna("None") From cf3b51f03130d8e0fb6e14de420602c7f8801244 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 27 Aug 2024 11:57:07 +0000 Subject: [PATCH 31/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 3726d0df..64cef979 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -417,14 +417,14 @@ def _find_best_match(subset: set[str], other_subsets: dict[str, set[str]], exclu def _get_initial_niches(niche_definitions: list[dict[str, set[str]]]) -> dict[str, set[str]]: min_niches = {} - min_niche_count = float('inf') + min_niche_count = float("inf") for niches in niche_definitions: niche_count = len(niches) - if niche_count < min_niche_count: - min_niches = niches - min_niche_count = niche_count + if niche_count < min_niche_count: + min_niches = niches + min_niche_count = niche_count return min_niches From 14ff9a5751b71356f9e85905e83544b1989b48c3 Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 27 Aug 2024 14:08:43 +0200 Subject: [PATCH 32/67] Add docstring --- src/squidpy/gr/_niche.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 3726d0df..ad24f726 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -403,6 +403,9 @@ def _get_subset_indices(df: pd.DataFrame, column: str) -> pd.Series: def _find_best_match(subset: set[str], other_subsets: dict[str, set[str]], exclude: set[str]) -> tuple[str, float]: + """Find best matching niche pair between two sets of niche definitions. + Niches which have already been matched, are excluded from further comparisons.""" + best_match = "" max_overlap = 0.0 for other_subset, indices in other_subsets.items(): @@ -416,6 +419,8 @@ def _find_best_match(subset: set[str], other_subsets: dict[str, set[str]], exclu def _get_initial_niches(niche_definitions: list[dict[str, set[str]]]) -> dict[str, set[str]]: + """Select the niche definition with the fewest amount of unique niches.""" + min_niches = {} min_niche_count = float('inf') @@ -430,20 +435,34 @@ def _get_initial_niches(niche_definitions: list[dict[str, set[str]]]) -> dict[st def _filter_overlap(initial_consensus: dict[str, set[str]]) -> dict[str, str]: + """"Remove labels which are present in multiple niches. Labels are always kept in the niche with higher average jaccard index.""" + filtered_consensus = {} processed_elements: set[str] = set() for key, values in initial_consensus.items(): unique_values = values - processed_elements # Remove already processed elements for value in unique_values: - filtered_consensus[value] = key # Add the value as the new key, with the original key as its value - processed_elements.update(unique_values) # Mark these values as processed + filtered_consensus[value] = key # Swap key and value to make further processing easier + processed_elements.update(unique_values) # Mark value as processed return filtered_consensus def build_consensus_niche(adata: AnnData, niche_definitions: list[str], merge: str = "union") -> AnnData: - """Given multiple niche definitions, construct a consensus niche using set matching.""" + """Given multiple niche definitions, construct a consensus niche using set matching. + Each niche definition is treated as a set of subsets. For each subset in set A we look for the best matching subset in set B. + Once a match has been found, these sets are merged either by union or intersection. This merged set is then used as the new set A for the next iteration. + The final consensus niches are filtered for overlapping labels and stored as a new column in `adata.obs`. + Parameters + ---------- + %(adata)s + niche_definitions + Name of columns in `adata.obs` where previously calculated niches are stored. + merge + - `{c.union.s!r}`- merge niche matches via union join. + - `{c.intersection.s!r}` - merge niche matches by their intersection. + """ list_of_sets = [] for definition in niche_definitions: @@ -451,7 +470,7 @@ def build_consensus_niche(adata: AnnData, niche_definitions: list[str], merge: s union_of_matches = _get_initial_niches(list_of_sets) - avg_jaccard = np.zeros(len(union_of_matches)) + avg_jaccard = np.zeros(len(union_of_matches)) # the jaccard index is tracked to order the consensus niches later on for set_of_sets in range(len(list_of_sets) - 1): current_matches = {} From 4457929e0b9bb6f2b789c45728dc2309c8034acb Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 27 Aug 2024 14:09:25 +0200 Subject: [PATCH 33/67] Update --- src/squidpy/gr/_niche.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index ad24f726..df8b6815 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -422,20 +422,20 @@ def _get_initial_niches(niche_definitions: list[dict[str, set[str]]]) -> dict[st """Select the niche definition with the fewest amount of unique niches.""" min_niches = {} - min_niche_count = float('inf') + min_niche_count = float("inf") for niches in niche_definitions: niche_count = len(niches) - if niche_count < min_niche_count: - min_niches = niches - min_niche_count = niche_count + if niche_count < min_niche_count: + min_niches = niches + min_niche_count = niche_count return min_niches def _filter_overlap(initial_consensus: dict[str, set[str]]) -> dict[str, str]: - """"Remove labels which are present in multiple niches. Labels are always kept in the niche with higher average jaccard index.""" + """ "Remove labels which are present in multiple niches. Labels are always kept in the niche with higher average jaccard index.""" filtered_consensus = {} processed_elements: set[str] = set() @@ -470,7 +470,7 @@ def build_consensus_niche(adata: AnnData, niche_definitions: list[str], merge: s union_of_matches = _get_initial_niches(list_of_sets) - avg_jaccard = np.zeros(len(union_of_matches)) # the jaccard index is tracked to order the consensus niches later on + avg_jaccard = np.zeros(len(union_of_matches)) # the jaccard index is tracked to order the consensus niches later on for set_of_sets in range(len(list_of_sets) - 1): current_matches = {} From f449d88c7818c93462be28eaaba13de5be12c6f8 Mon Sep 17 00:00:00 2001 From: Tim Treis Date: Thu, 29 Aug 2024 21:48:10 -0400 Subject: [PATCH 34/67] some type fixes for mypy --- src/squidpy/_utils.py | 8 ++------ src/squidpy/gr/_ppatterns.py | 22 ++++++++++------------ src/squidpy/gr/_sepal.py | 7 +++++-- src/squidpy/pl/_utils.py | 7 +++---- 4 files changed, 20 insertions(+), 24 deletions(-) diff --git a/src/squidpy/_utils.py b/src/squidpy/_utils.py index fb07dce0..9092f4fa 100644 --- a/src/squidpy/_utils.py +++ b/src/squidpy/_utils.py @@ -37,13 +37,9 @@ def wrapper(*args: Any, **kw: Any) -> Any: return wrapper -try: - from numpy.typing import NDArray +from numpy.typing import NDArray - NDArrayA = NDArray[Any] -except (ImportError, TypeError): - NDArray = np.ndarray # type: ignore[misc] - NDArrayA = np.ndarray # type: ignore[misc] +NDArrayA = NDArray[Any] class SigQueue(Queue["Signal"] if TYPE_CHECKING else Queue): # type: ignore[misc] diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index a785706d..2764b2c0 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -53,7 +53,7 @@ def spatial_autocorr( adata: AnnData | SpatialData, connectivity_key: str = Key.obsp.spatial_conn(), genes: str | int | Sequence[str] | Sequence[int] | None = None, - mode: Literal["moran", "geary"] = SpatialAutocorr.MORAN.s, # type: ignore[assignment] + mode: SpatialAutocorr | Literal["moran", "geary"] = "moran", transformation: bool = True, n_perms: int | None = None, two_tailed: bool = False, @@ -164,22 +164,20 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr if layer not in adata.obsm: raise KeyError(f"Key `{layer!r}` not found in `adata.obsm`.") if ixs is None: - ixs = np.arange(adata.obsm[layer].shape[1]) # type: ignore[assignment] + ixs = list(np.arange(adata.obsm[layer].shape[1])) ixs = list(np.ravel([ixs])) return adata.obsm[layer][:, ixs].T, ixs if attr == "X": - vals, index = extract_X(adata, genes) # type: ignore[arg-type] + vals, index = extract_X(adata, genes) # type: ignore elif attr == "obs": - vals, index = extract_obs(adata, genes) # type: ignore[arg-type] + vals, index = extract_obs(adata, genes) # type: ignore elif attr == "obsm": - vals, index = extract_obsm(adata, genes) # type: ignore[arg-type] + vals, index = extract_obsm(adata, genes) # type: ignore else: raise NotImplementedError(f"Extracting from `adata.{attr}` is not yet implemented.") - mode = SpatialAutocorr(mode) # type: ignore[assignment] - if TYPE_CHECKING: - assert isinstance(mode, SpatialAutocorr) + mode = SpatialAutocorr(mode) params = {"mode": mode.s, "transformation": transformation, "two_tailed": two_tailed} if mode == SpatialAutocorr.MORAN: @@ -199,7 +197,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr if transformation: # row-normalize normalize(g, norm="l1", axis=1, copy=False) - score = params["func"](g, vals) + score = params["func"](g, vals) # type: ignore n_jobs = _get_n_cores(n_jobs) start = logg.info(f"Calculating {mode}'s statistic for `{n_perms}` permutations using `{n_jobs}` core(s)") @@ -425,10 +423,10 @@ def co_occurrence( n_splits = max(min(n_splits, n_obs), 1) # split array and labels - spatial_splits = tuple(s for s in np.array_split(spatial, n_splits, axis=0) if len(s)) # type: ignore[arg-type] - labs_splits = tuple(s for s in np.array_split(labs, n_splits, axis=0) if len(s)) # type: ignore[arg-type] + spatial_splits = tuple(s for s in np.array_split(spatial, n_splits, axis=0) if len(s)) + labs_splits = tuple(s for s in np.array_split(labs, n_splits, axis=0) if len(s)) # create idx array including unique combinations and self-comparison - x, y = np.triu_indices_from(np.empty((n_splits, n_splits))) # type: ignore[arg-type] + x, y = np.triu_indices_from(np.empty((n_splits, n_splits))) idx_splits = list(zip(x, y)) n_jobs = _get_n_cores(n_jobs) diff --git a/src/squidpy/gr/_sepal.py b/src/squidpy/gr/_sepal.py index 57a44cfd..7860c1bf 100644 --- a/src/squidpy/gr/_sepal.py +++ b/src/squidpy/gr/_sepal.py @@ -182,8 +182,11 @@ def _score_helper( score, sparse = [], issparse(vals) for i in ixs: - conc = vals[:, i].toarray().flatten() if sparse else vals[:, i].copy() # type: ignore[union-attr] - conc = vals[:, i].toarray().flatten() if sparse else vals[:, i].copy() # type: ignore[union-attr] + if sparse and isinstance(vals, spmatrix): + conc = vals[:, i].toarray().flatten() # Safe to call toarray() + else: + conc = vals[:, i].copy() # vals is assumed to be a NumPy array here + time_iter = _diffusion(conc, fun, n_iter, sat, sat_idx, unsat, unsat_idx, dt=dt, thresh=thresh) score.append(dt * time_iter) diff --git a/src/squidpy/pl/_utils.py b/src/squidpy/pl/_utils.py index b30aa297..79aacaf7 100644 --- a/src/squidpy/pl/_utils.py +++ b/src/squidpy/pl/_utils.py @@ -206,11 +206,10 @@ def _min_max_norm(vec: spmatrix | NDArrayA) -> NDArrayA: if vec.ndim != 1: raise ValueError(f"Expected `1` dimension, found `{vec.ndim}`.") - maxx, minn = np.nanmax(vec), np.nanmin(vec) + maxx: float = np.nanmax(vec) + minn: float = np.nanmin(vec) - return ( # type: ignore[no-any-return] - np.ones_like(vec) if np.isclose(minn, maxx) else ((vec - minn) / (maxx - minn)) - ) + return np.ones_like(vec) if np.isclose(minn, maxx) else ((vec - minn) / (maxx - minn)) def _ensure_dense_vector(fn: Callable[..., Vector_name_t]) -> Callable[..., Vector_name_t]: From dcdd9f690475579f684f94ed7267c84bcb6f1624 Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 1 Oct 2024 22:29:50 +0200 Subject: [PATCH 35/67] Update utag --- src/squidpy/gr/_niche.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index df8b6815..4fd2689b 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -27,6 +27,8 @@ def calculate_niche( library_key: str | None = None, table_key: str | None = None, abs_nhood: bool = False, + n_neighbors: int = 15, + resolutions: int | list[float] | None = None, adj_subsets: list[int] | None = None, aggregation: str = "mean", spatial_key: str = "spatial", @@ -57,10 +59,16 @@ def calculate_niche( Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed. spatial_key Location of spatial coordinates in `adata.obsm`. + n_neighbors + Number of neighbors to use for 'scanpy.pp.neighbors' before clustering using leiden algorithm. + Required if flavor == 'neighborhood' or flavor == 'UTAG'. + resolutions + List of resolutions to use for leiden clustering. + Required if flavor == 'neighborhood' or flavor == 'UTAG'. %(copy)s """ - # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present + # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present if no table is specified is_sdata = False if isinstance(adata, SpatialData): is_sdata = True @@ -111,13 +119,19 @@ def calculate_niche( elif flavor == "utag": new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) - if copy: - return new_feature_matrix + adata_utag = ad.AnnData(X=new_feature_matrix) + sc.tl.pca(adata_utag) + sc.pp.neighbors(adata_utag, n_neighbors=n_neighbors, use_rep="X_pca") + + if resolutions is not None: + if not isinstance(resolutions, list): + resolutions = [resolutions] else: - if is_sdata: - sdata.tables[f"{flavor}_niche"] = new_feature_matrix - else: - adata.layers["utag"] = new_feature_matrix + raise ValueError("Please provide resolutions for leiden clustering.") + + for res in resolutions: + sc.tl.leiden(adata_utag, resolution=res, key_added=f"utag_res={res}") + adata.obs[f"utag_res={res}"] = adata_utag.obs[f"utag_res={res}"].values elif flavor == "cellcharter": adj_matrix_subsets = [] @@ -201,7 +215,7 @@ def _calculate_neighborhood_profile( def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: - """Performas inner product of adjacency matrix and feature matrix, + """Performs inner product of adjacency matrix and feature matrix, such that each observation inherits features from its immediate neighbors as described in UTAG paper. Parameters From 6638075b5702601df3825410c4b21626943c50f9 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 2 Oct 2024 15:16:50 +0200 Subject: [PATCH 36/67] Update neighborhood profile based approach --- src/squidpy/gr/_niche.py | 81 +++++++++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 17 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 4fd2689b..f543b21a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -26,9 +26,13 @@ def calculate_niche( flavor: str = "neighborhood", library_key: str | None = None, table_key: str | None = None, - abs_nhood: bool = False, + mask: pd.core.series.Series = None, n_neighbors: int = 15, resolutions: int | list[float] | None = None, + subset_groups: list[str] | None = None, + min_niche_size: int | None = None, + scale: bool = True, + abs_nhood: bool = False, adj_subsets: list[int] | None = None, aggregation: str = "mean", spatial_key: str = "spatial", @@ -57,14 +61,26 @@ def calculate_niche( Restrict niche calculation to a subset of the data. table_key Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed. - spatial_key - Location of spatial coordinates in `adata.obsm`. + mask + Boolean array to filter cells which won't get assigned to a niche. + Note that if you want to exclude these cells during neighborhood calculation already, you should subset your AnnData table before running 'sq.gr.spatial_neigbors'. n_neighbors Number of neighbors to use for 'scanpy.pp.neighbors' before clustering using leiden algorithm. Required if flavor == 'neighborhood' or flavor == 'UTAG'. resolutions List of resolutions to use for leiden clustering. Required if flavor == 'neighborhood' or flavor == 'UTAG'. + subset_groups + Groups (e.g. cell type categories) to ignore when calculating the neighborhood profile. + Optional if flavor == 'neighborhood'. + min_niche_size + Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. + Optional if flavor == 'neighborhood'. + scale + If 'True', compute z-scores of neighborhood profiles. + Optional if flavor == 'neighborhood'. + spatial_key + Location of spatial coordinates in `adata.obsm`. %(copy)s """ @@ -101,21 +117,40 @@ def calculate_niche( if flavor == "neighborhood": rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - adata, groups, spatial_connectivities_key + adata, groups, subset_groups, spatial_connectivities_key ) if not abs_nhood: - nhood_profile = rel_nhood_profile + adata_neighborhood = ad.AnnData(X=rel_nhood_profile) else: - nhood_profile = abs_nhood_profile - df = pd.DataFrame(nhood_profile, index=adata.obs.index) - nhood_table = _df_to_adata(df) - if copy: - return df + adata_neighborhood = ad.AnnData(X=abs_nhood_profile) + + if scale: + sc.pp.scale(adata_neighborhood, zero_center=True) + + if mask is not None: + if subset_groups is not None: + mask = mask[mask.index.isin(adata_neighborhood.obs.index)] + adata_neighborhood = adata_neighborhood[mask] + + sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") + + if resolutions is not None: + if not isinstance(resolutions, list): + resolutions = [resolutions] else: - if is_sdata: - sdata.tables[f"{flavor}_niche"] = nhood_table - else: - adata.obsm["neighborhood_profile"] = df + raise ValueError("Please provide resolutions for leiden clustering.") + + #adata_neighborhood.index = subset_index + + for res in resolutions: + sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}") + print(adata_neighborhood.obs) + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map(adata_neighborhood.obs[f"neighborhood_niche_res={res}"]).fillna('not_a_niche') + if min_niche_size is not None: + counts_by_niche = adata.obs[f"neighborhood_niche_res={res}"].value_counts() + to_filter = counts_by_niche[counts_by_niche < min_niche_size].index + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs[f"neighborhood_niche_res={res}"].apply(lambda x: 'not_a_niche' if x in to_filter else x) + elif flavor == "utag": new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) @@ -180,10 +215,21 @@ def calculate_niche( def _calculate_neighborhood_profile( - adata: AnnData | SpatialData, + adata: AnnData, groups: str, + subset_groups: list[str] | None, spatial_connectivities_key: str, ) -> tuple[pd.DataFrame, pd.DataFrame]: + + if subset_groups: + adjacency_matrix = adata.obsp[spatial_connectivities_key].tocsc() + obs_mask = ~adata.obs[groups].isin(subset_groups) + adata = adata[obs_mask] + + # Update adjacency matrix such that it only contains connections to filtered observations + adjacency_matrix = adjacency_matrix[obs_mask, :][:, obs_mask] + adata.obsp[spatial_connectivities_key] = adjacency_matrix.tocsr() + # get obs x neighbor matrix from sparse matrix matrix = adata.obsp[spatial_connectivities_key].tocoo() nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0]))) @@ -203,7 +249,7 @@ def _calculate_neighborhood_profile( cat_indices = {category: index for index, category in enumerate(unique_categories)} cat_values = np.vectorize(cat_indices.get)(cat_by_id) - # For each obs calculate absolute frequency for all (not just k) categories, given the subset of categories present in obs x k matrix + # get obx x category matrix where each column is the absolute amount of a category in the neighborhood m, k = cat_by_id.shape abs_freq = np.zeros((m, len(unique_categories)), dtype=int) np.add.at(abs_freq, (np.arange(m)[:, None], cat_values), 1) @@ -211,7 +257,8 @@ def _calculate_neighborhood_profile( # normalize by n_neighbors to get relative frequency of each category rel_freq = abs_freq / k - return rel_freq, abs_freq + return pd.DataFrame(rel_freq, index=adata.obs.index), pd.DataFrame(abs_freq, index=adata.obs.index) + def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: From ae3b2f80709c0215692c1042bd295d6d354386d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 13:17:16 +0000 Subject: [PATCH 37/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_niche.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index f543b21a..187203e6 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -72,7 +72,7 @@ def calculate_niche( Required if flavor == 'neighborhood' or flavor == 'UTAG'. subset_groups Groups (e.g. cell type categories) to ignore when calculating the neighborhood profile. - Optional if flavor == 'neighborhood'. + Optional if flavor == 'neighborhood'. min_niche_size Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. Optional if flavor == 'neighborhood'. @@ -126,12 +126,12 @@ def calculate_niche( if scale: sc.pp.scale(adata_neighborhood, zero_center=True) - + if mask is not None: if subset_groups is not None: mask = mask[mask.index.isin(adata_neighborhood.obs.index)] adata_neighborhood = adata_neighborhood[mask] - + sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") if resolutions is not None: @@ -139,18 +139,21 @@ def calculate_niche( resolutions = [resolutions] else: raise ValueError("Please provide resolutions for leiden clustering.") - - #adata_neighborhood.index = subset_index + + # adata_neighborhood.index = subset_index for res in resolutions: sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}") print(adata_neighborhood.obs) - adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map(adata_neighborhood.obs[f"neighborhood_niche_res={res}"]).fillna('not_a_niche') + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map( + adata_neighborhood.obs[f"neighborhood_niche_res={res}"] + ).fillna("not_a_niche") if min_niche_size is not None: counts_by_niche = adata.obs[f"neighborhood_niche_res={res}"].value_counts() to_filter = counts_by_niche[counts_by_niche < min_niche_size].index - adata.obs[f"neighborhood_niche_res={res}"] = adata.obs[f"neighborhood_niche_res={res}"].apply(lambda x: 'not_a_niche' if x in to_filter else x) - + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs[f"neighborhood_niche_res={res}"].apply( + lambda x: "not_a_niche" if x in to_filter else x + ) elif flavor == "utag": new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) @@ -220,16 +223,15 @@ def _calculate_neighborhood_profile( subset_groups: list[str] | None, spatial_connectivities_key: str, ) -> tuple[pd.DataFrame, pd.DataFrame]: - if subset_groups: adjacency_matrix = adata.obsp[spatial_connectivities_key].tocsc() obs_mask = ~adata.obs[groups].isin(subset_groups) - adata = adata[obs_mask] + adata = adata[obs_mask] # Update adjacency matrix such that it only contains connections to filtered observations adjacency_matrix = adjacency_matrix[obs_mask, :][:, obs_mask] adata.obsp[spatial_connectivities_key] = adjacency_matrix.tocsr() - + # get obs x neighbor matrix from sparse matrix matrix = adata.obsp[spatial_connectivities_key].tocoo() nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0]))) @@ -260,7 +262,6 @@ def _calculate_neighborhood_profile( return pd.DataFrame(rel_freq, index=adata.obs.index), pd.DataFrame(abs_freq, index=adata.obs.index) - def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: """Performs inner product of adjacency matrix and feature matrix, such that each observation inherits features from its immediate neighbors as described in UTAG paper. From 646dbb0c857d23746f73ba5f687f6af8956f182a Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 2 Oct 2024 15:30:04 +0200 Subject: [PATCH 38/67] Update doctstring --- src/squidpy/gr/_niche.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index f543b21a..848832c1 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -72,15 +72,28 @@ def calculate_niche( Required if flavor == 'neighborhood' or flavor == 'UTAG'. subset_groups Groups (e.g. cell type categories) to ignore when calculating the neighborhood profile. - Optional if flavor == 'neighborhood'. + Optional if flavor == 'neighborhood'. min_niche_size Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. Optional if flavor == 'neighborhood'. scale If 'True', compute z-scores of neighborhood profiles. Optional if flavor == 'neighborhood'. + abs_nhood + If 'True', calculate niches based on absolute neighborhood profile. + Optional if flavor == 'neighborhood'. + adj_subsets + List of adjacency matrices to use e.g. [1,2,3] for 1,2,3 neighbors respectively. + Required if flavor == 'cellcharter'. + aggregation + How to aggregate count matrices. Either 'mean' or 'variance'. + Required if flavor == 'cellcharter'. spatial_key Location of spatial coordinates in `adata.obsm`. + spatial_connectivities_key + Key in `adata.obsp` where spatial connectivities are stored. + spatial_distances_key + Key in `adata.obsp` where spatial distances are stored. %(copy)s """ @@ -126,12 +139,12 @@ def calculate_niche( if scale: sc.pp.scale(adata_neighborhood, zero_center=True) - + if mask is not None: if subset_groups is not None: mask = mask[mask.index.isin(adata_neighborhood.obs.index)] adata_neighborhood = adata_neighborhood[mask] - + sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") if resolutions is not None: @@ -139,18 +152,20 @@ def calculate_niche( resolutions = [resolutions] else: raise ValueError("Please provide resolutions for leiden clustering.") - - #adata_neighborhood.index = subset_index + + # adata_neighborhood.index = subset_index for res in resolutions: sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}") - print(adata_neighborhood.obs) - adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map(adata_neighborhood.obs[f"neighborhood_niche_res={res}"]).fillna('not_a_niche') + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map( + adata_neighborhood.obs[f"neighborhood_niche_res={res}"] + ).fillna("not_a_niche") if min_niche_size is not None: counts_by_niche = adata.obs[f"neighborhood_niche_res={res}"].value_counts() to_filter = counts_by_niche[counts_by_niche < min_niche_size].index - adata.obs[f"neighborhood_niche_res={res}"] = adata.obs[f"neighborhood_niche_res={res}"].apply(lambda x: 'not_a_niche' if x in to_filter else x) - + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs[f"neighborhood_niche_res={res}"].apply( + lambda x, to_filter=to_filter: "not_a_niche" if x in to_filter else x + ) elif flavor == "utag": new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) @@ -220,16 +235,15 @@ def _calculate_neighborhood_profile( subset_groups: list[str] | None, spatial_connectivities_key: str, ) -> tuple[pd.DataFrame, pd.DataFrame]: - if subset_groups: adjacency_matrix = adata.obsp[spatial_connectivities_key].tocsc() obs_mask = ~adata.obs[groups].isin(subset_groups) - adata = adata[obs_mask] + adata = adata[obs_mask] # Update adjacency matrix such that it only contains connections to filtered observations adjacency_matrix = adjacency_matrix[obs_mask, :][:, obs_mask] adata.obsp[spatial_connectivities_key] = adjacency_matrix.tocsr() - + # get obs x neighbor matrix from sparse matrix matrix = adata.obsp[spatial_connectivities_key].tocoo() nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0]))) @@ -260,7 +274,6 @@ def _calculate_neighborhood_profile( return pd.DataFrame(rel_freq, index=adata.obs.index), pd.DataFrame(abs_freq, index=adata.obs.index) - def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: """Performs inner product of adjacency matrix and feature matrix, such that each observation inherits features from its immediate neighbors as described in UTAG paper. From c9b4dc5ae7e580f50c8a6c9e53d9438d021dabdd Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 21:44:14 +0200 Subject: [PATCH 39/67] Update CellCharter approach --- src/squidpy/gr/_niche.py | 220 +++++++++++++++++++++++++++------------ 1 file changed, 155 insertions(+), 65 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index ed21ed33..b231dd2a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -2,17 +2,20 @@ import itertools from collections.abc import Iterator +from typing import Any import anndata as ad import numpy as np import pandas as pd import scanpy as sc +import scipy.sparse as sps from anndata import AnnData -from scipy.sparse import csr_matrix, hstack +from scipy.sparse import csr_matrix, hstack, issparse, spdiags from scipy.stats import ranksums from sklearn import metrics from sklearn.metrics import adjusted_rand_score, fowlkes_mallows_score, normalized_mutual_info_score -from sklearn.preprocessing import normalize +from sklearn.mixture import GaussianMixture +from sklearn.preprocessing import StandardScaler, normalize from spatialdata import SpatialData from squidpy._utils import NDArrayA @@ -33,8 +36,10 @@ def calculate_niche( min_niche_size: int | None = None, scale: bool = True, abs_nhood: bool = False, - adj_subsets: list[int] | None = None, + adj_subsets: int | list[int] | None = None, aggregation: str = "mean", + n_components: int = 3, + random_state: int = 42, spatial_key: str = "spatial", spatial_connectivities_key: str = "spatial_connectivities", spatial_distances_key: str = "spatial_distances", @@ -53,9 +58,8 @@ def calculate_niche( - `{c.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - `{c.SPOT.s!r}` - calculate niches using optimal transport. - `{c.BANKSY.s!r}`- use Banksy algorithm. - - `{c.CELLCHARTER.s!r}` - use cellcharter. + - `{c.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. - `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication). - - `{c.ALL.s!r}` - apply all available methods and compare them using cluster validation scores. %(library_key)s subset Restrict niche calculation to a subset of the data. @@ -89,6 +93,12 @@ def calculate_niche( aggregation How to aggregate count matrices. Either 'mean' or 'variance'. Required if flavor == 'cellcharter'. + n_components + Number of components to use for GMM. + Required if flavor == 'cellcharter'. + random_state + Random state to use for GMM. + Required if flavor == 'cellcharter'. spatial_key Location of spatial coordinates in `adata.obsm`. spatial_connectivities_key @@ -99,11 +109,8 @@ def calculate_niche( """ # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present if no table is specified - is_sdata = False if isinstance(adata, SpatialData): - is_sdata = True if table_key is not None: - sdata = adata adata = adata.tables[table_key].copy() else: if len(adata.tables) > 1: @@ -183,49 +190,40 @@ def calculate_niche( adata.obs[f"utag_res={res}"] = adata_utag.obs[f"utag_res={res}"].values elif flavor == "cellcharter": - adj_matrix_subsets = [] - if isinstance(adj_subsets, list): - for k in adj_subsets: - if k == 0: - adj_matrix_subsets.append(adata.obsp[spatial_connectivities_key]) - else: - adj_matrix_subsets.append( - _get_adj_matrix_subsets( - adata.obsp[spatial_connectivities_key], adata.obsp[spatial_distances_key], k - ) - ) - if aggregation == "mean": - inner_products = [adj_subset.dot(adata.X) for adj_subset in adj_matrix_subsets] - elif aggregation == "variance": - inner_products = [ - _aggregate_var(matrix, adata.obsp[spatial_connectivities_key], adata) for matrix in inner_products - ] + adjacency_matrix = adata.obsp[spatial_connectivities_key] + if not isinstance(adj_subsets, list): + if adj_subsets is not None: + adj_subsets = list(range(adj_subsets + 1)) else: raise ValueError( - f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'." + "flavor 'cellcharter' requires adj_subsets to not be None. Specify list of values or maximum value of neighbors to use." ) - concatenated_matrix = hstack(inner_products) - - # create df from sparse matrix - arr = concatenated_matrix.toarray() - df = pd.DataFrame(arr, index=adata.obs.index) - col_names = [] - for A_i in adj_subsets: - for var in adata.var_names: - col_names.append(f"{var}_Adj_{A_i}") - df.columns = col_names - - if copy: - return concatenated_matrix + + aggregated_matrices = [] + adj_hop = _setdiag(adjacency_matrix, 0) # Remove self-loops, set diagonal to 0 + adj_visited = _setdiag(adjacency_matrix.copy(), 1) # Track visited neighbors + for k in adj_subsets: + if k == 0: + # If k == 0, we're using the original cell features (no neighbors) + aggregated_matrices.append(adata.X) else: - if is_sdata: - sdata.tables[f"{flavor}_niche"] = ad.AnnData(concatenated_matrix) - else: - adata.obsm[f"{flavor}_niche"] = df - else: - raise ValueError( - "Flavor 'cellcharter' requires list of neighbors to build adjacency matrices. Please provide a list of k_neighbors for 'adj_subsets'." - ) + if k > 1: + adj_hop, adj_visited = _hop(adj_hop, adjacency_matrix, adj_visited) + + adj_hop_norm = _normalize(adj_hop) # Normalize adjacency matrix for current hop + + # Apply aggregation, default to "mean" unless specified otherwise + aggregated_matrix = _aggregate(adata, adj_hop_norm, aggregation) + + # Collect the aggregated matrices + aggregated_matrices.append(aggregated_matrix) + + concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally + arr = concatenated_matrix.toarray() # Densify the sparse matrix + + niches = _get_GMM_clusters(arr, n_components, random_state) + + adata.obs[f"{flavor}_niche"] = pd.Categorical(niches) def _calculate_neighborhood_profile( @@ -293,28 +291,120 @@ def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> return adjacency_matrix @ adata.X -def _get_adj_matrix_subsets(connectivities: csr_matrix, distances: csr_matrix, k_neighbors: int) -> csr_matrix: - # Convert the distance matrix to a dense format - dist_dense = distances.todense() +def _setdiag(adjacency_matrix: sps.spmatrix, value: int) -> sps.spmatrix: + if issparse(adjacency_matrix): + adjacency_matrix = adjacency_matrix.tolil() + adjacency_matrix.setdiag(value) + adjacency_matrix = adjacency_matrix.tocsr() + if value == 0: + adjacency_matrix.eliminate_zeros() + return adjacency_matrix + + +def _hop( + adj_hop: sps.spmatrix, adj: sps.spmatrix, adj_visited: sps.spmatrix = None +) -> tuple[sps.spmatrix, sps.spmatrix]: + adj_hop = adj_hop @ adj + + if adj_visited is not None: + adj_hop = adj_hop > adj_visited + adj_visited = adj_visited + adj_hop + + return adj_hop, adj_visited + + +def _normalize(adj: sps.spmatrix) -> sps.spmatrix: + deg = np.array(np.sum(adj, axis=1)).squeeze() + with np.errstate(divide="ignore"): + deg_inv = 1 / deg + deg_inv[deg_inv == float("inf")] = 0 + + return spdiags(deg_inv, 0, len(deg_inv), len(deg_inv)) * adj + + +def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggregation: str = "mean") -> Any: + if aggregation == "mean": + aggregated_matrix = normalized_adjacency_matrix @ adata.X + elif aggregation == "variance": + mean_matrix = normalized_adjacency_matrix @ adata.X + mean_squared_matrix = normalized_adjacency_matrix @ (adata.X * adata.X) + aggregated_matrix = mean_squared_matrix - mean_matrix * mean_matrix + else: + raise ValueError(f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'.") + + return aggregated_matrix + + +def _get_GMM_clusters(A: np.ndarray[float, Any], n_components: int, random_state: int) -> Any: + """Returns niche labels generated by GMM clustering. + Compared to cellcharter this approach is simplified by using sklearn's GaussianMixture model without stability analysis.""" + + gmm = GaussianMixture(n_components=n_components, random_state=random_state, init_params="random_from_data") + gmm.fit(A) + labels = gmm.predict(A) + + return labels + # scaler = StandardScaler() + # A = scaler.fit_transform(A) + + # results = {} + + # for k in range(k_min, k_max + 1): + # fmi_k_minus_1 = [] + # fmi_k_plus_1 = [] + + # previous_fmi_diff = np.inf + + # for r in range(R): + # # Clustering with K-1, K, and K+1 clusters respectively + # gmm_k_minus_1 = GaussianMixture(n_components=k-1, random_state=r, n_init=10, init_params="random_from_data").fit(A) if k > 1 else None + # gmm_k = GaussianMixture(n_components=k, random_state=r, n_init=10, init_params="random_from_data").fit(A) + # gmm_k_plus_1 = GaussianMixture(n_components=k+1, random_state=r, n_init=10, init_params="random_from_data").fit(A) + + # labels_k_minus_1 = gmm_k_minus_1.predict(A) if k > 1 else None + # labels_k = gmm_k.predict(A) + # labels_k_plus_1 = gmm_k_plus_1.predict(A) + + # # Calculate FMI between K-1 and K, and K and K+1 + # if k > 1: + # fmi_k_minus_1.append(fowlkes_mallows_score(labels_k_minus_1, labels_k)) + # fmi_k_plus_1.append(fowlkes_mallows_score(labels_k, labels_k_plus_1)) + + # # Compute mean FMIs + # mean_fmi_k_minus_1 = np.mean(fmi_k_minus_1) if fmi_k_minus_1 else None + # mean_fmi_k_plus_1 = np.mean(fmi_k_plus_1) + + # # Check convergence (mean average percentage error) + # current_fmi_diff = abs(mean_fmi_k_plus_1 - (mean_fmi_k_minus_1 or 0)) + # if r > 0 and previous_fmi_diff - current_fmi_diff < tolerance: + # print(f"Converged for K={k} after {r+1} runs.") + # break + + # previous_fmi_diff = current_fmi_diff + + # # Store the results for this K + # results[k] = { + # "fmi_k_minus_1": mean_fmi_k_minus_1, + # "fmi_k_plus_1": mean_fmi_k_plus_1, + # "best_model": gmm_k + # } - # Find the indices of the k closest neighbors for each row - closest_neighbors_indices = np.argsort(dist_dense, axis=1)[:, :k_neighbors] + # # Now find the K with the most stable clustering (highest average FMI) + # optimal_k = None + # best_model = None + # best_stability = -np.inf - # Initialize lists to collect data for the new sparse matrix - rows = [] - cols = [] - data = [] + # for k in range(k_min, k_max + 1): + # if k > 1: + # avg_fmi = np.mean([results[k]["fmi_k_minus_1"], results[k]["fmi_k_plus_1"]]) + # if avg_fmi > best_stability: + # best_stability = avg_fmi + # optimal_k = k + # best_model = results[k]["best_model"] - # Iterate over each row to construct the new adjacency matrix - for row in range(dist_dense.shape[0]): - for col in closest_neighbors_indices[row].flat: - rows.append(row) - cols.append(col) - data.append(connectivities[row, col]) + # k_labels = best_model.predict(A) - # Create the new sparse matrix with the reduced neighbors - new_adj_matrix = csr_matrix((data, (rows, cols)), shape=connectivities.shape) - return new_adj_matrix + # return optimal_k, k_labels def _df_to_adata(df: pd.DataFrame) -> AnnData: From b1aa25fcb5935238473a333b937446c09e8e3b00 Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 21:55:27 +0200 Subject: [PATCH 40/67] Remove commented-out code --- src/squidpy/gr/_niche.py | 61 ---------------------------------------- 1 file changed, 61 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index b231dd2a..8aaa6a3c 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -344,67 +344,6 @@ def _get_GMM_clusters(A: np.ndarray[float, Any], n_components: int, random_state labels = gmm.predict(A) return labels - # scaler = StandardScaler() - # A = scaler.fit_transform(A) - - # results = {} - - # for k in range(k_min, k_max + 1): - # fmi_k_minus_1 = [] - # fmi_k_plus_1 = [] - - # previous_fmi_diff = np.inf - - # for r in range(R): - # # Clustering with K-1, K, and K+1 clusters respectively - # gmm_k_minus_1 = GaussianMixture(n_components=k-1, random_state=r, n_init=10, init_params="random_from_data").fit(A) if k > 1 else None - # gmm_k = GaussianMixture(n_components=k, random_state=r, n_init=10, init_params="random_from_data").fit(A) - # gmm_k_plus_1 = GaussianMixture(n_components=k+1, random_state=r, n_init=10, init_params="random_from_data").fit(A) - - # labels_k_minus_1 = gmm_k_minus_1.predict(A) if k > 1 else None - # labels_k = gmm_k.predict(A) - # labels_k_plus_1 = gmm_k_plus_1.predict(A) - - # # Calculate FMI between K-1 and K, and K and K+1 - # if k > 1: - # fmi_k_minus_1.append(fowlkes_mallows_score(labels_k_minus_1, labels_k)) - # fmi_k_plus_1.append(fowlkes_mallows_score(labels_k, labels_k_plus_1)) - - # # Compute mean FMIs - # mean_fmi_k_minus_1 = np.mean(fmi_k_minus_1) if fmi_k_minus_1 else None - # mean_fmi_k_plus_1 = np.mean(fmi_k_plus_1) - - # # Check convergence (mean average percentage error) - # current_fmi_diff = abs(mean_fmi_k_plus_1 - (mean_fmi_k_minus_1 or 0)) - # if r > 0 and previous_fmi_diff - current_fmi_diff < tolerance: - # print(f"Converged for K={k} after {r+1} runs.") - # break - - # previous_fmi_diff = current_fmi_diff - - # # Store the results for this K - # results[k] = { - # "fmi_k_minus_1": mean_fmi_k_minus_1, - # "fmi_k_plus_1": mean_fmi_k_plus_1, - # "best_model": gmm_k - # } - - # # Now find the K with the most stable clustering (highest average FMI) - # optimal_k = None - # best_model = None - # best_stability = -np.inf - - # for k in range(k_min, k_max + 1): - # if k > 1: - # avg_fmi = np.mean([results[k]["fmi_k_minus_1"], results[k]["fmi_k_plus_1"]]) - # if avg_fmi > best_stability: - # best_stability = avg_fmi - # optimal_k = k - # best_model = results[k]["best_model"] - - # k_labels = best_model.predict(A) - - # return optimal_k, k_labels def _df_to_adata(df: pd.DataFrame) -> AnnData: From 72d8dc81faba54382815853c4961d17bf2dfcf5b Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 22:01:15 +0200 Subject: [PATCH 41/67] Remove draft validation methods --- src/squidpy/gr/_niche.py | 138 +-------------------------------------- 1 file changed, 1 insertion(+), 137 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 8aaa6a3c..8e893f7a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -20,7 +20,7 @@ from squidpy._utils import NDArrayA -__all__ = ["calculate_niche", "build_consensus_niche"] +__all__ = ["calculate_niche"] def calculate_niche( @@ -483,139 +483,3 @@ def _iter_uid( yield adata[adata.obs[slide_key] == slide] else: yield adata - - -def _compare_niche_definitions(adata: AnnData, niche_definitions: list[str]) -> pd.DataFrame: - """Given different clustering results, compare them using different scores.""" - - result = pd.DataFrame(index=niche_definitions, columns=niche_definitions, data=None, dtype=float) - combinations = list(itertools.combinations_with_replacement(niche_definitions, 2)) - scores = {"ARI:": adjusted_rand_score, "NMI": normalized_mutual_info_score, "FMI": fowlkes_mallows_score} - - # for each score, apply it on all pairs of niche definitions - for score_name, score_func in scores.items(): - for pair in combinations: - score = score_func(adata.obs[pair[0]], adata.obs[pair[1]]) - result.at[pair[0], pair[1]] = score - result.at[pair[1], pair[0]] = score - adata.uns[f"niche_definition_comparison_{score_name}"] = result - - -def _get_subset_indices(df: pd.DataFrame, column: str) -> pd.Series: - return df.groupby(column).apply(lambda x: set(x.index)) - - -def _find_best_match(subset: set[str], other_subsets: dict[str, set[str]], exclude: set[str]) -> tuple[str, float]: - """Find best matching niche pair between two sets of niche definitions. - Niches which have already been matched, are excluded from further comparisons.""" - - best_match = "" - max_overlap = 0.0 - for other_subset, indices in other_subsets.items(): - if other_subset in exclude: - continue # Skip excluded matches - overlap = len(subset & indices) / len(subset | indices) # jaccard index - if overlap > max_overlap: - max_overlap = overlap - best_match = other_subset - return (best_match, max_overlap) - - -def _get_initial_niches(niche_definitions: list[dict[str, set[str]]]) -> dict[str, set[str]]: - """Select the niche definition with the fewest amount of unique niches.""" - - min_niches = {} - min_niche_count = float("inf") - - for niches in niche_definitions: - niche_count = len(niches) - - if niche_count < min_niche_count: - min_niches = niches - min_niche_count = niche_count - - return min_niches - - -def _filter_overlap(initial_consensus: dict[str, set[str]]) -> dict[str, str]: - """ "Remove labels which are present in multiple niches. Labels are always kept in the niche with higher average jaccard index.""" - - filtered_consensus = {} - processed_elements: set[str] = set() - - for key, values in initial_consensus.items(): - unique_values = values - processed_elements # Remove already processed elements - for value in unique_values: - filtered_consensus[value] = key # Swap key and value to make further processing easier - processed_elements.update(unique_values) # Mark value as processed - - return filtered_consensus - - -def build_consensus_niche(adata: AnnData, niche_definitions: list[str], merge: str = "union") -> AnnData: - """Given multiple niche definitions, construct a consensus niche using set matching. - Each niche definition is treated as a set of subsets. For each subset in set A we look for the best matching subset in set B. - Once a match has been found, these sets are merged either by union or intersection. This merged set is then used as the new set A for the next iteration. - The final consensus niches are filtered for overlapping labels and stored as a new column in `adata.obs`. - Parameters - ---------- - %(adata)s - niche_definitions - Name of columns in `adata.obs` where previously calculated niches are stored. - merge - - `{c.union.s!r}`- merge niche matches via union join. - - `{c.intersection.s!r}` - merge niche matches by their intersection. - """ - - list_of_sets = [] - for definition in niche_definitions: - list_of_sets.append(_get_subset_indices(adata.obs, definition)) - - union_of_matches = _get_initial_niches(list_of_sets) - - avg_jaccard = np.zeros(len(union_of_matches)) # the jaccard index is tracked to order the consensus niches later on - - for set_of_sets in range(len(list_of_sets) - 1): - current_matches = {} - used_matches: set[str] = set() - matches_A_B = { - subset: _find_best_match(indices, list_of_sets[set_of_sets + 1], exclude=used_matches) - for subset, indices in union_of_matches.items() - } - ranked_matches = sorted(matches_A_B.items(), key=lambda x: x[1][1], reverse=True) - for subset_A, (match, jaccard_index) in ranked_matches: - if match not in used_matches: - current_matches[subset_A] = (match, jaccard_index) - used_matches.add(match) - else: - new_match, new_jaccard = _find_best_match( - union_of_matches[subset_A], list_of_sets[set_of_sets + 1], exclude=used_matches - ) - if new_match: - current_matches[subset_A] = (new_match, new_jaccard) - used_matches.add(new_match) - - jaccard = np.asarray([jaccard_index for _, (_, jaccard_index) in current_matches.items()]) - avg_jaccard = (avg_jaccard + jaccard) / (set_of_sets + 1) - - if merge == "union": - consensus = { - subset_A: union_of_matches[subset_A] | list_of_sets[set_of_sets + 1][match] - for subset_A, (match, _) in current_matches.items() - } - if merge == "intersection": - consensus = { - subset_A: union_of_matches[subset_A] & list_of_sets[set_of_sets + 1][match] - for subset_A, (match, _) in current_matches.items() - } - - niche_categories = list(consensus.keys()) - consensus_by_jaccard = dict(zip(niche_categories, avg_jaccard)) - - sorted_by_jaccard = dict( - sorted(consensus_by_jaccard.items(), key=lambda item: item[1], reverse=True), - ) - sorted_consensus = {key: consensus[key] for key in sorted_by_jaccard} - filtered_consensus = _filter_overlap(sorted_consensus) - - adata.obs["consensus_niche"] = adata.obs.index.map(filtered_consensus).fillna("None") From a64d0fd6abe0e4da49ed91e119836763db88bced Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 23:06:35 +0200 Subject: [PATCH 42/67] Update init; Fix mypy --- src/squidpy/gr/__init__.py | 2 +- src/squidpy/gr/_niche.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/__init__.py b/src/squidpy/gr/__init__.py index ac558368..0122a249 100644 --- a/src/squidpy/gr/__init__.py +++ b/src/squidpy/gr/__init__.py @@ -5,7 +5,7 @@ from squidpy.gr._build import mask_graph, spatial_neighbors from squidpy.gr._ligrec import ligrec from squidpy.gr._nhood import centrality_scores, interaction_matrix, nhood_enrichment -from squidpy.gr._niche import build_consensus_niche, calculate_niche +from squidpy.gr._niche import calculate_niche from squidpy.gr._ppatterns import co_occurrence, spatial_autocorr from squidpy.gr._ripley import ripley from squidpy.gr._sepal import sepal diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 8e893f7a..e8d1452e 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -335,7 +335,7 @@ def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggreg return aggregated_matrix -def _get_GMM_clusters(A: np.ndarray[float, Any], n_components: int, random_state: int) -> Any: +def _get_GMM_clusters(A: np.ndarray[np.float64, Any], n_components: int, random_state: int) -> Any: """Returns niche labels generated by GMM clustering. Compared to cellcharter this approach is simplified by using sklearn's GaussianMixture model without stability analysis.""" From 5bde71290a7b3bb89186bd1af66049888a2d52db Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 23:17:23 +0200 Subject: [PATCH 43/67] Fix mypy --- src/squidpy/gr/_ppatterns.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 56bcc629..f34ff9c4 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Sequence from itertools import chain -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, Dict import numba.types as nt import numpy as np @@ -219,21 +219,24 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr with np.errstate(divide="ignore"): pval_results = _p_value_calc(score, score_perms, g, params) - - df = pd.DataFrame({params["stat"]: score, **pval_results}, index=index) + + data_dict: Dict[str, Dict[str, Any]] = {params["stat"]: score, **pval_results} + df = pd.DataFrame(data_dict, index=index) if corr_method is not None: for pv in filter(lambda x: "pval" in x, df.columns): _, pvals_adj, _, _ = multipletests(df[pv].values, alpha=0.05, method=corr_method) df[f"{pv}_{corr_method}"] = pvals_adj - + df.sort_values(by=params["stat"], ascending=params["ascending"], inplace=True) if copy: logg.info("Finish", time=start) return df - _save_data(adata, attr="uns", key=params["mode"] + params["stat"], data=df, time=start) + mode_str = str(params["mode"]) + stat_str = str(params["stat"]) + _save_data(adata, attr="uns", key=mode_str + stat_str, data=df, time=start) def _score_helper( From 79fcbd0f3c1dbe7d4184fef4994e3cc760334186 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 8 Oct 2024 21:17:47 +0000 Subject: [PATCH 44/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/squidpy/gr/_ppatterns.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index f34ff9c4..ce556bec 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Sequence from itertools import chain -from typing import TYPE_CHECKING, Any, Literal, Dict +from typing import TYPE_CHECKING, Any, Dict, Literal import numba.types as nt import numpy as np @@ -219,7 +219,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr with np.errstate(divide="ignore"): pval_results = _p_value_calc(score, score_perms, g, params) - + data_dict: Dict[str, Dict[str, Any]] = {params["stat"]: score, **pval_results} df = pd.DataFrame(data_dict, index=index) @@ -227,7 +227,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr for pv in filter(lambda x: "pval" in x, df.columns): _, pvals_adj, _, _ = multipletests(df[pv].values, alpha=0.05, method=corr_method) df[f"{pv}_{corr_method}"] = pvals_adj - + df.sort_values(by=params["stat"], ascending=params["ascending"], inplace=True) if copy: From 67a513315304d417d4ac50d8d09db0765f79d2d1 Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 23:19:01 +0200 Subject: [PATCH 45/67] Fix mypy --- src/squidpy/gr/_ppatterns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index ce556bec..87100d21 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Sequence from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Literal +from typing import TYPE_CHECKING, Any, Literal import numba.types as nt import numpy as np @@ -220,7 +220,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr with np.errstate(divide="ignore"): pval_results = _p_value_calc(score, score_perms, g, params) - data_dict: Dict[str, Dict[str, Any]] = {params["stat"]: score, **pval_results} + data_dict: dict[str, dict[str, Any]] = {params["stat"]: score, **pval_results} df = pd.DataFrame(data_dict, index=index) if corr_method is not None: From bc387377848dddbca9dfdfb82b2e049679fe0aae Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 23:20:47 +0200 Subject: [PATCH 46/67] Fix mypy --- src/squidpy/gr/_ppatterns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 87100d21..49dd91b4 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -220,7 +220,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr with np.errstate(divide="ignore"): pval_results = _p_value_calc(score, score_perms, g, params) - data_dict: dict[str, dict[str, Any]] = {params["stat"]: score, **pval_results} + data_dict: dict[str, Any] = {params["stat"]: score, **pval_results} df = pd.DataFrame(data_dict, index=index) if corr_method is not None: From 342375cfadc9e44e67de637ef9a365dfe3aa11fc Mon Sep 17 00:00:00 2001 From: LLehner Date: Tue, 8 Oct 2024 23:25:06 +0200 Subject: [PATCH 47/67] Fix mypy --- src/squidpy/gr/_ppatterns.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 49dd91b4..ccd103d2 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -220,7 +220,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr with np.errstate(divide="ignore"): pval_results = _p_value_calc(score, score_perms, g, params) - data_dict: dict[str, Any] = {params["stat"]: score, **pval_results} + data_dict: dict[str, Any] = {str(params["stat"]): score, **pval_results} df = pd.DataFrame(data_dict, index=index) if corr_method is not None: From cb4e4d1ad52e5e42b3e2f1e543efc95ee0f63cb5 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 9 Oct 2024 17:11:53 +0200 Subject: [PATCH 48/67] Add comments; Remove draft evaluation function --- src/squidpy/gr/_niche.py | 169 +++++++++++++++------------------------ 1 file changed, 63 insertions(+), 106 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index e8d1452e..cc3be22c 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -25,11 +25,11 @@ def calculate_niche( adata: AnnData | SpatialData, - groups: str, flavor: str = "neighborhood", library_key: str | None = None, table_key: str | None = None, mask: pd.core.series.Series = None, + groups: str | None = None, n_neighbors: int = 15, resolutions: int | list[float] | None = None, subset_groups: list[str] | None = None, @@ -40,10 +40,7 @@ def calculate_niche( aggregation: str = "mean", n_components: int = 3, random_state: int = 42, - spatial_key: str = "spatial", spatial_connectivities_key: str = "spatial_connectivities", - spatial_distances_key: str = "spatial_distances", - copy: bool = False, ) -> AnnData | pd.DataFrame: """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. The resulting niche labels with be stored in 'adata.obs'. If flavor = 'all' then all available methods @@ -51,15 +48,13 @@ def calculate_niche( Parameters ---------- %(adata)s - groups - groups based on which to calculate neighborhood profile. flavor Method to use for niche calculation. Available options are: - `{c.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - - `{c.SPOT.s!r}` - calculate niches using optimal transport. - - `{c.BANKSY.s!r}`- use Banksy algorithm. - - `{c.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. - `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication). + - `{c.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. + - `{c.SPOT.s!r}` - calculate niches using optimal transport. (coming soon) + - `{c.BANKSY.s!r}`- use Banksy algorithm. (coming soon) %(library_key)s subset Restrict niche calculation to a subset of the data. @@ -68,6 +63,9 @@ def calculate_niche( mask Boolean array to filter cells which won't get assigned to a niche. Note that if you want to exclude these cells during neighborhood calculation already, you should subset your AnnData table before running 'sq.gr.spatial_neigbors'. + groups + Groups based on which to calculate neighborhood profile (E.g. columns of cell type annotations in adata.obs). + Required if flavor == 'neighborhood'. n_neighbors Number of neighbors to use for 'scanpy.pp.neighbors' before clustering using leiden algorithm. Required if flavor == 'neighborhood' or flavor == 'UTAG'. @@ -77,7 +75,6 @@ def calculate_niche( subset_groups Groups (e.g. cell type categories) to ignore when calculating the neighborhood profile. Optional if flavor == 'neighborhood'. - Optional if flavor == 'neighborhood'. min_niche_size Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. Optional if flavor == 'neighborhood'. @@ -88,7 +85,7 @@ def calculate_niche( If 'True', calculate niches based on absolute neighborhood profile. Optional if flavor == 'neighborhood'. adj_subsets - List of adjacency matrices to use e.g. [1,2,3] for 1,2,3 neighbors respectively. + List of adjacency matrices to use e.g. [1,2,3] for 1-hop,2-hop,3-hop neighbors respectively or "5" for 1-hop,...,5-hop neighbors. 0 (self) is always included. Required if flavor == 'cellcharter'. aggregation How to aggregate count matrices. Either 'mean' or 'variance'. @@ -98,62 +95,46 @@ def calculate_niche( Required if flavor == 'cellcharter'. random_state Random state to use for GMM. - Required if flavor == 'cellcharter'. - spatial_key - Location of spatial coordinates in `adata.obsm`. + Optional if flavor == 'cellcharter'. spatial_connectivities_key Key in `adata.obsp` where spatial connectivities are stored. - spatial_distances_key - Key in `adata.obsp` where spatial distances are stored. - %(copy)s """ - # check whether anndata or spatialdata is provided and if spatialdata, check whether a table with the provided groups is present if no table is specified + # check whether anndata or spatialdata is provided and if spatialdata, check whether table_key is provided if isinstance(adata, SpatialData): if table_key is not None: adata = adata.tables[table_key].copy() else: - if len(adata.tables) > 1: - count = 0 - for table in adata.tables.keys(): - if groups in table.obs: - count += 1 - table_key = table - if count > 1: - raise ValueError( - f"Multiple tables in `spatialdata` with group `{groups}` detected. Please specify which table to use in `table_key`." - ) - elif count == 0: - raise ValueError( - f"Group `{groups}` not found in any table in `spatialdata`. Please specify a valid group in `groups`." - ) - else: - adata = adata.tables[table_key].copy() - else: - ((key, adata),) = adata.tables.items() - if groups not in adata.obs: - raise ValueError( - f"Group {groups} not found in table in `spatialdata`. Please specify a valid group in `groups`." - ) + raise ValueError("Please specify which table to use with `table_key`.") + else: + adata = adata if flavor == "neighborhood": + """adapted from https://github.com/immunitastx/monkeybread/blob/main/src/monkeybread/calc/_neighborhood_profile.py""" + + # calculate the neighborhood profile for each cell (relative and absolute proportion of e.g. each cell type in the neighborhood) rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( adata, groups, subset_groups, spatial_connectivities_key ) + # create AnnData object from neighborhood profile to perform scanpy functions if not abs_nhood: adata_neighborhood = ad.AnnData(X=rel_nhood_profile) else: adata_neighborhood = ad.AnnData(X=abs_nhood_profile) + # reason for scaling see https://monkeybread.readthedocs.io/en/latest/notebooks/tutorial.html#niche-analysis if scale: sc.pp.scale(adata_neighborhood, zero_center=True) + # mask obs to exclude cells for which no niche shall be assigned if mask is not None: - if subset_groups is not None: - mask = mask[mask.index.isin(adata_neighborhood.obs.index)] + mask = mask[mask.index.isin(adata_neighborhood.obs.index)] adata_neighborhood = adata_neighborhood[mask] + # required for leiden clustering (note: no dim reduction performed in original implementation) + print("calculating neighbors...") sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") + print("finished calculating neighbors") if resolutions is not None: if not isinstance(resolutions, list): @@ -161,11 +142,16 @@ def calculate_niche( else: raise ValueError("Please provide resolutions for leiden clustering.") + # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label + print("starting clustering...") for res in resolutions: sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}") adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map( adata_neighborhood.obs[f"neighborhood_niche_res={res}"] ).fillna("not_a_niche") + print(f"finished clustering at resolution {res}") + + # filter niches with n_cells < min_niche_size if min_niche_size is not None: counts_by_niche = adata.obs[f"neighborhood_niche_res={res}"].value_counts() to_filter = counts_by_niche[counts_by_niche < min_niche_size].index @@ -174,9 +160,11 @@ def calculate_niche( ) elif flavor == "utag": + """adapted from https://github.com/ElementoLab/utag/blob/main/utag/segmentation.py""" + new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) adata_utag = ad.AnnData(X=new_feature_matrix) - sc.tl.pca(adata_utag) + sc.tl.pca(adata_utag) # note: unlike with flavor 'neighborhood' dim reduction is performed here sc.pp.neighbors(adata_utag, n_neighbors=n_neighbors, use_rep="X_pca") if resolutions is not None: @@ -185,11 +173,15 @@ def calculate_niche( else: raise ValueError("Please provide resolutions for leiden clustering.") + # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label for res in resolutions: sc.tl.leiden(adata_utag, resolution=res, key_added=f"utag_res={res}") adata.obs[f"utag_res={res}"] = adata_utag.obs[f"utag_res={res}"].values elif flavor == "cellcharter": + """adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py + and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py""" + adjacency_matrix = adata.obsp[spatial_connectivities_key] if not isinstance(adj_subsets, list): if adj_subsets is not None: @@ -198,29 +190,31 @@ def calculate_niche( raise ValueError( "flavor 'cellcharter' requires adj_subsets to not be None. Specify list of values or maximum value of neighbors to use." ) + else: + if 0 not in adj_subsets: + adj_subsets.insert(0, 0) + if any(x < 0 for x in adj_subsets): + raise ValueError("adj_subsets must contain non-negative integers.") aggregated_matrices = [] adj_hop = _setdiag(adjacency_matrix, 0) # Remove self-loops, set diagonal to 0 adj_visited = _setdiag(adjacency_matrix.copy(), 1) # Track visited neighbors for k in adj_subsets: if k == 0: - # If k == 0, we're using the original cell features (no neighbors) + # get original count matrix (not aggregated) aggregated_matrices.append(adata.X) else: + # get count and adjacency matrix for k-hop (neighbor of neighbor of neighbor ...) and aggregate them if k > 1: adj_hop, adj_visited = _hop(adj_hop, adjacency_matrix, adj_visited) - - adj_hop_norm = _normalize(adj_hop) # Normalize adjacency matrix for current hop - - # Apply aggregation, default to "mean" unless specified otherwise + adj_hop_norm = _normalize(adj_hop) aggregated_matrix = _aggregate(adata, adj_hop_norm, aggregation) - - # Collect the aggregated matrices aggregated_matrices.append(aggregated_matrix) concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally - arr = concatenated_matrix.toarray() # Densify the sparse matrix + arr = concatenated_matrix.toarray() # Densify + # cluster concatenated matrix with GMM, each cluster label equals to a niche label niches = _get_GMM_clusters(arr, n_components, random_state) adata.obs[f"{flavor}_niche"] = pd.Categorical(niches) @@ -228,15 +222,18 @@ def calculate_niche( def _calculate_neighborhood_profile( adata: AnnData, - groups: str, + groups: str | None, subset_groups: list[str] | None, spatial_connectivities_key: str, ) -> tuple[pd.DataFrame, pd.DataFrame]: + """returns an obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood""" + + if groups is None: + raise ValueError("Please specify 'groups' based on which to calculate neighborhood profile.") if subset_groups: adjacency_matrix = adata.obsp[spatial_connectivities_key].tocsc() obs_mask = ~adata.obs[groups].isin(subset_groups) adata = adata[obs_mask] - adata = adata[obs_mask] # Update adjacency matrix such that it only contains connections to filtered observations adjacency_matrix = adjacency_matrix[obs_mask, :][:, obs_mask] @@ -274,14 +271,7 @@ def _calculate_neighborhood_profile( def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: """Performs inner product of adjacency matrix and feature matrix, - such that each observation inherits features from its immediate neighbors as described in UTAG paper. - - Parameters - ---------- - adata - Annotated data matrix. - normalize - If 'True', aggregate by the mean, else aggregate by the sum.""" + such that each observation inherits features from its immediate neighbors as described in UTAG paper.""" adjacency_matrix = adata.obsp[spatial_connectivity_key] @@ -292,6 +282,8 @@ def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> def _setdiag(adjacency_matrix: sps.spmatrix, value: int) -> sps.spmatrix: + """remove self-loops""" + if issparse(adjacency_matrix): adjacency_matrix = adjacency_matrix.tolil() adjacency_matrix.setdiag(value) @@ -304,6 +296,8 @@ def _setdiag(adjacency_matrix: sps.spmatrix, value: int) -> sps.spmatrix: def _hop( adj_hop: sps.spmatrix, adj: sps.spmatrix, adj_visited: sps.spmatrix = None ) -> tuple[sps.spmatrix, sps.spmatrix]: + """get nearest neighbor of neighbors""" + adj_hop = adj_hop @ adj if adj_visited is not None: @@ -314,6 +308,8 @@ def _hop( def _normalize(adj: sps.spmatrix) -> sps.spmatrix: + """normalize adjacency matrix such that nodes with high degree don't disproportionately affect aggregation""" + deg = np.array(np.sum(adj, axis=1)).squeeze() with np.errstate(divide="ignore"): deg_inv = 1 / deg @@ -323,6 +319,8 @@ def _normalize(adj: sps.spmatrix) -> sps.spmatrix: def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggregation: str = "mean") -> Any: + """aggregate count and adjacency matrix either by mean or variance""" + # TODO: add support for other aggregation methods if aggregation == "mean": aggregated_matrix = normalized_adjacency_matrix @ adata.X elif aggregation == "variance": @@ -339,58 +337,17 @@ def _get_GMM_clusters(A: np.ndarray[np.float64, Any], n_components: int, random_ """Returns niche labels generated by GMM clustering. Compared to cellcharter this approach is simplified by using sklearn's GaussianMixture model without stability analysis.""" + print("initializing GMM...") gmm = GaussianMixture(n_components=n_components, random_state=random_state, init_params="random_from_data") + print("fitting GMM...") gmm.fit(A) + print("predicting labels...") labels = gmm.predict(A) + print("done") return labels -def _df_to_adata(df: pd.DataFrame) -> AnnData: - df.index = df.index.map(str) - adata = AnnData(X=df) - adata.obs.index = df.index - return adata - - -def _aggregate_var(product: csr_matrix, connectivities: csr_matrix, adata: AnnData) -> csr_matrix: - mean_squared = connectivities.dot(adata.X.multiply(adata.X)) - return mean_squared - (product.multiply(product)) - - -def pairwise_niche_comparison( - adata: AnnData, - library_key: str, -) -> pd.DataFrame: - """Do a simple pairwise DE test on the 99th percentile of each gene for each niche. - Can be used to plot heatmap showing similar (large p-value) or different (small p-value) niches. - For validating niche results, the niche pairs that are similar in expression are the ones of interest because - it could hint at niches not being well defined in those cases.""" - niches = adata.obs[library_key].unique().tolist() - niche_dict = {} - # for each niche, calculate the 99th percentile of each gene - for niche in adata.obs[library_key].unique(): - niche_adata = adata[adata.obs[library_key] == niche] - n_cols = niche_adata.X.shape[1] - arr = np.ones(n_cols) - for i in range(n_cols): - col_data = niche_adata.X.getcol(i).data - percentile_99 = np.percentile(col_data, 99) - arr[i] = percentile_99 - niche_dict[niche] = arr - # create 99th percentile count x niche matrix - var_by_niche = pd.DataFrame(niche_dict) - result = pd.DataFrame(index=niches, columns=niches, data=None, dtype=float) - # construct all pairs (unordered and with pairs of the same niche) - combinations = list(itertools.combinations_with_replacement(niches, 2)) - # create a p-value matrix for all niche pairs - for pair in combinations: - p_val = ranksums(var_by_niche[pair[0]], var_by_niche[pair[1]], alternative="two-sided")[1] - result.at[pair[0], pair[1]] = p_val - result.at[pair[1], pair[0]] = p_val - return result - - def mean_fide_score( adatas: AnnData | list[AnnData], library_key: str, From e0a67f5749b22c80f6057cec2ca467a14173d51d Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 9 Oct 2024 23:00:22 +0200 Subject: [PATCH 49/67] Add tests --- tests/graph/test_niche.py | 118 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 tests/graph/test_niche.py diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py new file mode 100644 index 00000000..da5d8330 --- /dev/null +++ b/tests/graph/test_niche.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import numpy as np +import pytest +import scipy +from anndata import AnnData +from pandas.testing import assert_frame_equal +from scipy.sparse import issparse + +from squidpy.gr import calculate_niche, spatial_neighbors +from squidpy.gr._niche import _aggregate, _calculate_neighborhood_profile, _hop, _normalize, _setdiag, _utag + +SPATIAL_CONNECTIVITIES_KEY = "spatial_connectivities" +N_NEIGHBORS = 20 + + +def test_neighborhood_profile_calculation(adata_seqfish: AnnData): + """Check whether niche calculation using neighborhood profile approach works as intended.""" + spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=N_NEIGHBORS) + calculate_niche( + adata_seqfish, + groups="celltype_mapped_refined", + flavor="neighborhood", + n_neighbors=N_NEIGHBORS, + resolutions=[0.1], + min_niche_size=100, + ) + niches = adata_seqfish.obs["neighborhood_niche_res=0.1"] + + # assert no nans, more niche labels than non-niche labels, and at least 100 obs per niche + assert niches.isna().sum() == 0 + assert len(niches[niches != "not_a_niche"]) > len(niches[niches == "not_a_niche"]) + for label in niches.unique(): + if label != "not_a_niche": + assert len(niches[niches == label]) >= 100 + + rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( + adata_seqfish, groups="celltype_mapped_refined", spatial_connectivities_key=SPATIAL_CONNECTIVITIES_KEY + ) + # assert shape obs x groups + assert rel_nhood_profile.shape == ( + adata_seqfish.n_obs, + len(adata_seqfish.obs["celltype_mapped_refined"].cat.categories), + ) + assert abs_nhood_profile.shape == rel_nhood_profile.shape + # normalization + assert int(rel_nhood_profile.sum(axis=1).sum()) == adata_seqfish.n_obs + assert rel_nhood_profile.sum(axis=1).max() == 1 + # maximum amount of categories equals n_neighbors + assert abs_nhood_profile.sum(axis=1).max() == N_NEIGHBORS + + +def test_utag(adata_seqfish: AnnData): + """Check whether niche calculation using UTAG approach works as intended.""" + spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=N_NEIGHBORS) + calculate_niche(adata_seqfish, flavor="utag", n_neighbors=N_NEIGHBORS, resolutions=[0.1, 1.0]) + + niches = adata_seqfish.obs["utag_niche_res=1.0"] + niches_low_res = adata_seqfish.obs["utag_niche_res=0.1"] + + assert niches.isna().sum() == 0 + assert niches.nunique() > niches_low_res.nunique() + + # assert shape obs x var and sparsity in new feature matrix + new_feature_matrix = _utag(adata_seqfish, normalize_adj=True, spatial_connectivity_key=SPATIAL_CONNECTIVITIES_KEY) + assert new_feature_matrix.shape == adata_seqfish.X.shape + assert issparse(new_feature_matrix) + + spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=40) + new_feature_matrix_more_neighs = _utag( + adata_seqfish, normalize_adj=True, spatial_connectivity_key=SPATIAL_CONNECTIVITIES_KEY + ) + + # matrix products should differ when using different amount of neighbors + try: + assert_frame_equal(new_feature_matrix, new_feature_matrix_more_neighs) + except AssertionError: + pass + else: + raise AssertionError + + +def test_cellcharter_approach(adata_seqfish: AnnData): + """Check whether niche calculation using CellCharter approach works as intended.""" + + spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=N_NEIGHBORS) + calculate_niche( + adata_seqfish, groups="celltype_mapped_refined", flavor="cellcharter", adj_subsets=3, n_components=5 + ) + niches = adata_seqfish.obs["cellcharter_niche"] + + assert niches.nunique() == 5 + assert niches.isna().sum() == 0 + + adj = adata_seqfish.obsp[SPATIAL_CONNECTIVITIES_KEY] + adj_hop = _setdiag(adj, value=0) + assert adj_hop.shape == adj.shape + assert issparse(adj_hop) + assert isinstance(adj_hop, scipy.sparse.csrmatrix) + + adj_visited = _setdiag(adj.copy(), 1) # Track visited neighbors + adj_hop, adj_visited = _hop(adj_hop, adj, adj_visited) + assert adj_hop.shape == adj.shape + assert adj_hop.shape == adj_visited.shape + + assert np.array(np.sum(adj, axis=1)).squeeze() == N_NEIGHBORS + adj_hop_norm = _normalize(adj_hop) + assert adj_hop_norm.shape == adj.shape + + mean_aggr_matrix = _aggregate(adata_seqfish, adj_hop_norm, "mean") + assert mean_aggr_matrix.shape == adata_seqfish.X.shape + var_aggr_matrix = _aggregate(adata_seqfish, adj_hop_norm, "variance") + assert var_aggr_matrix.shape == adata_seqfish.X.shape + + # TODO: add test for GMM + + +# TODO: comppare results to previously calculated niches From ae3b4db8c71c273ffeb29650b34131d14e48d7f4 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 9 Oct 2024 23:04:05 +0200 Subject: [PATCH 50/67] Remove unused imports and print statements --- src/squidpy/gr/_niche.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index cc3be22c..10457449 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,6 +1,5 @@ from __future__ import annotations -import itertools from collections.abc import Iterator from typing import Any @@ -10,12 +9,10 @@ import scanpy as sc import scipy.sparse as sps from anndata import AnnData -from scipy.sparse import csr_matrix, hstack, issparse, spdiags -from scipy.stats import ranksums +from scipy.sparse import hstack, issparse, spdiags from sklearn import metrics -from sklearn.metrics import adjusted_rand_score, fowlkes_mallows_score, normalized_mutual_info_score from sklearn.mixture import GaussianMixture -from sklearn.preprocessing import StandardScaler, normalize +from sklearn.preprocessing import normalize from spatialdata import SpatialData from squidpy._utils import NDArrayA @@ -26,7 +23,7 @@ def calculate_niche( adata: AnnData | SpatialData, flavor: str = "neighborhood", - library_key: str | None = None, + library_key: str | None = None, # TODO: calculate niches on a per-slide basis table_key: str | None = None, mask: pd.core.series.Series = None, groups: str | None = None, @@ -132,9 +129,7 @@ def calculate_niche( adata_neighborhood = adata_neighborhood[mask] # required for leiden clustering (note: no dim reduction performed in original implementation) - print("calculating neighbors...") sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") - print("finished calculating neighbors") if resolutions is not None: if not isinstance(resolutions, list): @@ -143,13 +138,11 @@ def calculate_niche( raise ValueError("Please provide resolutions for leiden clustering.") # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label - print("starting clustering...") for res in resolutions: sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}") adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map( adata_neighborhood.obs[f"neighborhood_niche_res={res}"] ).fillna("not_a_niche") - print(f"finished clustering at resolution {res}") # filter niches with n_cells < min_niche_size if min_niche_size is not None: @@ -337,13 +330,9 @@ def _get_GMM_clusters(A: np.ndarray[np.float64, Any], n_components: int, random_ """Returns niche labels generated by GMM clustering. Compared to cellcharter this approach is simplified by using sklearn's GaussianMixture model without stability analysis.""" - print("initializing GMM...") gmm = GaussianMixture(n_components=n_components, random_state=random_state, init_params="random_from_data") - print("fitting GMM...") gmm.fit(A) - print("predicting labels...") labels = gmm.predict(A) - print("done") return labels From 5a4cfc4c768adda0d920dd76a6d0d6ef86b1ac88 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 9 Oct 2024 23:26:20 +0200 Subject: [PATCH 51/67] Fix tests --- tests/graph/test_niche.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index da5d8330..35354151 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -35,7 +35,10 @@ def test_neighborhood_profile_calculation(adata_seqfish: AnnData): assert len(niches[niches == label]) >= 100 rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - adata_seqfish, groups="celltype_mapped_refined", spatial_connectivities_key=SPATIAL_CONNECTIVITIES_KEY + adata_seqfish, + groups="celltype_mapped_refined", + subset_groups=None, + spatial_connectivities_key=SPATIAL_CONNECTIVITIES_KEY, ) # assert shape obs x groups assert rel_nhood_profile.shape == ( @@ -55,8 +58,8 @@ def test_utag(adata_seqfish: AnnData): spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=N_NEIGHBORS) calculate_niche(adata_seqfish, flavor="utag", n_neighbors=N_NEIGHBORS, resolutions=[0.1, 1.0]) - niches = adata_seqfish.obs["utag_niche_res=1.0"] - niches_low_res = adata_seqfish.obs["utag_niche_res=0.1"] + niches = adata_seqfish.obs["utag_res=1.0"] + niches_low_res = adata_seqfish.obs["utag_res=0.1"] assert niches.isna().sum() == 0 assert niches.nunique() > niches_low_res.nunique() @@ -96,7 +99,7 @@ def test_cellcharter_approach(adata_seqfish: AnnData): adj_hop = _setdiag(adj, value=0) assert adj_hop.shape == adj.shape assert issparse(adj_hop) - assert isinstance(adj_hop, scipy.sparse.csrmatrix) + assert isinstance(adj_hop, scipy.sparse.csr_matrix) adj_visited = _setdiag(adj.copy(), 1) # Track visited neighbors adj_hop, adj_visited = _hop(adj_hop, adj, adj_visited) From 3e47df8cbf8af6f0d93a28ac45aa569d32061da0 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 9 Oct 2024 23:55:20 +0200 Subject: [PATCH 52/67] Fix test --- tests/graph/test_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index 35354151..e729a28f 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -48,7 +48,7 @@ def test_neighborhood_profile_calculation(adata_seqfish: AnnData): assert abs_nhood_profile.shape == rel_nhood_profile.shape # normalization assert int(rel_nhood_profile.sum(axis=1).sum()) == adata_seqfish.n_obs - assert rel_nhood_profile.sum(axis=1).max() == 1 + assert round(rel_nhood_profile.sum(axis=1).max(), 2) == 1 # maximum amount of categories equals n_neighbors assert abs_nhood_profile.sum(axis=1).max() == N_NEIGHBORS From 6cbc09ebf352e330b50dcd7334a5938cda5ae2f0 Mon Sep 17 00:00:00 2001 From: LLehner Date: Thu, 10 Oct 2024 14:12:13 +0200 Subject: [PATCH 53/67] Fix test --- src/squidpy/gr/_niche.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 10457449..e7231f9a 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -317,8 +317,9 @@ def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggreg if aggregation == "mean": aggregated_matrix = normalized_adjacency_matrix @ adata.X elif aggregation == "variance": - mean_matrix = normalized_adjacency_matrix @ adata.X - mean_squared_matrix = normalized_adjacency_matrix @ (adata.X * adata.X) + mean_matrix = (normalized_adjacency_matrix @ adata.X).toarray() + X_to_arr = adata.X.toarray() + mean_squared_matrix = normalized_adjacency_matrix @ (X_to_arr * X_to_arr) aggregated_matrix = mean_squared_matrix - mean_matrix * mean_matrix else: raise ValueError(f"Invalid aggregation method '{aggregation}'. Please choose either 'mean' or 'variance'.") From 6185c0e48a6117bf01cf7c420fbb8445a64b6eea Mon Sep 17 00:00:00 2001 From: LLehner Date: Thu, 10 Oct 2024 14:28:56 +0200 Subject: [PATCH 54/67] Fix test --- tests/graph/test_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index e729a28f..54ec4586 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -106,7 +106,7 @@ def test_cellcharter_approach(adata_seqfish: AnnData): assert adj_hop.shape == adj.shape assert adj_hop.shape == adj_visited.shape - assert np.array(np.sum(adj, axis=1)).squeeze() == N_NEIGHBORS + assert np.array(np.sum(adj, axis=1)).squeeze().max() == N_NEIGHBORS adj_hop_norm = _normalize(adj_hop) assert adj_hop_norm.shape == adj.shape From 33436bf1f76651bfcc6f549a650d4a8f42f3b2d1 Mon Sep 17 00:00:00 2001 From: LLehner <64135338+LLehner@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:36:46 +0100 Subject: [PATCH 55/67] Fix sepal --- src/squidpy/gr/_sepal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/squidpy/gr/_sepal.py b/src/squidpy/gr/_sepal.py index a1dd91b2..95d74099 100644 --- a/src/squidpy/gr/_sepal.py +++ b/src/squidpy/gr/_sepal.py @@ -182,7 +182,7 @@ def _score_helper( score = [] for i in ixs: - if sparse and isinstance(vals, spmatrix): + if isinstance(vals, spmatrix): conc = vals[:, i].toarray().flatten() # Safe to call toarray() else: conc = vals[:, i].copy() # vals is assumed to be a NumPy array here From ca753a28904a130a0a11baeeae1db54efbefb3fb Mon Sep 17 00:00:00 2001 From: LLehner Date: Sat, 30 Nov 2024 23:30:31 +0100 Subject: [PATCH 56/67] Add suggested changes from review --- src/squidpy/_constants/_constants.py | 9 + src/squidpy/gr/_niche.py | 458 ++++++++++++++++----------- src/squidpy/gr/_ppatterns.py | 2 +- 3 files changed, 277 insertions(+), 192 deletions(-) diff --git a/src/squidpy/_constants/_constants.py b/src/squidpy/_constants/_constants.py index 6d6bdab7..403f072b 100644 --- a/src/squidpy/_constants/_constants.py +++ b/src/squidpy/_constants/_constants.py @@ -123,3 +123,12 @@ class TenxVersions(str, ModeEnum): V1 = "1.1.0" V2 = "1.2.0" V3 = "1.3.0" + + +@unique +class NicheDefinitions(ModeEnum): + NEIGHBORHOOD = "neighborhood" + UTAG = "utag" + CELLCHARTER = "cellcharter" + SPOT = "spot" + BANKSY = "banksy" diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index e7231f9a..8a72c507 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,7 +1,8 @@ from __future__ import annotations +import warnings from collections.abc import Iterator -from typing import Any +from typing import Any, Literal, Union import anndata as ad import numpy as np @@ -9,37 +10,44 @@ import scanpy as sc import scipy.sparse as sps from anndata import AnnData +from numpy.typing import NDArray from scipy.sparse import hstack, issparse, spdiags -from sklearn import metrics +from scipy.spatial import distance +from sklearn.metrics import f1_score from sklearn.mixture import GaussianMixture from sklearn.preprocessing import normalize from spatialdata import SpatialData +from squidpy._constants._constants import NicheDefinitions +from squidpy._docs import d, inject_docs from squidpy._utils import NDArrayA __all__ = ["calculate_niche"] +@d.dedent +@inject_docs(m=NicheDefinitions) def calculate_niche( adata: AnnData | SpatialData, - flavor: str = "neighborhood", + flavor: Literal["neighborhood", "utag", "cellcharter"] = "neighborhood", library_key: str | None = None, # TODO: calculate niches on a per-slide basis table_key: str | None = None, mask: pd.core.series.Series = None, groups: str | None = None, - n_neighbors: int = 15, - resolutions: int | list[float] | None = None, + n_neighbors: int | None = None, + resolutions: float | list[float] | None = None, subset_groups: list[str] | None = None, min_niche_size: int | None = None, scale: bool = True, abs_nhood: bool = False, adj_subsets: int | list[int] | None = None, aggregation: str = "mean", - n_components: int = 3, + n_components: int | None = None, random_state: int = 42, spatial_connectivities_key: str = "spatial_connectivities", ) -> AnnData | pd.DataFrame: - """Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. + """ + Calculate niches (spatial clusters) based on a user-defined method in 'flavor'. The resulting niche labels with be stored in 'adata.obs'. If flavor = 'all' then all available methods will be applied and additionally compared using cluster validation scores. Parameters @@ -53,8 +61,6 @@ def calculate_niche( - `{c.SPOT.s!r}` - calculate niches using optimal transport. (coming soon) - `{c.BANKSY.s!r}`- use Banksy algorithm. (coming soon) %(library_key)s - subset - Restrict niche calculation to a subset of the data. table_key Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed. mask @@ -62,37 +68,37 @@ def calculate_niche( Note that if you want to exclude these cells during neighborhood calculation already, you should subset your AnnData table before running 'sq.gr.spatial_neigbors'. groups Groups based on which to calculate neighborhood profile (E.g. columns of cell type annotations in adata.obs). - Required if flavor == 'neighborhood'. + Required if flavor == `{c.NEIGHBORHOOD.s!r}`. n_neighbors Number of neighbors to use for 'scanpy.pp.neighbors' before clustering using leiden algorithm. - Required if flavor == 'neighborhood' or flavor == 'UTAG'. + Required if flavor == `{c.NEIGHBORHOOD.s!r}` or flavor == `{c.UTAG.s!r}`. resolutions List of resolutions to use for leiden clustering. - Required if flavor == 'neighborhood' or flavor == 'UTAG'. + Required if flavor == `{c.NEIGHBORHOOD.s!r}` or flavor == `{c.UTAG.s!r}`. subset_groups Groups (e.g. cell type categories) to ignore when calculating the neighborhood profile. - Optional if flavor == 'neighborhood'. + Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. min_niche_size Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. - Optional if flavor == 'neighborhood'. + Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. scale If 'True', compute z-scores of neighborhood profiles. - Optional if flavor == 'neighborhood'. + Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. abs_nhood If 'True', calculate niches based on absolute neighborhood profile. - Optional if flavor == 'neighborhood'. + Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. adj_subsets List of adjacency matrices to use e.g. [1,2,3] for 1-hop,2-hop,3-hop neighbors respectively or "5" for 1-hop,...,5-hop neighbors. 0 (self) is always included. - Required if flavor == 'cellcharter'. + Required if flavor == `{c.CELLCHARTER.s!r}`. aggregation How to aggregate count matrices. Either 'mean' or 'variance'. - Required if flavor == 'cellcharter'. + Required if flavor == `{c.CELLCHARTER.s!r}`. n_components Number of components to use for GMM. - Required if flavor == 'cellcharter'. + Required if flavor == `{c.CELLCHARTER.s!r}`. random_state Random state to use for GMM. - Optional if flavor == 'cellcharter'. + Optional if flavor == `{c.CELLCHARTER.s!r}`. spatial_connectivities_key Key in `adata.obsp` where spatial connectivities are stored. """ @@ -106,111 +112,166 @@ def calculate_niche( else: adata = adata - if flavor == "neighborhood": - """adapted from https://github.com/immunitastx/monkeybread/blob/main/src/monkeybread/calc/_neighborhood_profile.py""" - - # calculate the neighborhood profile for each cell (relative and absolute proportion of e.g. each cell type in the neighborhood) - rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - adata, groups, subset_groups, spatial_connectivities_key + # check whether neighborhood graph exists + if spatial_connectivities_key not in adata.obsp.keys(): + raise KeyError( + f"Key '{spatial_connectivities_key}' not found in `adata.obsp`. If you haven't computed a spatial neighborhood graph yet, use `sq.gr.spatial_neighbors`." ) - # create AnnData object from neighborhood profile to perform scanpy functions - if not abs_nhood: - adata_neighborhood = ad.AnnData(X=rel_nhood_profile) - else: - adata_neighborhood = ad.AnnData(X=abs_nhood_profile) - # reason for scaling see https://monkeybread.readthedocs.io/en/latest/notebooks/tutorial.html#niche-analysis - if scale: - sc.pp.scale(adata_neighborhood, zero_center=True) + _validate_args( + adata, + mask, + flavor, + groups, + n_neighbors, + resolutions, + subset_groups, + min_niche_size, + scale, + abs_nhood, + adj_subsets, + aggregation, + n_components, + random_state, + spatial_connectivities_key, + ) - # mask obs to exclude cells for which no niche shall be assigned - if mask is not None: - mask = mask[mask.index.isin(adata_neighborhood.obs.index)] - adata_neighborhood = adata_neighborhood[mask] - # required for leiden clustering (note: no dim reduction performed in original implementation) - sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") +def _get_nhood_profile_niches( + adata: AnnData, + mask: pd.core.series.Series | None, + groups: str | None, + n_neighbors: int, + resolutions: float | list[float], + subset_groups: list[str] | None, + min_niche_size: int | None, + scale: bool, + abs_nhood: bool, + spatial_connectivities_key: str, +) -> None: + """ + adapted from https://github.com/immunitastx/monkeybread/blob/main/src/monkeybread/calc/_neighborhood_profile.py + """ + # calculate the neighborhood profile for each cell (relative and absolute proportion of e.g. each cell type in the neighborhood) + rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( + adata, groups, subset_groups, spatial_connectivities_key + ) + # create AnnData object from neighborhood profile to perform scanpy functions + if not abs_nhood: + adata_neighborhood = ad.AnnData(X=rel_nhood_profile) + else: + adata_neighborhood = ad.AnnData(X=abs_nhood_profile) - if resolutions is not None: - if not isinstance(resolutions, list): - resolutions = [resolutions] - else: - raise ValueError("Please provide resolutions for leiden clustering.") - - # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label - for res in resolutions: - sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}") - adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map( - adata_neighborhood.obs[f"neighborhood_niche_res={res}"] - ).fillna("not_a_niche") - - # filter niches with n_cells < min_niche_size - if min_niche_size is not None: - counts_by_niche = adata.obs[f"neighborhood_niche_res={res}"].value_counts() - to_filter = counts_by_niche[counts_by_niche < min_niche_size].index - adata.obs[f"neighborhood_niche_res={res}"] = adata.obs[f"neighborhood_niche_res={res}"].apply( - lambda x, to_filter=to_filter: "not_a_niche" if x in to_filter else x - ) + # reason for scaling see https://monkeybread.readthedocs.io/en/latest/notebooks/tutorial.html#niche-analysis + if scale: + sc.pp.scale(adata_neighborhood, zero_center=True) - elif flavor == "utag": - """adapted from https://github.com/ElementoLab/utag/blob/main/utag/segmentation.py""" + # mask obs to exclude cells for which no niche shall be assigned + if mask is not None: + mask = mask[mask.index.isin(adata_neighborhood.obs.index)] + adata_neighborhood = adata_neighborhood[mask] - new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) - adata_utag = ad.AnnData(X=new_feature_matrix) - sc.tl.pca(adata_utag) # note: unlike with flavor 'neighborhood' dim reduction is performed here - sc.pp.neighbors(adata_utag, n_neighbors=n_neighbors, use_rep="X_pca") + # required for leiden clustering (note: no dim reduction performed in original implementation) + sc.pp.neighbors(adata_neighborhood, n_neighbors=n_neighbors, use_rep="X") - if resolutions is not None: - if not isinstance(resolutions, list): - resolutions = [resolutions] - else: - raise ValueError("Please provide resolutions for leiden clustering.") + resolutions = [resolutions] if not isinstance(resolutions, list) else resolutions - # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label - for res in resolutions: - sc.tl.leiden(adata_utag, resolution=res, key_added=f"utag_res={res}") - adata.obs[f"utag_res={res}"] = adata_utag.obs[f"utag_res={res}"].values + # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label + for res in resolutions: + sc.tl.leiden(adata_neighborhood, resolution=res, key_added=f"neighborhood_niche_res={res}") + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs.index.map( + adata_neighborhood.obs[f"neighborhood_niche_res={res}"] + ).fillna("not_a_niche") - elif flavor == "cellcharter": - """adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py - and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py""" - - adjacency_matrix = adata.obsp[spatial_connectivities_key] - if not isinstance(adj_subsets, list): - if adj_subsets is not None: - adj_subsets = list(range(adj_subsets + 1)) - else: - raise ValueError( - "flavor 'cellcharter' requires adj_subsets to not be None. Specify list of values or maximum value of neighbors to use." - ) + # filter niches with n_cells < min_niche_size + if min_niche_size is not None: + counts_by_niche = adata.obs[f"neighborhood_niche_res={res}"].value_counts() + to_filter = counts_by_niche[counts_by_niche < min_niche_size].index + adata.obs[f"neighborhood_niche_res={res}"] = adata.obs[f"neighborhood_niche_res={res}"].apply( + lambda x, to_filter=to_filter: "not_a_niche" if x in to_filter else x + ) + + return + + +def _get_utag_niches( + adata: AnnData, + subset_groups: list[str] | None, + n_neighbors: int, + resolutions: float | list[float] | None, + spatial_connectivities_key: str, +) -> None: + """ + Adapted from https://github.com/ElementoLab/utag/blob/main/utag/segmentation.py + """ + + new_feature_matrix = _utag(adata, normalize_adj=True, spatial_connectivity_key=spatial_connectivities_key) + adata_utag = ad.AnnData(X=new_feature_matrix) + sc.tl.pca(adata_utag) # note: unlike with flavor 'neighborhood' dim reduction is performed here + sc.pp.neighbors(adata_utag, n_neighbors=n_neighbors, use_rep="X_pca") + + if resolutions is not None: + if not isinstance(resolutions, list): + resolutions = [resolutions] + else: + raise ValueError("Please provide resolutions for leiden clustering.") + + # For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label + for res in resolutions: + sc.tl.leiden(adata_utag, resolution=res, key_added=f"utag_res={res}") + adata.obs[f"utag_res={res}"] = adata_utag.obs[f"utag_res={res}"].values + return + + +def _get_cellcharter_niches( + adata: AnnData, + subset_groups: list[str] | None, + adj_subsets: int | list[int] | None, + aggregation: str, + n_components: int, + random_state: int, + spatial_connectivities_key: str, +) -> None: + """adapted from https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/gr/_aggr.py + and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py""" + + adjacency_matrix = adata.obsp[spatial_connectivities_key] + if not isinstance(adj_subsets, list): + if adj_subsets is not None: + adj_subsets = list(range(adj_subsets + 1)) + else: + raise ValueError( + "flavor 'cellcharter' requires adj_subsets to not be None. Specify list of values or maximum value of neighbors to use." + ) + else: + if 0 not in adj_subsets: + adj_subsets.insert(0, 0) + if any(x < 0 for x in adj_subsets): + raise ValueError("adj_subsets must contain non-negative integers.") + + aggregated_matrices = [] + adj_hop = _setdiag(adjacency_matrix, 0) # Remove self-loops, set diagonal to 0 + adj_visited = _setdiag(adjacency_matrix.copy(), 1) # Track visited neighbors + for k in adj_subsets: + if k == 0: + # get original count matrix (not aggregated) + aggregated_matrices.append(adata.X) else: - if 0 not in adj_subsets: - adj_subsets.insert(0, 0) - if any(x < 0 for x in adj_subsets): - raise ValueError("adj_subsets must contain non-negative integers.") - - aggregated_matrices = [] - adj_hop = _setdiag(adjacency_matrix, 0) # Remove self-loops, set diagonal to 0 - adj_visited = _setdiag(adjacency_matrix.copy(), 1) # Track visited neighbors - for k in adj_subsets: - if k == 0: - # get original count matrix (not aggregated) - aggregated_matrices.append(adata.X) - else: - # get count and adjacency matrix for k-hop (neighbor of neighbor of neighbor ...) and aggregate them - if k > 1: - adj_hop, adj_visited = _hop(adj_hop, adjacency_matrix, adj_visited) - adj_hop_norm = _normalize(adj_hop) - aggregated_matrix = _aggregate(adata, adj_hop_norm, aggregation) - aggregated_matrices.append(aggregated_matrix) - - concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally - arr = concatenated_matrix.toarray() # Densify - - # cluster concatenated matrix with GMM, each cluster label equals to a niche label - niches = _get_GMM_clusters(arr, n_components, random_state) - - adata.obs[f"{flavor}_niche"] = pd.Categorical(niches) + # get count and adjacency matrix for k-hop (neighbor of neighbor of neighbor ...) and aggregate them + if k > 1: + adj_hop, adj_visited = _hop(adj_hop, adjacency_matrix, adj_visited) + adj_hop_norm = _normalize(adj_hop) + aggregated_matrix = _aggregate(adata, adj_hop_norm, aggregation) + aggregated_matrices.append(aggregated_matrix) + + concatenated_matrix = hstack(aggregated_matrices) # Stack all matrices horizontally + arr = concatenated_matrix.toarray() # Densify + + # cluster concatenated matrix with GMM, each cluster label equals to a niche label + niches = _get_GMM_clusters(arr, n_components, random_state) + + adata.obs["cellcharter_niche"] = pd.Categorical(niches) + return def _calculate_neighborhood_profile( @@ -263,8 +324,10 @@ def _calculate_neighborhood_profile( def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: - """Performs inner product of adjacency matrix and feature matrix, - such that each observation inherits features from its immediate neighbors as described in UTAG paper.""" + """ + Performs inner product of adjacency matrix and feature matrix, + such that each observation inherits features from its immediate neighbors as described in UTAG paper. + """ adjacency_matrix = adata.obsp[spatial_connectivity_key] @@ -327,7 +390,7 @@ def _aggregate(adata: AnnData, normalized_adjacency_matrix: sps.spmatrix, aggreg return aggregated_matrix -def _get_GMM_clusters(A: np.ndarray[np.float64, Any], n_components: int, random_state: int) -> Any: +def _get_GMM_clusters(A: NDArray[np.float64], n_components: int, random_state: int) -> Any: """Returns niche labels generated by GMM clustering. Compared to cellcharter this approach is simplified by using sklearn's GaussianMixture model without stability analysis.""" @@ -338,95 +401,108 @@ def _get_GMM_clusters(A: np.ndarray[np.float64, Any], n_components: int, random_ return labels -def mean_fide_score( - adatas: AnnData | list[AnnData], - library_key: str, - slide_key: str | None = None, - n_classes: int | None = None, -) -> float: - """Mean FIDE score over all slides. A low score indicates a great domain continuity.""" - return float( - np.mean([fide_score(adata, library_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key)]) - ) - - -def fide_score(adata: AnnData, library_key: str, n_classes: int | None = None) -> float: +def _fide_score(adata: AnnData, niche_key: str, average: bool) -> Any: """ F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. The F1-score is computed for every class, then all F1-scores are averaged. If some classes are not predicted, the `n_classes` argument allows to pad with zeros before averaging the F1-scores. """ - i_left, i_right = adata.obsp["spatial_connectivities"].nonzero() - classes_left, classes_right = ( - adata.obs.iloc[i_left][library_key], - adata.obs.iloc[i_right][library_key], + i, j = adata.obsp["spatial_connectivities"].nonzero() # get row and column indices of non-zero elements + niche_labels, neighbor_niche_labels = ( + adata.obs.iloc[i][niche_key], + adata.obs.iloc[j][niche_key], ) - f1_scores = metrics.f1_score(classes_left, classes_right, average=None) - - if n_classes is None: - return float(f1_scores.mean()) - - assert n_classes >= len(f1_scores), f"Expected {n_classes:=}, but found {len(f1_scores)}, which is greater" - - return float(np.pad(f1_scores, (0, n_classes - len(f1_scores))).mean()) - - -def jensen_shannon_divergence(adatas: AnnData | list[AnnData], library_key: str, slide_key: str | None = None) -> float: - """Jensen-Shannon divergence (JSD) over all slides""" - distributions = [ - adata.obs[library_key].value_counts(sort=False).values for adata in _iter_uid(adatas, slide_key, library_key) - ] - - return _jensen_shannon_divergence(np.array(distributions)) - + if not average: + fide = f1_score(niche_labels, neighbor_niche_labels, average=None) + else: + fide = f1_score(niche_labels, neighbor_niche_labels, average="macro") -def _jensen_shannon_divergence(distributions: NDArrayA) -> float: - """Compute the Jensen-Shannon divergence (JSD) for a multiple probability distributions. - The lower the score, the better distribution of clusters among the different batches. + return fide - Parameters - ---------- - distributions - An array of shape (B x C), where B is the number of batches, and C is the number of clusters. For each batch, it contains the percentage of each cluster among cells. - Returns - JSD (float) +def _jensen_shannon_divergence(adata: AnnData, niche_key: str, library_key: str) -> Any: """ - distributions = distributions / distributions.sum(1)[:, None] - mean_distribution = np.mean(distributions, 0) - - return _entropy(mean_distribution) - float(np.mean([_entropy(dist) for dist in distributions])) - - -def _entropy(distribution: NDArrayA) -> float: - """Shannon entropy - - Parameters - ---------- - distribution: An array of probabilities (should sum to one) - - Returns - The Shannon entropy + Calculate Jensen-Shannon divergence (JSD) over all slides. + This metric measures how well niche label distributions match across different slides. """ - return float(-(distribution * np.log(distribution + 1e-8)).sum()) + niche_labels = sorted(adata.obs[niche_key].unique()) + label_distributions = [] + for _, slide in adata.obs.groupby(library_key): + counts = slide[niche_key].value_counts(normalize=True) + relative_freq = [counts.get(label, 0) for label in niche_labels] + label_distributions.append(relative_freq) -def _iter_uid( - adatas: AnnData | list[AnnData], slide_key: str | None, library_key: str | None = None -) -> Iterator[AnnData]: - if isinstance(adatas, AnnData): - adatas = [adatas] + return distance.jensenshannon(np.array(label_distributions)) - if library_key is not None: - categories = set.union(*[set(adata.obs[library_key].unique().dropna()) for adata in adatas]) - for adata in adatas: - adata.obs[library_key] = adata.obs[library_key].astype("category").cat.set_categories(categories) - for adata in adatas: - if slide_key is not None: - for slide in adata.obs[slide_key].unique(): - yield adata[adata.obs[slide_key] == slide] +def _validate_args( + adata: AnnData, + mask: pd.core.series.Series | None, + flavor: Literal["neighborhood", "utag", "cellcharter"], + groups: str | None, + n_neighbors: int | None, + resolutions: float | list[float] | None, + subset_groups: list[str] | None, + min_niche_size: int | None, + scale: bool, + abs_nhood: bool, + adj_subsets: int | list[int] | None, + aggregation: str, + n_components: int | None, + random_state: int, + spatial_connectivities_key: str, +) -> str | None: + """ + Validate whether necessary arguments are provided for a given niche flavor. + If required arguments are provided, run respective niche calculation function. + Also warns whether unnecessary optional arguments are supplied. + """ + if flavor == "neighborhood": + if any(arg is not None for arg in ([random_state])): + warnings.warn("param 'random_state' is not used for neighborhood flavor.", stacklevel=2) + if groups is not None and n_neighbors is not None and resolutions is not None: + _get_nhood_profile_niches( + adata, + mask, + groups, + n_neighbors, + resolutions, + subset_groups, + min_niche_size, + scale, + abs_nhood, + spatial_connectivities_key, + ) else: - yield adata + raise ValueError( + "One of required args 'groups', 'n_neighbors' and 'resolutions' for flavor 'neighborhood' is 'None'." + ) + elif flavor == "utag": + if any(arg is not None for arg in (subset_groups, min_niche_size, scale, abs_nhood, random_state)): + warnings.warn( + "param 'subset_groups', 'min_niche_size', 'scale', 'abs_nhood', 'random_state' are not used for utag flavor.", + stacklevel=2, + ) + if n_neighbors is not None and resolutions is not None: + _get_utag_niches(adata, subset_groups, n_neighbors, resolutions, spatial_connectivities_key) + else: + raise ValueError("One of required args 'n_neighbors' and 'resolutions' for flavor 'utag' is 'None'.") + elif flavor == "cellcharter": + if any(arg is not None for arg in (groups, subset_groups, min_niche_size, scale, abs_nhood)): + warnings.warn( + "param 'groups', 'subset_groups', 'min_niche_size', 'scale', 'abs_nhood' are not used for cellcharter flavor.", + stacklevel=2, + ) + if adj_subsets is not None and aggregation is not None and n_components is not None: + _get_cellcharter_niches( + adata, subset_groups, adj_subsets, aggregation, n_components, random_state, spatial_connectivities_key + ) + else: + raise ValueError( + "One of required args 'adj_subsets', 'aggregation' and 'n_components' for flavor 'cellcharter' is 'None'." + ) + else: + raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter'.") diff --git a/src/squidpy/gr/_ppatterns.py b/src/squidpy/gr/_ppatterns.py index 985992fa..187f1156 100644 --- a/src/squidpy/gr/_ppatterns.py +++ b/src/squidpy/gr/_ppatterns.py @@ -204,7 +204,7 @@ def extract_obsm(adata: AnnData, ixs: int | Sequence[int] | None) -> tuple[NDArr start = logg.info(f"Calculating {mode}'s statistic for `{n_perms}` permutations using `{n_jobs}` core(s)") if n_perms is not None: _assert_positive(n_perms, name="n_perms") - perms = np.arange(n_perms) + perms = list(np.arange(n_perms)) score_perms = parallelize( _score_helper, From e1b2814c57c0def7a326dd26816e8a1d7e1b6294 Mon Sep 17 00:00:00 2001 From: LLehner Date: Sun, 1 Dec 2024 23:54:54 +0100 Subject: [PATCH 57/67] Add distance option to neighborhood approach --- .pre-commit-config.yaml | 4 +- src/squidpy/gr/_niche.py | 121 ++++++++++++++++++++++----------------- 2 files changed, 71 insertions(+), 54 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5d09ab60..15ff6913 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ fail_fast: false default_stages: - - pre-commit - - pre-push + - commit + - push minimum_pre_commit_version: 2.9.3 repos: - repo: https://github.com/astral-sh/ruff-pre-commit diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 8a72c507..4a7a6cab 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -11,7 +11,7 @@ import scipy.sparse as sps from anndata import AnnData from numpy.typing import NDArray -from scipy.sparse import hstack, issparse, spdiags +from scipy.sparse import coo_matrix, hstack, issparse, spdiags from scipy.spatial import distance from sklearn.metrics import f1_score from sklearn.mixture import GaussianMixture @@ -26,7 +26,7 @@ @d.dedent -@inject_docs(m=NicheDefinitions) +# @inject_docs(m=NicheDefinitions) def calculate_niche( adata: AnnData | SpatialData, flavor: Literal["neighborhood", "utag", "cellcharter"] = "neighborhood", @@ -40,7 +40,8 @@ def calculate_niche( min_niche_size: int | None = None, scale: bool = True, abs_nhood: bool = False, - adj_subsets: int | list[int] | None = None, + distance: int = 1, + n_hop_weights: list[float] | None = None, aggregation: str = "mean", n_components: int | None = None, random_state: int = 42, @@ -87,9 +88,13 @@ def calculate_niche( abs_nhood If 'True', calculate niches based on absolute neighborhood profile. Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. - adj_subsets - List of adjacency matrices to use e.g. [1,2,3] for 1-hop,2-hop,3-hop neighbors respectively or "5" for 1-hop,...,5-hop neighbors. 0 (self) is always included. + distance + n-hop neighbor adjacency matrices to use e.g. [1,2,3] for 1-hop,2-hop,3-hop neighbors respectively or "5" for 1-hop,...,5-hop neighbors. 0 (self) is always included. Required if flavor == `{c.CELLCHARTER.s!r}`. + Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. + n_hop_weights + How to weight subsequent n-hop adjacency matrices. E.g. [1, 0.5, 0.25] for weights of 1-hop, 2-hop, 3-hop adjacency matrices respectively. + Optional if flavor == `{c.NEIGHBORHOOD.s!r}` and `distance` > 1. aggregation How to aggregate count matrices. Either 'mean' or 'variance'. Required if flavor == `{c.CELLCHARTER.s!r}`. @@ -129,7 +134,8 @@ def calculate_niche( min_niche_size, scale, abs_nhood, - adj_subsets, + distance, + n_hop_weights, aggregation, n_components, random_state, @@ -140,27 +146,54 @@ def calculate_niche( def _get_nhood_profile_niches( adata: AnnData, mask: pd.core.series.Series | None, - groups: str | None, + groups: str, n_neighbors: int, resolutions: float | list[float], subset_groups: list[str] | None, min_niche_size: int | None, scale: bool, abs_nhood: bool, + distance: int, + n_hop_weights: list[float] | None, spatial_connectivities_key: str, ) -> None: """ adapted from https://github.com/immunitastx/monkeybread/blob/main/src/monkeybread/calc/_neighborhood_profile.py """ - # calculate the neighborhood profile for each cell (relative and absolute proportion of e.g. each cell type in the neighborhood) - rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - adata, groups, subset_groups, spatial_connectivities_key - ) + # If subsetting, filter connections from adjacency matrix + if subset_groups: + adjacency_matrix = adata.obsp[spatial_connectivities_key].tocsc() + obs_mask = ~adata.obs[groups].isin(subset_groups) + adata = adata[obs_mask] + + adjacency_matrix = adjacency_matrix[obs_mask, :][:, obs_mask] + adata.obsp[spatial_connectivities_key] = adjacency_matrix.tocsr() + + # get obs x neighbor matrix from sparse matrix + matrix = adata.obsp[spatial_connectivities_key].tocoo() + + # get obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood + nhood_profile = _calculate_neighborhood_profile(adata, groups, matrix, abs_nhood) + + # Additionally use n-hop neighbors if distance > 1. This sums up the (weighted) neighborhood profiles of all n-hop neighbors. + if distance > 1: + n_hop_adjacency_matrix = adata.obsp[spatial_connectivities_key].copy() + # if no weights are provided, use 1 for all n_hop neighbors + if n_hop_weights is None: + n_hop_weights = [1] * distance + # if weights are provided, start with applying weight to the original neighborhood profile + else: + nhood_profile = n_hop_weights[0] * nhood_profile + # get n_hop neighbor adjacency matrices by multiplying the original adjacency matrix with itself n times and get corresponding neighborhood profiles. + for n_hop in range(distance - 1): + n_hop_adjacency_matrix = n_hop_adjacency_matrix @ adata.obsp[spatial_connectivities_key] + matrix = n_hop_adjacency_matrix.tocoo() + nhood_profile += n_hop_weights[n_hop + 1] * _calculate_neighborhood_profile( + adata, groups, matrix, abs_nhood + ) + # create AnnData object from neighborhood profile to perform scanpy functions - if not abs_nhood: - adata_neighborhood = ad.AnnData(X=rel_nhood_profile) - else: - adata_neighborhood = ad.AnnData(X=abs_nhood_profile) + adata_neighborhood = ad.AnnData(X=nhood_profile) # reason for scaling see https://monkeybread.readthedocs.io/en/latest/notebooks/tutorial.html#niche-analysis if scale: @@ -226,7 +259,7 @@ def _get_utag_niches( def _get_cellcharter_niches( adata: AnnData, subset_groups: list[str] | None, - adj_subsets: int | list[int] | None, + distance: int, aggregation: str, n_components: int, random_state: int, @@ -236,23 +269,12 @@ def _get_cellcharter_niches( and https://github.com/CSOgroup/cellcharter/blob/main/src/cellcharter/tl/_gmm.py""" adjacency_matrix = adata.obsp[spatial_connectivities_key] - if not isinstance(adj_subsets, list): - if adj_subsets is not None: - adj_subsets = list(range(adj_subsets + 1)) - else: - raise ValueError( - "flavor 'cellcharter' requires adj_subsets to not be None. Specify list of values or maximum value of neighbors to use." - ) - else: - if 0 not in adj_subsets: - adj_subsets.insert(0, 0) - if any(x < 0 for x in adj_subsets): - raise ValueError("adj_subsets must contain non-negative integers.") + layers = list(range(distance + 1)) aggregated_matrices = [] adj_hop = _setdiag(adjacency_matrix, 0) # Remove self-loops, set diagonal to 0 adj_visited = _setdiag(adjacency_matrix.copy(), 1) # Track visited neighbors - for k in adj_subsets: + for k in layers: if k == 0: # get original count matrix (not aggregated) aggregated_matrices.append(adata.X) @@ -276,25 +298,14 @@ def _get_cellcharter_niches( def _calculate_neighborhood_profile( adata: AnnData, - groups: str | None, - subset_groups: list[str] | None, - spatial_connectivities_key: str, -) -> tuple[pd.DataFrame, pd.DataFrame]: - """returns an obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood""" - - if groups is None: - raise ValueError("Please specify 'groups' based on which to calculate neighborhood profile.") - if subset_groups: - adjacency_matrix = adata.obsp[spatial_connectivities_key].tocsc() - obs_mask = ~adata.obs[groups].isin(subset_groups) - adata = adata[obs_mask] - - # Update adjacency matrix such that it only contains connections to filtered observations - adjacency_matrix = adjacency_matrix[obs_mask, :][:, obs_mask] - adata.obsp[spatial_connectivities_key] = adjacency_matrix.tocsr() + groups: str, + matrix: coo_matrix, + abs_nhood: bool, +) -> pd.DataFrame: + """ + Returns an obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood + """ - # get obs x neighbor matrix from sparse matrix - matrix = adata.obsp[spatial_connectivities_key].tocoo() nonzero_indices = np.split(matrix.col, matrix.row.searchsorted(np.arange(1, matrix.shape[0]))) neighbor_matrix = pd.DataFrame(nonzero_indices) @@ -320,7 +331,10 @@ def _calculate_neighborhood_profile( # normalize by n_neighbors to get relative frequency of each category rel_freq = abs_freq / k - return pd.DataFrame(rel_freq, index=adata.obs.index), pd.DataFrame(abs_freq, index=adata.obs.index) + if abs_nhood: + return pd.DataFrame(abs_freq, index=adata.obs.index) + else: + return pd.DataFrame(rel_freq, index=adata.obs.index) def _utag(adata: AnnData, normalize_adj: bool, spatial_connectivity_key: str) -> AnnData: @@ -449,7 +463,8 @@ def _validate_args( min_niche_size: int | None, scale: bool, abs_nhood: bool, - adj_subsets: int | list[int] | None, + distance: int, + n_hop_weights: list[float] | None, aggregation: str, n_components: int | None, random_state: int, @@ -474,6 +489,8 @@ def _validate_args( min_niche_size, scale, abs_nhood, + distance, + n_hop_weights, spatial_connectivities_key, ) else: @@ -496,13 +513,13 @@ def _validate_args( "param 'groups', 'subset_groups', 'min_niche_size', 'scale', 'abs_nhood' are not used for cellcharter flavor.", stacklevel=2, ) - if adj_subsets is not None and aggregation is not None and n_components is not None: + if distance is not None and aggregation is not None and n_components is not None: _get_cellcharter_niches( - adata, subset_groups, adj_subsets, aggregation, n_components, random_state, spatial_connectivities_key + adata, subset_groups, distance, aggregation, n_components, random_state, spatial_connectivities_key ) else: raise ValueError( - "One of required args 'adj_subsets', 'aggregation' and 'n_components' for flavor 'cellcharter' is 'None'." + "One of required args 'distance', 'aggregation' and 'n_components' for flavor 'cellcharter' is 'None'." ) else: raise ValueError(f"Invalid flavor '{flavor}'. Please choose one of 'neighborhood', 'utag', 'cellcharter'.") From 91895cdbba924d3591df7c336dc1e5356dcf3e87 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 4 Dec 2024 17:04:17 +0100 Subject: [PATCH 58/67] Fix docs; Fix relative frequency calculation; Update arg validation function name --- src/squidpy/gr/_niche.py | 48 ++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/src/squidpy/gr/_niche.py b/src/squidpy/gr/_niche.py index 4a7a6cab..911ca2a7 100644 --- a/src/squidpy/gr/_niche.py +++ b/src/squidpy/gr/_niche.py @@ -1,8 +1,7 @@ from __future__ import annotations import warnings -from collections.abc import Iterator -from typing import Any, Literal, Union +from typing import Any, Literal import anndata as ad import numpy as np @@ -20,13 +19,12 @@ from squidpy._constants._constants import NicheDefinitions from squidpy._docs import d, inject_docs -from squidpy._utils import NDArrayA __all__ = ["calculate_niche"] @d.dedent -# @inject_docs(m=NicheDefinitions) +@inject_docs(fla=NicheDefinitions) def calculate_niche( adata: AnnData | SpatialData, flavor: Literal["neighborhood", "utag", "cellcharter"] = "neighborhood", @@ -56,11 +54,11 @@ def calculate_niche( %(adata)s flavor Method to use for niche calculation. Available options are: - - `{c.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. - - `{c.UTAG.s!r}` - use utag algorithm (matrix multiplication). - - `{c.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. - - `{c.SPOT.s!r}` - calculate niches using optimal transport. (coming soon) - - `{c.BANKSY.s!r}`- use Banksy algorithm. (coming soon) + - `{fla.NEIGHBORHOOD.s!r}` - cluster the neighborhood profile. + - `{fla.UTAG.s!r}` - use utag algorithm (matrix multiplication). + - `{fla.CELLCHARTER.s!r}` - cluster adjacency matrix with Gaussian Mixture Model (GMM) using CellCharter's approach. + - `{fla.SPOT.s!r}` - calculate niches using optimal transport. (coming soon) + - `{fla.BANKSY.s!r}`- use Banksy algorithm. (coming soon) %(library_key)s table_key Key in `spatialdata.tables` to specify an 'anndata' table. Only necessary if 'sdata' is passed. @@ -69,41 +67,41 @@ def calculate_niche( Note that if you want to exclude these cells during neighborhood calculation already, you should subset your AnnData table before running 'sq.gr.spatial_neigbors'. groups Groups based on which to calculate neighborhood profile (E.g. columns of cell type annotations in adata.obs). - Required if flavor == `{c.NEIGHBORHOOD.s!r}`. + Required if flavor == `{fla.NEIGHBORHOOD.s!r}`. n_neighbors Number of neighbors to use for 'scanpy.pp.neighbors' before clustering using leiden algorithm. - Required if flavor == `{c.NEIGHBORHOOD.s!r}` or flavor == `{c.UTAG.s!r}`. + Required if flavor == `{fla.NEIGHBORHOOD.s!r}` or flavor == `{fla.UTAG.s!r}`. resolutions List of resolutions to use for leiden clustering. - Required if flavor == `{c.NEIGHBORHOOD.s!r}` or flavor == `{c.UTAG.s!r}`. + Required if flavor == `{fla.NEIGHBORHOOD.s!r}` or flavor == `{fla.UTAG.s!r}`. subset_groups Groups (e.g. cell type categories) to ignore when calculating the neighborhood profile. - Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. + Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. min_niche_size Minimum required size of a niche. Niches with fewer cells will be labeled as 'not_a_niche'. - Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. + Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. scale If 'True', compute z-scores of neighborhood profiles. - Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. + Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. abs_nhood If 'True', calculate niches based on absolute neighborhood profile. - Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. + Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. distance n-hop neighbor adjacency matrices to use e.g. [1,2,3] for 1-hop,2-hop,3-hop neighbors respectively or "5" for 1-hop,...,5-hop neighbors. 0 (self) is always included. - Required if flavor == `{c.CELLCHARTER.s!r}`. - Optional if flavor == `{c.NEIGHBORHOOD.s!r}`. + Required if flavor == `{fla.CELLCHARTER.s!r}`. + Optional if flavor == `{fla.NEIGHBORHOOD.s!r}`. n_hop_weights How to weight subsequent n-hop adjacency matrices. E.g. [1, 0.5, 0.25] for weights of 1-hop, 2-hop, 3-hop adjacency matrices respectively. - Optional if flavor == `{c.NEIGHBORHOOD.s!r}` and `distance` > 1. + Optional if flavor == `{fla.NEIGHBORHOOD.s!r}` and `distance` > 1. aggregation How to aggregate count matrices. Either 'mean' or 'variance'. - Required if flavor == `{c.CELLCHARTER.s!r}`. + Required if flavor == `{fla.CELLCHARTER.s!r}`. n_components Number of components to use for GMM. - Required if flavor == `{c.CELLCHARTER.s!r}`. + Required if flavor == `{fla.CELLCHARTER.s!r}`. random_state Random state to use for GMM. - Optional if flavor == `{c.CELLCHARTER.s!r}`. + Optional if flavor == `{fla.CELLCHARTER.s!r}`. spatial_connectivities_key Key in `adata.obsp` where spatial connectivities are stored. """ @@ -123,7 +121,7 @@ def calculate_niche( f"Key '{spatial_connectivities_key}' not found in `adata.obsp`. If you haven't computed a spatial neighborhood graph yet, use `sq.gr.spatial_neighbors`." ) - _validate_args( + _validate_niche_args( adata, mask, flavor, @@ -191,6 +189,8 @@ def _get_nhood_profile_niches( nhood_profile += n_hop_weights[n_hop + 1] * _calculate_neighborhood_profile( adata, groups, matrix, abs_nhood ) + if not abs_nhood: + nhood_profile = nhood_profile / sum(n_hop_weights) # create AnnData object from neighborhood profile to perform scanpy functions adata_neighborhood = ad.AnnData(X=nhood_profile) @@ -452,7 +452,7 @@ def _jensen_shannon_divergence(adata: AnnData, niche_key: str, library_key: str) return distance.jensenshannon(np.array(label_distributions)) -def _validate_args( +def _validate_niche_args( adata: AnnData, mask: pd.core.series.Series | None, flavor: Literal["neighborhood", "utag", "cellcharter"], From e771218547fce4302c5be7893ded6bfaeba14c1a Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 4 Dec 2024 17:17:51 +0100 Subject: [PATCH 59/67] Fix tests --- tests/graph/test_niche.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index 54ec4586..4873c734 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -35,10 +35,7 @@ def test_neighborhood_profile_calculation(adata_seqfish: AnnData): assert len(niches[niches == label]) >= 100 rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - adata_seqfish, - groups="celltype_mapped_refined", - subset_groups=None, - spatial_connectivities_key=SPATIAL_CONNECTIVITIES_KEY, + adata_seqfish, groups="celltype_mapped_refined", spatial_connectivities_key=SPATIAL_CONNECTIVITIES_KEY ) # assert shape obs x groups assert rel_nhood_profile.shape == ( @@ -87,9 +84,7 @@ def test_cellcharter_approach(adata_seqfish: AnnData): """Check whether niche calculation using CellCharter approach works as intended.""" spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=N_NEIGHBORS) - calculate_niche( - adata_seqfish, groups="celltype_mapped_refined", flavor="cellcharter", adj_subsets=3, n_components=5 - ) + calculate_niche(adata_seqfish, groups="celltype_mapped_refined", flavor="cellcharter", distance=3, n_components=5) niches = adata_seqfish.obs["cellcharter_niche"] assert niches.nunique() == 5 From 48d495765a3a22e786197c4335d77374312e3fb4 Mon Sep 17 00:00:00 2001 From: LLehner Date: Wed, 4 Dec 2024 17:42:14 +0100 Subject: [PATCH 60/67] Fix tests --- tests/graph/test_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index 4873c734..432dca27 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -35,7 +35,7 @@ def test_neighborhood_profile_calculation(adata_seqfish: AnnData): assert len(niches[niches == label]) >= 100 rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - adata_seqfish, groups="celltype_mapped_refined", spatial_connectivities_key=SPATIAL_CONNECTIVITIES_KEY + adata_seqfish, groups="celltype_mapped_refined" ) # assert shape obs x groups assert rel_nhood_profile.shape == ( From 0c691a62b82330606aa4de9c7f37c5cd0b19183f Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 9 Dec 2024 14:20:25 +0100 Subject: [PATCH 61/67] Update tests --- tests/conftest.py | 17 +++++++++++++++++ tests/graph/test_niche.py | 25 +++++++++++++++++++------ 2 files changed, 36 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 01dfe1e7..6662426e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -448,3 +448,20 @@ def pytest_collection_modifyitems(config, items): @pytest.fixture(scope="session") def _test_napari(pytestconfig): _ = pytestconfig.getoption("--test-napari", skip=True) + + +@pytest.fixture() +def adjacency_matrix(): + return np.array( + [ + [0, 1, 1, 0], + [1, 0, 1, 0], + [1, 1, 0, 1], + [0, 0, 1, 0], + ] + ) + + +@pytest.fixture() +def nhop_matrix(): + return np.array([[2, 1, 1, 1], [1, 2, 1, 1], [1, 1, 3, 0], [1, 1, 0, 1]]) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index 432dca27..eb39a4bd 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -12,6 +12,7 @@ SPATIAL_CONNECTIVITIES_KEY = "spatial_connectivities" N_NEIGHBORS = 20 +GROUPS = "celltype_mapped_refined" def test_neighborhood_profile_calculation(adata_seqfish: AnnData): @@ -19,7 +20,7 @@ def test_neighborhood_profile_calculation(adata_seqfish: AnnData): spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=N_NEIGHBORS) calculate_niche( adata_seqfish, - groups="celltype_mapped_refined", + groups=GROUPS, flavor="neighborhood", n_neighbors=N_NEIGHBORS, resolutions=[0.1], @@ -34,13 +35,16 @@ def test_neighborhood_profile_calculation(adata_seqfish: AnnData): if label != "not_a_niche": assert len(niches[niches == label]) >= 100 - rel_nhood_profile, abs_nhood_profile = _calculate_neighborhood_profile( - adata_seqfish, groups="celltype_mapped_refined" - ) + # get obs x neighbor matrix from sparse matrix + matrix = adata_seqfish.obsp[SPATIAL_CONNECTIVITIES_KEY].tocoo() + + # get obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood + rel_nhood_profile = _calculate_neighborhood_profile(adata_seqfish, groups=GROUPS, matrix=matrix, abs_nhood=True) + abs_nhood_profile = _calculate_neighborhood_profile(adata_seqfish, groups=GROUPS, matrix=matrix, abs_nhood=False) # assert shape obs x groups assert rel_nhood_profile.shape == ( adata_seqfish.n_obs, - len(adata_seqfish.obs["celltype_mapped_refined"].cat.categories), + len(adata_seqfish.obs[GROUPS].cat.categories), ) assert abs_nhood_profile.shape == rel_nhood_profile.shape # normalization @@ -84,7 +88,7 @@ def test_cellcharter_approach(adata_seqfish: AnnData): """Check whether niche calculation using CellCharter approach works as intended.""" spatial_neighbors(adata_seqfish, coord_type="generic", delaunay=False, n_neighs=N_NEIGHBORS) - calculate_niche(adata_seqfish, groups="celltype_mapped_refined", flavor="cellcharter", distance=3, n_components=5) + calculate_niche(adata_seqfish, groups=GROUPS, flavor="cellcharter", distance=3, n_components=5) niches = adata_seqfish.obs["cellcharter_niche"] assert niches.nunique() == 5 @@ -113,4 +117,13 @@ def test_cellcharter_approach(adata_seqfish: AnnData): # TODO: add test for GMM +def test_nhop(adjacency_matrix: np.array, n_hop_matrix: np.array): + """Test if n-hop neighbor computation works as expected.""" + + assert adjacency_matrix**2 == n_hop_matrix + adj_sparse = scipy.sparse.csr_matrix(adjacency_matrix) + nhop_sparse = scipy.sparse.csr_matrix(n_hop_matrix) + assert (adj_sparse.dot(adj_sparse)) == nhop_sparse + + # TODO: comppare results to previously calculated niches From c7ef9057c2d5408d9b044abbf619d4a8d5c3117f Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 9 Dec 2024 15:06:17 +0100 Subject: [PATCH 62/67] Fix tests --- tests/graph/test_niche.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index eb39a4bd..75d02c83 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -39,8 +39,8 @@ def test_neighborhood_profile_calculation(adata_seqfish: AnnData): matrix = adata_seqfish.obsp[SPATIAL_CONNECTIVITIES_KEY].tocoo() # get obs x category matrix where each column is the absolute/relative frequency of a category in the neighborhood - rel_nhood_profile = _calculate_neighborhood_profile(adata_seqfish, groups=GROUPS, matrix=matrix, abs_nhood=True) - abs_nhood_profile = _calculate_neighborhood_profile(adata_seqfish, groups=GROUPS, matrix=matrix, abs_nhood=False) + rel_nhood_profile = _calculate_neighborhood_profile(adata_seqfish, groups=GROUPS, matrix=matrix, abs_nhood=False) + abs_nhood_profile = _calculate_neighborhood_profile(adata_seqfish, groups=GROUPS, matrix=matrix, abs_nhood=True) # assert shape obs x groups assert rel_nhood_profile.shape == ( adata_seqfish.n_obs, From 87cf01aba4aaa8276b523ced7c4fcbcfd9890f50 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 9 Dec 2024 15:23:27 +0100 Subject: [PATCH 63/67] Fix tests --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6662426e..016d20d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -463,5 +463,5 @@ def adjacency_matrix(): @pytest.fixture() -def nhop_matrix(): +def n_hop_matrix(): return np.array([[2, 1, 1, 1], [1, 2, 1, 1], [1, 1, 3, 0], [1, 1, 0, 1]]) From 1a3361b8fbb3ee72905608e70387de201c84732c Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 9 Dec 2024 15:50:18 +0100 Subject: [PATCH 64/67] Fix tests --- tests/graph/test_niche.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index 75d02c83..edb3d54b 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -120,10 +120,10 @@ def test_cellcharter_approach(adata_seqfish: AnnData): def test_nhop(adjacency_matrix: np.array, n_hop_matrix: np.array): """Test if n-hop neighbor computation works as expected.""" - assert adjacency_matrix**2 == n_hop_matrix + assert np.array_equal(adjacency_matrix**2, n_hop_matrix) adj_sparse = scipy.sparse.csr_matrix(adjacency_matrix) nhop_sparse = scipy.sparse.csr_matrix(n_hop_matrix) - assert (adj_sparse.dot(adj_sparse)) == nhop_sparse + assert (adj_sparse.dot(adj_sparse)) != nhop_sparse).nnz == 0 # TODO: comppare results to previously calculated niches From 192f14616ae0fd81c592b3a243662dab07bb7beb Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 9 Dec 2024 15:58:06 +0100 Subject: [PATCH 65/67] Fix tests --- tests/graph/test_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index edb3d54b..24f15b90 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -123,7 +123,7 @@ def test_nhop(adjacency_matrix: np.array, n_hop_matrix: np.array): assert np.array_equal(adjacency_matrix**2, n_hop_matrix) adj_sparse = scipy.sparse.csr_matrix(adjacency_matrix) nhop_sparse = scipy.sparse.csr_matrix(n_hop_matrix) - assert (adj_sparse.dot(adj_sparse)) != nhop_sparse).nnz == 0 + assert (adj_sparse.dot(adj_sparse) != nhop_sparse).nnz == 0 # TODO: comppare results to previously calculated niches From 1b41546e3db6684928011afc247759cd4594f1d2 Mon Sep 17 00:00:00 2001 From: LLehner Date: Mon, 9 Dec 2024 16:19:47 +0100 Subject: [PATCH 66/67] Fix tests --- tests/graph/test_niche.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index 24f15b90..1611d5de 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -120,10 +120,10 @@ def test_cellcharter_approach(adata_seqfish: AnnData): def test_nhop(adjacency_matrix: np.array, n_hop_matrix: np.array): """Test if n-hop neighbor computation works as expected.""" - assert np.array_equal(adjacency_matrix**2, n_hop_matrix) + assert np.array_equal(adjacency_matrix @ adjacency_matrix , n_hop_matrix) adj_sparse = scipy.sparse.csr_matrix(adjacency_matrix) nhop_sparse = scipy.sparse.csr_matrix(n_hop_matrix) - assert (adj_sparse.dot(adj_sparse) != nhop_sparse).nnz == 0 + assert (adj_sparse @ adj_sparse != nhop_sparse).nnz == 0 # TODO: comppare results to previously calculated niches From 6dcf561cb1947a39de0b0fdfaa0c9e2ab9321577 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 9 Dec 2024 15:20:43 +0000 Subject: [PATCH 67/67] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/graph/test_niche.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/graph/test_niche.py b/tests/graph/test_niche.py index 1611d5de..a5225ca6 100644 --- a/tests/graph/test_niche.py +++ b/tests/graph/test_niche.py @@ -120,7 +120,7 @@ def test_cellcharter_approach(adata_seqfish: AnnData): def test_nhop(adjacency_matrix: np.array, n_hop_matrix: np.array): """Test if n-hop neighbor computation works as expected.""" - assert np.array_equal(adjacency_matrix @ adjacency_matrix , n_hop_matrix) + assert np.array_equal(adjacency_matrix @ adjacency_matrix, n_hop_matrix) adj_sparse = scipy.sparse.csr_matrix(adjacency_matrix) nhop_sparse = scipy.sparse.csr_matrix(n_hop_matrix) assert (adj_sparse @ adj_sparse != nhop_sparse).nnz == 0