Skip to content
This repository has been archived by the owner on Oct 9, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into jhong/normalizationnoclip
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong authored Mar 18, 2024
2 parents eb8270e + 6150961 commit 86e3feb
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions src/mrvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,12 @@ def loss(
inference_outputs["qu"], generative_outputs["pu"]
).sum(-1)
inference_outputs["qeps"]

kl_z = 0.0
eps = inference_outputs["z"] - inference_outputs["z_base"]
if self.z_u_prior:
peps = dist.Normal(0, jnp.exp(self.pz_scale))
kl_z = -peps.log_prob(eps).sum(-1)
else:
kl_z = (
-dist.Normal(inference_outputs["z_base"], jnp.exp(self.z_u_prior_scale))
.log_prob(inference_outputs["z"])
.sum(-1)
if self.z_u_prior_scale is not None
else 0
)

weighted_kl_local = kl_weight * (kl_u + kl_z)
loss = reconstruction_loss + weighted_kl_local
Expand Down

0 comments on commit 86e3feb

Please sign in to comment.