-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scVI+MMD: Variable-Strength Batch Correction with scVI #2
base: master
Are you sure you want to change the base?
Conversation
assert len(model.history["mmd_loss_validation"]) == 1 | ||
assert not np.isnan(model.history["mmd_loss_train"].values[0][0]) | ||
assert not np.isnan(model.history["mmd_loss_validation"].values[0][0]) | ||
model.get_mmd_loss() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I understand, this call does not look like it is actually doing validation. I followed the same pattern as in the rest of this test (for example see get_elbo, a bit further above), but I am curious how these calls (such as get_elbo, get_marginal_ll, etc.) actually perform validation (if they do).
scvi/module/_vae.py
Outdated
z1_even = z1[: batch_size - 1 : 2, :] | ||
z1_odd = z1[1:batch_size:2, :] | ||
z2_even = z2[: batch_size - 1 : 2, :] | ||
z2_odd = z2[1:batch_size:2, :] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to take every other, as the order in this tensor is random anyway, so you can like take the first half and the second half
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. I made that change. Although, I’m curious: If the order is random, shouldn’t it anyway not matter which pairing scheme we use, whether it is every other or pairs from first/second halves?
This change adds a component onto the VAE objective function which measures the extent to which posterior distributions for different batches are dissimilar. Essentially, this measures the effectiveness of the batch effect correction. This component is added onto the loss with a scaling factor - beta - which can be used to regulate the batch effect correction (note that at its extreme, batch effect correction can cause a potential over-mixing of cells due to enforcing over-similarity of latent distributions).
The newly added component consists of the aggregate Maximum Mean Discrepancy (https://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf), aka MMD, of pairs of sets of samples taken from the latent distributions, grouped per their originator batch. We provide two modes of computation, ‘normal’ and ‘fast’, where the former computes the exact MMD (quadratic runtime in the number of samples) while the latter computes a fast approximation of the MMD (linear runtime in the number of samples). Furthermore, in the spirit of fast approximations, we only compute the MMD for sequential pairs of batches rather than for all pair-wise combinations.
MMD parameters (mode and weight) can be set during model instantiation. The MMD loss is recorded in the history of the trainer and thus is available from the model history once training is complete.
Details of the MMD computation
The formula used for the normal (exact) computation is formula 5 in the paper linked above. We use a Gaussian kernel with gamma=1. X and Y in our case are two sets of samples Z1 and Z2 taken from the latent space where Z1 corresponds to cells originating from batch k and Z2 corresponds to cells originating from batch k’. We carry this out for all sequential (k, k’) pairs and sum over all of them to obtain the aggregate MMD loss, L_mmd.
Finally, the existing SCVI loss (negative ELBO) is updated as follows:
L_scvi-mmd = L_scvi + beta*L_mmd
In fast mode, we proceed the same way as above, with the exception that the formula used for the MMD computation is the one presented in Lemma 14 in the paper linked above.
Results
The following Colab notebooks show runs of the current and updated SCVI models along with training curves and training runtimes: notebook1 for the current model, notebook2 for the updated model.
Changes to poetry.lock
This change-list also includes a minor change to the poetry.lock file that updates the version of the llvmlite package. Currently Poetry fails dependency resolution with a SolverProblemError because the declared numba and llvmlite versions do not match. In fact, the numba 0.51.2 release notes (https://pypi.org/project/numba/0.51.2/) declare that it is only compatible with llvmlite 0.34.*, which mismatches the version of llvmlite declared currently in the .lock file (0.35.0rc2).