Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scVI+MMD: Variable-Strength Batch Correction with scVI #2

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

watiss
Copy link
Owner

@watiss watiss commented May 17, 2021

This change adds a component onto the VAE objective function which measures the extent to which posterior distributions for different batches are dissimilar. Essentially, this measures the effectiveness of the batch effect correction. This component is added onto the loss with a scaling factor - beta - which can be used to regulate the batch effect correction (note that at its extreme, batch effect correction can cause a potential over-mixing of cells due to enforcing over-similarity of latent distributions).

The newly added component consists of the aggregate Maximum Mean Discrepancy (https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf), aka MMD, of pairs of sets of samples taken from the latent distributions, grouped per their originator batch. We provide two modes of computation, ‘normal’ and ‘fast’, where the former computes the exact MMD (quadratic runtime in the number of samples) while the latter computes a fast approximation of the MMD (linear runtime in the number of samples). Furthermore, in the spirit of fast approximations, we only compute the MMD for sequential pairs of batches rather than for all pair-wise combinations.

MMD parameters (mode and weight) can be set during model instantiation. The MMD loss is recorded in the history of the trainer and thus is available from the model history once training is complete.

Details of the MMD computation

The formula used for the normal (exact) computation is formula 5 in the paper linked above. We use a Gaussian kernel with gamma=1. X and Y in our case are two sets of samples Z1 and Z2 taken from the latent space where Z1 corresponds to cells originating from batch k and Z2 corresponds to cells originating from batch k’. We carry this out for all sequential (k, k’) pairs and sum over all of them to obtain the aggregate MMD loss, L_mmd.

Finally, the existing SCVI loss (negative ELBO) is updated as follows:
L_scvi-mmd = L_scvi + beta*L_mmd

In fast mode, we proceed the same way as above, with the exception that the formula used for the MMD computation is the one presented in Lemma 14 in the paper linked above.

Results

The following Colab notebooks show runs of the current and updated SCVI models along with training curves and training runtimes: notebook1 for the current model, notebook2 for the updated model.

Changes to poetry.lock

This change-list also includes a minor change to the poetry.lock file that updates the version of the llvmlite package. Currently Poetry fails dependency resolution with a SolverProblemError because the declared numba and llvmlite versions do not match. In fact, the numba 0.51.2 release notes (https://pypi.org/project/numba/0.51.2/) declare that it is only compatible with llvmlite 0.34.*, which mismatches the version of llvmlite declared currently in the .lock file (0.35.0rc2).

@watiss watiss changed the title SCVI+MMD: Variable-Strength Batch Correction with scVI scVI+MMD: Variable-Strength Batch Correction with scVI May 17, 2021
assert len(model.history["mmd_loss_validation"]) == 1
assert not np.isnan(model.history["mmd_loss_train"].values[0][0])
assert not np.isnan(model.history["mmd_loss_validation"].values[0][0])
model.get_mmd_loss()
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I understand, this call does not look like it is actually doing validation. I followed the same pattern as in the rest of this test (for example see get_elbo, a bit further above), but I am curious how these calls (such as get_elbo, get_marginal_ll, etc.) actually perform validation (if they do).

Comment on lines 397 to 400
z1_even = z1[: batch_size - 1 : 2, :]
z1_odd = z1[1:batch_size:2, :]
z2_even = z2[: batch_size - 1 : 2, :]
z2_odd = z2[1:batch_size:2, :]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to take every other, as the order in this tensor is random anyway, so you can like take the first half and the second half

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. I made that change. Although, I’m curious: If the order is random, shouldn’t it anyway not matter which pairing scheme we use, whether it is every other or pairs from first/second halves?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants