From 455693a073d1bc23d0bce354982242b01a83a07d Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Mon, 25 Mar 2024 23:03:18 -0400 Subject: [PATCH] change back optional sample_cov_keys for de and add assertion, update docstrings --- src/mrvi/_model.py | 47 +++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/mrvi/_model.py b/src/mrvi/_model.py index cb6e3f2..96a3bac 100644 --- a/src/mrvi/_model.py +++ b/src/mrvi/_model.py @@ -760,6 +760,32 @@ def differential_abundance( sample_subset: list[str] | None = None, minibatch_size: int = 128, ) -> xr.Dataset: + """ + Compute the differential abundance between samples. + + Computes the logarithm of the ratio of the probabilities of each sample + conditioned on the estimated aggregate posterior distribution of each cell. + + Parameters + ---------- + adata + The data object to compute the differential abundance for. + If not given, the data object stored in the model is used. + sample_cov_keys + Keys for covariates (batch, etc.) that should also be taken into account + when computing the differential abundance. + sample_subset + Only compute differential abundance for these sample labels. + minibatch_size + Minibatch size for computing the differential abundance. + + Returns + ------- + :class:`xarray.Dataset` + Returns an xarray.Dataset with data variables: + - `log_probs`: Array of shape (n_cells, n_samples) containing the log probabilities for each cell across samples. + - `{cov_key}_log_probs`: For each key in `sample_cov_keys`, an array of shape (n_cells, n_cov_values) containing the log probabilities for each cell across covariate values. + """ adata = self._validate_anndata(adata) us = self.get_latent_representation(adata, use_mean=True, give_z=False) @@ -912,8 +938,8 @@ def get_outlier_cell_sample_pairs( def differential_expression( self, - sample_cov_keys: list[str], adata: AnnData | None = None, + sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, batch_size: int = 128, use_vmap: bool = True, @@ -979,7 +1005,26 @@ def differential_expression( delta LFC threshold used to compute posterior DE probabilities. If None does not compute them to save memory consumption. + + Returns + ------- + xr.Dataset + An xarray Dataset containing the results of the differential expression analysis. The dataset includes: + - `beta`: Coefficients for each covariate across cells and latent dimensions. + - `effect_size`: Effect sizes for each covariate across cells. + - `pvalue`: P-values for each covariate across cells. + - `padj`: Adjusted P-values for each covariate across cells using the Benjamini-Hochberg procedure. + - `lfc`: Log fold changes for each covariate across cells and genes, if `store_lfc` is True. + - `lfc_std`: Standard deviation of log fold changes, if `store_lfc` is True and `delta` is not None. + - `pde`: Posterior DE probabilities, if `store_lfc` is True and `delta` is not None. + - `baseline_expression`: Baseline expression levels for each covariate across cells and genes, if `store_baseline` is True. + - `n_samples`: Number of admissible samples for each cell, if `filter_inadmissible_samples` is True. """ + assert ( + sample_cov_keys is not None + ), ( + "Must assign `sample_cov_keys`" + ) # Keep as kwarg to maintain order of arguments. adata = self.adata if adata is None else adata self._check_if_trained(warn=False) # Hack to ensure new AnnDatas have indices and indices have correct dimensions.