Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Mar 27, 2024
1 parent 455693a commit e329ed7
Showing 1 changed file with 43 additions and 23 deletions.
66 changes: 43 additions & 23 deletions src/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,37 @@ class MrVI(JaxTrainingMixin, BaseModelClass):
Parameters
----------
adata
adata : AnnData
AnnData object that has been registered via ``setup_anndata``.
n_latent
Dimensionality of the latent space.
n_latent_sample
Dimensionality of the latent space for sample embeddings.
encoder_n_hidden
n_latent : int
Dimensionality of the latent space for `z`.
n_latent_u : int
Dimensionality of the latent space for `u`.
encoder_n_hidden : int
Number of nodes per hidden layer in the encoder.
px_kwargs
encoder_n_layers : int
Number of hidden layers in the encoder.
z_u_prior : bool
Whether to use a prior for `z_u`.
z_u_prior_scale : float
Scale of the prior for the difference between `z` and `u`.
u_prior_scale : float
Scale of the prior for `u`.
u_prior_mixture : bool
Whether to use a mixture model for the `u` prior.
u_prior_mixture_k : int
Number of components in the mixture model for the `u` prior.
learn_z_u_prior_scale : bool
Whether to learn the scale of the `z` and `u` difference prior during training.
laplace_scale : float, optional
Scale parameter for the Laplace distribution in the decoder.
scale_observations : bool
Whether to scale loss by the number of observations per sample.
px_kwargs : dict, optional
Keyword args for :class:`~mrvi.DecoderZX`.
qz_kwargs
qz_kwargs : dict, optional
Keyword args for :class:`~mrvi.EncoderUZ`.
qu_kwargs
qu_kwargs : dict, optional
Keyword args for :class:`~mrvi.EncoderXU`.
"""

Expand Down Expand Up @@ -758,7 +776,7 @@ def differential_abundance(
adata: AnnData | None = None,
sample_cov_keys: list[str] | None = None,
sample_subset: list[str] | None = None,
minibatch_size: int = 128,
batch_size: int = 128,
) -> xr.Dataset:
"""
Compute the differential abundance between samples.
Expand All @@ -776,7 +794,7 @@ def differential_abundance(
when computing the differential abundance.
sample_subset
Only compute differential abundance for these sample labels.
minibatch_size
batch_size
Minibatch size for computing the differential abundance.
Returns
Expand All @@ -787,13 +805,17 @@ def differential_abundance(
- `{cov_key}_log_probs`: For each key in `sample_cov_keys`, an array of shape (n_cells, n_cov_values) containing the log probabilities for each cell across covariate values.
"""
adata = self._validate_anndata(adata)
us = self.get_latent_representation(adata, use_mean=True, give_z=False)
us = self.get_latent_representation(
adata, use_mean=True, give_z=False, batch_size=batch_size
)

log_probs = []
unique_samples = adata.obs[self.sample_key].unique()
for sample_name in tqdm(unique_samples):
ap = self.get_aggregated_posterior(adata=adata, sample=sample_name)
n_splits = adata.n_obs // minibatch_size
ap = self.get_aggregated_posterior(
adata=adata, sample=sample_name, batch_size=batch_size
)
n_splits = max(adata.n_obs // batch_size, 1)
log_probs_ = []
for u_rep in np.array_split(us, n_splits):
log_probs_.append(
Expand Down Expand Up @@ -864,7 +886,7 @@ def get_outlier_cell_sample_pairs(
subsample_size: int = 5_000,
quantile_threshold: float = 0.05,
admissibility_threshold: float = 0.0,
minibatch_size: int = 256,
batch_size: int = 256,
) -> xr.Dataset:
"""Utils function to get outlier cell-sample pairs.
Expand Down Expand Up @@ -903,7 +925,7 @@ def get_outlier_cell_sample_pairs(
log_probs_s = jnp.quantile(
ap.log_prob(adata_s.obsm["U"]).sum(axis=1), q=quantile_threshold
)
n_splits = adata.n_obs // minibatch_size
n_splits = adata.n_obs // batch_size
log_probs_ = []
for u_rep in np.array_split(adata.obsm["U"], n_splits):
log_probs_.append(
Expand Down Expand Up @@ -998,13 +1020,13 @@ def differential_expression(
Epsilon to add to the log-fold changes to avoid detecting genes with low expression.
filter_inadmissible_samples
Whether to filter out-of-distribution samples prior to performing the analysis.
filter_inadmissible_samples_kwargs
Keyword arguments to pass to `get_outlier_cell_sample_pairs`.
lambd
Regularization parameter for the linear model.
delta
LFC threshold used to compute posterior DE probabilities.
If None does not compute them to save memory consumption.
filter_samples_kwargs
Keyword arguments to pass to `get_outlier_cell_sample_pairs`.
Returns
-------
Expand All @@ -1020,11 +1042,9 @@ def differential_expression(
- `baseline_expression`: Baseline expression levels for each covariate across cells and genes, if `store_baseline` is True.
- `n_samples`: Number of admissible samples for each cell, if `filter_inadmissible_samples` is True.
"""
assert (
sample_cov_keys is not None
), (
"Must assign `sample_cov_keys`"
) # Keep as kwarg to maintain order of arguments.
if sample_cov_keys is not None:
# Hack: kept as kwarg to maintain order of arguments.
raise ValueError("Must assign `sample_cov_keys`")
adata = self.adata if adata is None else adata
self._check_if_trained(warn=False)
# Hack to ensure new AnnDatas have indices and indices have correct dimensions.
Expand Down

0 comments on commit e329ed7

Please sign in to comment.