Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reporting KL divergence loss for training step #86

Open
chinmay5 opened this issue Jan 17, 2024 · 0 comments
Open

Reporting KL divergence loss for training step #86

chinmay5 opened this issue Jan 17, 2024 · 0 comments

Comments

@chinmay5
Copy link

Thank you for releasing the code. I am using a custom dataset with 10k graphs. I tried to update the code to include the kl divergence during training to check if there is overfitting on the smaller dataset. While the PosMSE seems fine, the results for E_kl and X_kl always give a nan for the training samples. Can you please tell me if there is something wrong with my approach?

self.train_metrics = torchmetrics.MetricCollection([custom_metrics.PosMSE(), custom_metrics.XKl(), custom_metrics.EKl()])

In my training_step, I invoke
nll, log_dict = self.compute_train_nll_loss(pred, z_t, clean_data=dense_data)

Finally, the method definition is

def compute_train_nll_loss(self, pred, z_t, clean_data):

    node_mask = z_t.node_mask
    t_int = z_t.t_int
    s_int = t_int - 1
    logger_metric = self.train_metrics
    # 1.
    N = node_mask.sum(1).long()
    log_pN = self.node_dist.log_prob(N)

    # 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
    kl_prior = self.kl_prior(clean_data, node_mask)

    # 3. Diffusion loss
    loss_all_t = self.compute_Lt(clean_data, pred, z_t, s_int, node_mask, logger_metric)

    # Combine terms
    nlls = - log_pN + kl_prior + loss_all_t
    # Update NLL metric object and return batch nll
    nll = self.train_nll(nlls)  # Average over the batch

    log_dict = {"train kl prior": kl_prior.mean(),
                "Estimator loss terms": loss_all_t.mean(),
                "log_pn": log_pN.mean(),
                'train_nll': nll}
    return nll, log_dict

Any help would be highly appreciated.

Best,
Chinmay

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant