Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scVI+MMD: Variable-Strength Batch Correction with scVI #2

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions docs/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,7 @@ References
.. [Lopez18] Romain Lopez, Jeffrey Regier, Michael Cole, Michael I. Jordan, Nir Yosef (2018),
*Deep generative modeling for single-cell transcriptomics*,
`Nature Methods <https://www.nature.com/articles/s41592-018-0229-2.epdf?author_access_token=5sMbnZl1iBFitATlpKkddtRgN0jAjWel9jnR3ZoTv0P1-tTjoP-mBfrGiMqpQx63aBtxToJssRfpqQ482otMbBw2GIGGeinWV4cULBLPg4L4DpCg92dEtoMaB1crCRDG7DgtNrM_1j17VfvHfoy1cQ%3D%3D>`__.

.. [Gretton12] Arthur Gretton, Karsten M. Borgwardt, Malte J. Rasch, Bernhard Schölkopf, and Alexander Smola (2012),
*A Kernel Two-Sample Test*,
`Journal of Machine Learning Research <https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf>`__.
33 changes: 17 additions & 16 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion scvi/model/_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ class SCVI(

* ``'normal'`` - Normal distribution
* ``'ln'`` - Logistic normal distribution (Normal(0, I) transformed by softmax)
mmd_mode
Describes how to compute the MMD component of the objective (loss) function.
One of:
* ``'normal'`` - Compute the exact MMD loss
* ``'fast'`` - Compute the approximate MMD loss
mmd_loss_weight
Describes the weight of the MMD component in the overall loss function.
**model_kwargs
Keyword args for :class:`~scvi.module.VAE`

Expand Down Expand Up @@ -79,6 +86,8 @@ def __init__(
dispersion: Literal["gene", "gene-batch", "gene-label", "gene-cell"] = "gene",
gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb",
latent_distribution: Literal["normal", "ln"] = "normal",
mmd_mode: Literal["normal", "fast"] = "normal",
mmd_loss_weight: float = 1.0,
**model_kwargs,
):
super(SCVI, self).__init__(adata)
Expand All @@ -101,11 +110,13 @@ def __init__(
dispersion=dispersion,
gene_likelihood=gene_likelihood,
latent_distribution=latent_distribution,
mmd_mode=mmd_mode,
mmd_loss_weight=mmd_loss_weight,
**model_kwargs,
)
self._model_summary_string = (
"SCVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
"{}, dispersion: {}, gene_likelihood: {}, latent_distribution: {}"
"{}, dispersion: {}, gene_likelihood: {}, latent_distribution: {}, mmd_mode: {}, mmd_loss_weight: {}"
).format(
n_hidden,
n_latent,
Expand All @@ -114,5 +125,7 @@ def __init__(
dispersion,
gene_likelihood,
latent_distribution,
mmd_mode,
mmd_loss_weight,
)
self.init_params_ = self._get_init_params(locals())
39 changes: 39 additions & 0 deletions scvi/model/base/_vaemixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,45 @@ def get_reconstruction_error(
reconstruction_error = compute_reconstruction_error(self.module, scdl)
return reconstruction_error

@torch.no_grad()
def get_mmd_loss(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> float:
"""
Computes the total MMD loss for the data over the given indices sliced by
the given batch size.

Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

Returns
-------
Float value containing the total MMD loss for the data
"""
adata = self._validate_anndata(adata)
scdl = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

total_mmd = 0.0
for tensors in scdl:
_, _, scvi_loss = self.module(tensors)
if hasattr(scvi_loss, "mmd"):
total_mmd += scvi_loss.mmd.item()

n_samples = len(scdl.indices)
return total_mmd / n_samples

@torch.no_grad()
def get_latent_representation(
self,
Expand Down
174 changes: 171 additions & 3 deletions scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class VAE(BaseModuleClass):
"""
Variational auto-encoder model.

This is an implementation of the scVI model descibed in [Lopez18]_
This is an implementation of the scVI model described in [Lopez18]_

Parameters
----------
Expand Down Expand Up @@ -77,6 +77,13 @@ class VAE(BaseModuleClass):
var_activation
Callable used to ensure positivity of the variational distributions' variance.
When `None`, defaults to `torch.exp`.
mmd_mode
Describes how to compute the MMD component of the objective (loss) function.
One of:
* ``'normal'`` - Compute the exact MMD loss
* ``'fast'`` - Compute the approximate MMD loss
mmd_loss_weight
Describes the weight of the MMD component in the overall loss function.
"""

def __init__(
Expand All @@ -100,6 +107,8 @@ def __init__(
use_layer_norm: Literal["encoder", "decoder", "none", "both"] = "none",
use_observed_lib_size: bool = True,
var_activation: Optional[Callable] = None,
mmd_mode: Literal["normal", "fast"] = "normal",
mmd_loss_weight: float = 1.0,
):
super().__init__()
self.dispersion = dispersion
Expand All @@ -112,6 +121,8 @@ def __init__(
self.latent_distribution = latent_distribution
self.encode_covariates = encode_covariates
self.use_observed_lib_size = use_observed_lib_size
self.mmd_mode = mmd_mode
self.mmd_loss_weight = mmd_loss_weight

if self.dispersion == "gene":
self.px_r = torch.nn.Parameter(torch.randn(n_input))
Expand Down Expand Up @@ -289,6 +300,156 @@ def generative(
px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout
)

def _compute_mmd_kernels(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
"""
Compute the kernels for the given tensors, which can be
used to compute their MMD. The kernel used here is a
Gaussian kernel with :math:`\gamma`=1.

Parameters
----------
z1
First tensor
z2
Second tensor

Returns
-------
Tensor containing the kernels of ``z1`` and ``z2``.
"""
z1_size = z1.size(0)
z2_size = z2.size(0)

d = z1.size(1)
if d != z2.size(1):
raise ValueError(
"z1 and z2 must be defined on the same space, "
"but input was: z1_d={} while "
"z2_d={}.".format(z1.size(1), z2.size(1))
)

z1 = z1.unsqueeze(1) # (z1_size, 1, d)
z2 = z2.unsqueeze(0) # (1, z2_size, d)
z1 = z1.expand(z1_size, z2_size, d)
z2 = z2.expand(z1_size, z2_size, d)

exp_term = (z1 - z2).pow(2).sum(dim=2) # (z1_size, z2_size)
return torch.exp(-exp_term)

def _compute_mmd(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
"""
Computes the Maximum Mean Discrepancy (MMD) of ``z1`` and ``z2`` as described in [Gretton12]_.
Based on https://github.com/napsternxg/pytorch-practice/blob/master/Pytorch%20-%20MMD%20VAE.ipynb

Parameters
----------
z1
First tensor
z2
Second tensor

Returns
-------
Tensor with one item containing the MMD of ``z1`` and ``z2``.
"""
k_z1_z1 = self._compute_mmd_kernels(z1, z1)
k_z2_z2 = self._compute_mmd_kernels(z2, z2)
k_z1_z2 = self._compute_mmd_kernels(z1, z2)
mmd = k_z1_z1.mean() + k_z2_z2.mean() - 2 * k_z1_z2.mean()
return mmd

def _compute_fast_mmd(self, z1: torch.Tensor, z2: torch.Tensor) -> torch.Tensor:
"""
Computes a fast approximation of the MMD. See `_compute_mmd`.

Parameters
----------
z1
First tensor
z2
Second tensor

Returns
-------
Tensor with one item containing an approximation of the MMD of ``z1`` and ``z2``.
"""
z1_size = z1.size(0)
z2_size = z2.size(0)

# z1_size and z2_size must match, otherwise pick their min
batch_size = min(z1_size, z2_size)

# Drop a sample if batch_size is not even
if batch_size % 2 != 0:
batch_size -= 1

z1 = z1[:batch_size, :]
z2 = z2[:batch_size, :]

if z1.size(1) != z2.size(1):
raise ValueError(
"z1 and z2 must be defined on the same space, "
"but input was: z1_d={} while "
"z2_d={}.".format(z1.size(1), z2.size(1))
)

# The order of the cells does not matter here (and is in fact random).
# Thus we can compute the h(.) terms for any pairs of vectors from z1
# and z2. Here we make these pairs from the first and second halves of
# these two tensors.
z1_first_half = z1[: batch_size // 2, :]
z1_second_half = z1[batch_size // 2 :, :]
z2_first_half = z2[: batch_size // 2, :]
z2_second_half = z2[batch_size // 2 :, :]

# Compute the kernels
z1_z1_kernels = torch.exp(-((z1_first_half - z1_second_half).pow(2).sum(1)))
z2_z2_kernels = torch.exp(-((z2_first_half - z2_second_half).pow(2).sum(1)))
z1_z2_kernels = torch.exp(-((z1_first_half - z2_second_half).pow(2).sum(1)))
z2_z1_kernels = torch.exp(-((z1_second_half - z2_first_half).pow(2).sum(1)))

all_kernels = z1_z1_kernels + z2_z2_kernels - z1_z2_kernels - z2_z1_kernels
mmd = all_kernels.mean()
return mmd

def _compute_mmd_loss(
self, z: torch.Tensor, batch_indices: torch.Tensor, mode: str
) -> torch.Tensor:
"""
Computes the overall MMD loss associated with this set of samples. The overall MMD is the sum
of batch-wise MMD's, i.e. the MMD associated with the samples from ``z`` for each pair of sequential
batches in batch_indices.

Parameters
----------
z
Set of samples to compute the overall MMD loss on
batch_indices
Batch indices corresponding to each sample in ``z``. Same length as ``z``.
mode
Whether to compute the approximate MMD ("fast" mode) or the exact MMD ("normal" mode)

Returns
-------
Tensor with one item containing the MMD loss for the samples in ``z``
"""
mmd_loss = torch.tensor(0.0, device=z.device)
batches = torch.unique(batch_indices)
for b0, b1 in zip(batches, batches[1:]):
z0 = z[(batch_indices == b0).reshape(-1)]
z1 = z[(batch_indices == b1).reshape(-1)]
if mode == "normal":
mmd_loss += self._compute_mmd(z0, z1)
elif mode == "fast":
mmd_loss += self._compute_fast_mmd(z0, z1)
else:
raise ValueError(
"Invalid mode passed in: {}. Must be one of 'normal' or 'fast'.".format(
mode
)
)
return mmd_loss

def loss(
self,
tensors,
Expand All @@ -299,11 +460,13 @@ def loss(
x = tensors[_CONSTANTS.X_KEY]
local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY]
local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY]
batch_index = tensors[_CONSTANTS.BATCH_KEY]

qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
ql_m = inference_outputs["ql_m"]
ql_v = inference_outputs["ql_v"]
z = inference_outputs["z"]
px_rate = generative_outputs["px_rate"]
px_r = generative_outputs["px_r"]
px_dropout = generative_outputs["px_dropout"]
Expand All @@ -330,13 +493,18 @@ def loss(

weighted_kl_local = kl_weight * kl_local_for_warmup + kl_local_no_warmup

loss = torch.mean(reconst_loss + weighted_kl_local)
mmd_loss = self._compute_mmd_loss(z, batch_index, self.mmd_mode)

loss = (
torch.mean(reconst_loss + weighted_kl_local)
+ self.mmd_loss_weight * mmd_loss
)

kl_local = dict(
kl_divergence_l=kl_divergence_l, kl_divergence_z=kl_divergence_z
)
kl_global = 0.0
return LossRecorder(loss, reconst_loss, kl_local, kl_global)
return LossRecorder(loss, reconst_loss, kl_local, kl_global, mmd=mmd_loss)

@torch.no_grad()
def sample(
Expand Down
Loading