Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 committed Jan 31, 2024
1 parent 4b334e1 commit 4053e89
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 32 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ name = "scvi-v2"
version = "0.0.1"
description = "Multi-resolution Variational Inference"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.9"
license = {file = "LICENSE"}
authors = [
{name = "Justin Hong"},
Expand Down
33 changes: 24 additions & 9 deletions src/scvi_v2/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist
from flax.linen.dtypes import promote_dtype
from flax.linen.initializers import variance_scaling

from ._types import Dtype, NdArray, PRNGKey, Shape
from ._types import Dtype, PRNGKey, Shape

_normal_initializer = jax.nn.initializers.normal(stddev=0.1)

Expand All @@ -37,7 +38,9 @@ class ResnetBlock(nn.Module):
training: bool | None = None

@nn.compact
def __call__(self, inputs: NdArray, training: bool | None = None) -> NdArray:
def __call__(
self, inputs: np.ndarray | jnp.ndarray, training: bool | None = None
) -> np.ndarray | jnp.ndarray:
training = nn.merge_param("training", self.training, training)
h = Dense(self.n_hidden)(inputs)
h = nn.LayerNorm()(h)
Expand All @@ -62,7 +65,9 @@ class MLP(nn.Module):
training: bool | None = None

@nn.compact
def __call__(self, inputs: NdArray, training: bool | None = None) -> dist.Normal:
def __call__(
self, inputs: np.ndarray | jnp.ndarray, training: bool | None = None
) -> dist.Normal:
training = nn.merge_param("training", self.training, training)
h = inputs
for _ in range(self.n_layers):
Expand All @@ -84,7 +89,9 @@ class NormalDistOutputNN(nn.Module):
training: bool | None = None

@nn.compact
def __call__(self, inputs: NdArray, training: bool | None = None) -> dist.Normal:
def __call__(
self, inputs: np.ndarray | jnp.ndarray, training: bool | None = None
) -> dist.Normal:
training = nn.merge_param("training", self.training, training)
h = inputs
for _ in range(self.n_layers):
Expand Down Expand Up @@ -125,7 +132,10 @@ def init(

@nn.compact
def __call__(
self, x: NdArray, condition: NdArray, training: bool | None = None
self,
x: np.ndarray | jnp.ndarray,
condition: np.ndarray | jnp.ndarray,
training: bool | None = None,
) -> jnp.ndarray:
training = nn.merge_param("training", self.training, training)
if self.normalization_type == "batch":
Expand Down Expand Up @@ -173,7 +183,10 @@ class AttentionBlock(nn.Module):

@nn.compact
def __call__(
self, query_embed: NdArray, kv_embed: NdArray, training: bool | None = None
self,
query_embed: np.ndarray | jnp.ndarray,
kv_embed: np.ndarray | jnp.ndarray,
training: bool | None = None,
):
training = nn.merge_param("training", self.training, training)
has_mc_samples = query_embed.ndim == 3
Expand Down Expand Up @@ -250,9 +263,11 @@ class FactorizedEmbedding(nn.Module):
factorized_features: int
dtype: Dtype | None = None
param_dtype: Dtype = jnp.float32
embedding_init: callable[[PRNGKey, Shape, Dtype], NdArray] = _normal_initializer
embedding_init: callable[
[PRNGKey, Shape, Dtype], np.ndarray | jnp.ndarray
] = _normal_initializer

embedding: NdArray = dataclasses.field(init=False)
embedding: np.ndarray | jnp.ndarray = dataclasses.field(init=False)

def setup(self) -> None:
"""Initialize the embedding matrix."""
Expand All @@ -269,7 +284,7 @@ def setup(self) -> None:
self.param_dtype,
)

def __call__(self, inputs: NdArray) -> NdArray:
def __call__(self, inputs: np.ndarray | jnp.ndarray) -> np.ndarray | jnp.ndarray:
"""
Embeds the inputs along the last dimension.
Expand Down
48 changes: 32 additions & 16 deletions src/scvi_v2/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist
from scvi import REGISTRY_KEYS
from scvi.distributions import JaxNegativeBinomialMeanDisp as NegativeBinomial
Expand All @@ -19,7 +20,6 @@
NormalDistOutputNN,
)
from ._constants import MRVI_REGISTRY_KEYS
from ._types import NdArray

DEFAULT_PX_KWARGS = {
"n_hidden": 32,
Expand Down Expand Up @@ -56,10 +56,10 @@ class _DecoderZX(nn.Module):
@nn.compact
def __call__(
self,
z: NdArray,
batch_covariate: NdArray,
size_factor: NdArray,
continuous_covariates: NdArray | None,
z: np.ndarray | jnp.ndarray,
batch_covariate: np.ndarray | jnp.ndarray,
size_factor: np.ndarray | jnp.ndarray,
continuous_covariates: np.ndarray | jnp.ndarray | None,
training: bool | None = None,
) -> NegativeBinomial:
h1 = Dense(self.n_out, use_bias=False, name="amat")(z)
Expand Down Expand Up @@ -117,10 +117,10 @@ class _DecoderZXAttention(nn.Module):
@nn.compact
def __call__(
self,
z: NdArray,
batch_covariate: NdArray,
size_factor: NdArray,
continuous_covariates: NdArray | None,
z: np.ndarray | jnp.ndarray,
batch_covariate: np.ndarray | jnp.ndarray,
size_factor: np.ndarray | jnp.ndarray,
continuous_covariates: np.ndarray | jnp.ndarray | None,
training: bool | None = None,
) -> NegativeBinomial:
has_mc_samples = z.ndim == 3
Expand Down Expand Up @@ -212,7 +212,10 @@ def setup(self):
)

def __call__(
self, u: NdArray, sample_covariate: NdArray, training: bool | None = None
self,
u: np.ndarray | jnp.ndarray,
sample_covariate: np.ndarray | jnp.ndarray,
training: bool | None = None,
) -> jnp.ndarray:
training = nn.merge_param("training", self.training, training)
sample_covariate = sample_covariate.astype(int).flatten()
Expand Down Expand Up @@ -272,7 +275,10 @@ class _EncoderUZ2(nn.Module):

@nn.compact
def __call__(
self, u: NdArray, sample_covariate: NdArray, training: bool | None = None
self,
u: np.ndarray | jnp.ndarray,
sample_covariate: np.ndarray | jnp.ndarray,
training: bool | None = None,
):
training = nn.merge_param("training", self.training, training)
sample_covariate = sample_covariate.astype(int).flatten()
Expand Down Expand Up @@ -319,7 +325,10 @@ class _EncoderUZ2Attention(nn.Module):

@nn.compact
def __call__(
self, u: NdArray, sample_covariate: NdArray, training: bool | None = None
self,
u: np.ndarray | jnp.ndarray,
sample_covariate: np.ndarray | jnp.ndarray,
training: bool | None = None,
):
training = nn.merge_param("training", self.training, training)
sample_covariate = sample_covariate.astype(int).flatten()
Expand Down Expand Up @@ -367,7 +376,10 @@ class _EncoderXU(nn.Module):

@nn.compact
def __call__(
self, x: NdArray, sample_covariate: NdArray, training: bool | None = None
self,
x: np.ndarray | jnp.ndarray,
sample_covariate: np.ndarray | jnp.ndarray,
training: bool | None = None,
) -> dist.Normal:
training = nn.merge_param("training", self.training, training)
x_feat = jnp.log1p(x)
Expand Down Expand Up @@ -501,7 +513,9 @@ def setup(self):
def required_rngs(self):
return ("params", "u", "dropout", "eps")

def _get_inference_input(self, tensors: dict[str, NdArray]) -> dict[str, Any]:
def _get_inference_input(
self, tensors: dict[str, np.ndarray | jnp.ndarray]
) -> dict[str, Any]:
x = tensors[REGISTRY_KEYS.X_KEY]
sample_index = tensors[MRVI_REGISTRY_KEYS.SAMPLE_KEY]
return {"x": x, "sample_index": sample_index}
Expand Down Expand Up @@ -548,7 +562,9 @@ def inference(
}

def _get_generative_input(
self, tensors: dict[str, NdArray], inference_outputs: dict[str, Any]
self,
tensors: dict[str, np.ndarray | jnp.ndarray],
inference_outputs: dict[str, Any],
) -> dict[str, Any]:
z = inference_outputs["z"]
library = inference_outputs["library"]
Expand Down Expand Up @@ -590,7 +606,7 @@ def generative(self, z, library, batch_index, label_index, continuous_covs):

def loss(
self,
tensors: dict[str, NdArray],
tensors: dict[str, np.ndarray | jnp.ndarray],
inference_outputs: dict[str, Any],
generative_outputs: dict[str, Any],
kl_weight: float = 1.0,
Expand Down
11 changes: 5 additions & 6 deletions src/scvi_v2/_types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from __future__ import annotations

from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, Union
from typing import Any, Literal

import jax.numpy as jnp
import numpy as np
import xarray as xr

NdArray = Union[np.ndarray, jnp.ndarray]
PRNGKey = Any
Shape = tuple[int, ...]
Dtype = Any
Expand Down Expand Up @@ -38,8 +37,8 @@ class MrVIReduction:
"sampled_distances",
"normalized_distances",
]
fn: Callable[[xr.DataArray], xr.DataArray] = lambda x: xr.DataArray(x)
group_by: Optional[str] = None
fn: callable[[xr.DataArray], xr.DataArray] = lambda x: xr.DataArray(x)
group_by: str | None = None


@dataclass(frozen=True)
Expand Down

0 comments on commit 4053e89

Please sign in to comment.