Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
Fix tests and commment out
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 committed Jan 30, 2024
1 parent 13b5703 commit 4b334e1
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 158 deletions.
11 changes: 9 additions & 2 deletions src/scvi_v2/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from statsmodels.stats.multitest import multipletests
from tqdm import tqdm

from ._components import MLP
from ._constants import MRVI_REGISTRY_KEYS
from ._module import MrVAE
from ._tree_utils import (
Expand Down Expand Up @@ -529,12 +530,18 @@ def _compute_local_baseline_dists(
def get_A_s(module, u, sample_covariate):
sample_covariate = sample_covariate.astype(int).flatten()
if getattr(module.qz, "use_nonlinear", False):
A_s = module.qz.A_s_enc(sample_covariate)
A_s = module.qz.A_s_enc(sample_covariate, training=False)
else:
# A_s output by a non-linear function without an explicit intercept
sample_one_hot = jax.nn.one_hot(sample_covariate, module.qz.n_sample)
A_s_dec_inputs = jnp.concatenate([u, sample_one_hot], axis=-1)
A_s = module.qz.A_s_enc(A_s_dec_inputs, training=False)

if isinstance(module.qz.A_s_enc, MLP):
A_s = module.qz.A_s_enc(A_s_dec_inputs, training=False)
else:
# nn.Embed does not support training kwarg
A_s = module.qz.A_s_enc(A_s_dec_inputs)

# cells by n_latent by n_latent
return A_s.reshape(sample_covariate.shape[0], module.qz.n_latent, -1)

Expand Down
30 changes: 25 additions & 5 deletions src/scvi_v2/_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any
from typing import Any, Literal

import flax.linen as nn
import jax
Expand All @@ -27,7 +27,12 @@
"stop_gradients_mlp": True,
"dropout_rate": 0.03,
}
DEFAULT_QZ_KWARGS = {
DEFAULT_QZ_KWARGS = {}
DEFAULT_QZ_MLP_KWARGS = {
"use_map": True,
"stop_gradients": False,
}
DEFAULT_QZ_ATTENTION_KWARGS = {
"use_map": True,
"stop_gradients": False,
"stop_gradients_mlp": True,
Expand Down Expand Up @@ -230,7 +235,13 @@ def __call__(
if u_drop.ndim == 3:
sample_one_hot = jnp.tile(sample_one_hot, (u_drop.shape[0], 1, 1))
A_s_enc_inputs = jnp.concatenate([u_drop, sample_one_hot], axis=-1)
A_s = self.A_s_enc(A_s_enc_inputs, training=training)

if isinstance(self.A_s_enc, MLP):
A_s = self.A_s_enc(A_s_enc_inputs, training=training)
else:
# nn.Embed does not support training kwarg
A_s = self.A_s_enc(A_s_enc_inputs)

if u_drop.ndim == 3:
A_s = A_s.reshape(
u_drop.shape[0], sample_covariate.shape[0], self.n_latent, n_latent_u
Expand Down Expand Up @@ -397,7 +408,7 @@ class MrVAE(JaxBaseModuleClass):
laplace_scale: float = None
scale_observations: bool = False
px_nn_flavor: str = "attention"
qz_nn_flavor: str = "attention"
qz_nn_flavor: Literal["linear", "mlp", "attention"] = "attention"
px_kwargs: dict | None = None
qz_kwargs: dict | None = None
qu_kwargs: dict | None = None
Expand All @@ -408,9 +419,18 @@ def setup(self):
px_kwargs = DEFAULT_PX_KWARGS.copy()
if self.px_kwargs is not None:
px_kwargs.update(self.px_kwargs)
qz_kwargs = DEFAULT_QZ_KWARGS.copy()

if self.qz_nn_flavor == "linear":
qz_kwargs = DEFAULT_QZ_KWARGS.copy()
elif self.qz_nn_flavor == "mlp":
qz_kwargs = DEFAULT_QZ_MLP_KWARGS.copy()
elif self.qz_nn_flavor == "attention":
qz_kwargs = DEFAULT_QZ_ATTENTION_KWARGS.copy()
else:
raise ValueError(f"Unknown qz_nn_flavor: {self.qz_nn_flavor}")
if self.qz_kwargs is not None:
qz_kwargs.update(self.qz_kwargs)

qu_kwargs = DEFAULT_QU_KWARGS.copy()
if self.qu_kwargs is not None:
qu_kwargs.update(self.qu_kwargs)
Expand Down
Loading

0 comments on commit 4b334e1

Please sign in to comment.