diff --git a/src/mrvi/_model.py b/src/mrvi/_model.py index 96a3bac..3876f13 100644 --- a/src/mrvi/_model.py +++ b/src/mrvi/_model.py @@ -58,19 +58,37 @@ class MrVI(JaxTrainingMixin, BaseModelClass): Parameters ---------- - adata + adata : AnnData AnnData object that has been registered via ``setup_anndata``. - n_latent - Dimensionality of the latent space. - n_latent_sample - Dimensionality of the latent space for sample embeddings. - encoder_n_hidden + n_latent : int + Dimensionality of the latent space for `z`. + n_latent_u : int + Dimensionality of the latent space for `u`. + encoder_n_hidden : int Number of nodes per hidden layer in the encoder. - px_kwargs + encoder_n_layers : int + Number of hidden layers in the encoder. + z_u_prior : bool + Whether to use a prior for `z_u`. + z_u_prior_scale : float + Scale of the prior for the difference between `z` and `u`. + u_prior_scale : float + Scale of the prior for `u`. + u_prior_mixture : bool + Whether to use a mixture model for the `u` prior. + u_prior_mixture_k : int + Number of components in the mixture model for the `u` prior. + learn_z_u_prior_scale : bool + Whether to learn the scale of the `z` and `u` difference prior during training. + laplace_scale : float, optional + Scale parameter for the Laplace distribution in the decoder. + scale_observations : bool + Whether to scale loss by the number of observations per sample. + px_kwargs : dict, optional Keyword args for :class:`~mrvi.DecoderZX`. - qz_kwargs + qz_kwargs : dict, optional Keyword args for :class:`~mrvi.EncoderUZ`. - qu_kwargs + qu_kwargs : dict, optional Keyword args for :class:`~mrvi.EncoderXU`. """ @@ -758,7 +776,7 @@ def differential_abundance( adata: AnnData | None = None, sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, - minibatch_size: int = 128, + batch_size: int = 128, ) -> xr.Dataset: """ Compute the differential abundance between samples. @@ -776,7 +794,7 @@ def differential_abundance( when computing the differential abundance. sample_subset Only compute differential abundance for these sample labels. - minibatch_size + batch_size Minibatch size for computing the differential abundance. Returns @@ -787,13 +805,17 @@ def differential_abundance( - `{cov_key}_log_probs`: For each key in `sample_cov_keys`, an array of shape (n_cells, n_cov_values) containing the log probabilities for each cell across covariate values. """ adata = self._validate_anndata(adata) - us = self.get_latent_representation(adata, use_mean=True, give_z=False) + us = self.get_latent_representation( + adata, use_mean=True, give_z=False, batch_size=batch_size + ) log_probs = [] unique_samples = adata.obs[self.sample_key].unique() for sample_name in tqdm(unique_samples): - ap = self.get_aggregated_posterior(adata=adata, sample=sample_name) - n_splits = adata.n_obs // minibatch_size + ap = self.get_aggregated_posterior( + adata=adata, sample=sample_name, batch_size=batch_size + ) + n_splits = max(adata.n_obs // batch_size, 1) log_probs_ = [] for u_rep in np.array_split(us, n_splits): log_probs_.append( @@ -864,7 +886,7 @@ def get_outlier_cell_sample_pairs( subsample_size: int = 5_000, quantile_threshold: float = 0.05, admissibility_threshold: float = 0.0, - minibatch_size: int = 256, + batch_size: int = 256, ) -> xr.Dataset: """Utils function to get outlier cell-sample pairs. @@ -903,7 +925,7 @@ def get_outlier_cell_sample_pairs( log_probs_s = jnp.quantile( ap.log_prob(adata_s.obsm["U"]).sum(axis=1), q=quantile_threshold ) - n_splits = adata.n_obs // minibatch_size + n_splits = adata.n_obs // batch_size log_probs_ = [] for u_rep in np.array_split(adata.obsm["U"], n_splits): log_probs_.append( @@ -998,13 +1020,13 @@ def differential_expression( Epsilon to add to the log-fold changes to avoid detecting genes with low expression. filter_inadmissible_samples Whether to filter out-of-distribution samples prior to performing the analysis. - filter_inadmissible_samples_kwargs - Keyword arguments to pass to `get_outlier_cell_sample_pairs`. lambd Regularization parameter for the linear model. delta LFC threshold used to compute posterior DE probabilities. If None does not compute them to save memory consumption. + filter_samples_kwargs + Keyword arguments to pass to `get_outlier_cell_sample_pairs`. Returns ------- @@ -1020,11 +1042,9 @@ def differential_expression( - `baseline_expression`: Baseline expression levels for each covariate across cells and genes, if `store_baseline` is True. - `n_samples`: Number of admissible samples for each cell, if `filter_inadmissible_samples` is True. """ - assert ( - sample_cov_keys is not None - ), ( - "Must assign `sample_cov_keys`" - ) # Keep as kwarg to maintain order of arguments. + if sample_cov_keys is not None: + # Hack: kept as kwarg to maintain order of arguments. + raise ValueError("Must assign `sample_cov_keys`") adata = self.adata if adata is None else adata self._check_if_trained(warn=False) # Hack to ensure new AnnDatas have indices and indices have correct dimensions.