Skip to content

Commit

Permalink
Merge pull request #7 from rkansal47/development
Browse files Browse the repository at this point in the history
Development
  • Loading branch information
rkansal47 authored Oct 26, 2021
2 parents 8a85272 + 04f31f8 commit 4e2ec94
Show file tree
Hide file tree
Showing 164 changed files with 2,797 additions and 756 deletions.
16 changes: 8 additions & 8 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
*
*/**
!final_models
!final_models/**
final_models/*/samples*.npy
!evaluation
!evaluation/**

!mpgan
!mpgan/**
!jetnet
!jetnet/**
!datasets/README.md
!ext_models
!ext_models/**

!trained_models
!trained_models/*
trained_models/*/**
!trained_models/*/args.txt
!trained_models/*/G_best_epoch.pt

!*.py
!LICENSE
!README.md
!Dockerfile
!.gitignore

setup.py
**/__pycache__
**/.DS_Store
Expand Down
2 changes: 2 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,7 @@ RUN pip install qpth cvxpy

RUN pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.9.0+cu102.html

RUN pip install jetnet

# Set the default command to python3.
CMD ["python3"]
24 changes: 14 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ Code for Kansal et. al., *Particle Cloud Generation with Message Passing Generat

## Overview

This repository contains PyTorch code for the message passing GAN (MPGAN) [model](mpgan/model.py), as well as scripts for [training](train.py) the models from scratch, [generating](gen_samples.py) and [plotting](save_outputs.py) the particle clouds.
We include also [weights](final_models) of the fully trained models discussed in the paper.
This repository contains PyTorch code for the message passing GAN (MPGAN) [model](mpgan/model.py), as well as scripts for [training](train.py) the models from scratch, [generating](gen.py) and [plotting](plotting.py) the particle clouds.
We include also [weights](trained_models) of fully trained models discussed in the paper.

Additionally, we release the standalone [JetNet](https://github.com/rkansal47/JetNet) library, which provides a PyTorch Dataset class for our JetNet dataset, as well as implementations of the evaluation metrics discussed in the paper.
Additionally, we release the standalone [JetNet](https://github.com/rkansal47/JetNet) library, which provides a PyTorch Dataset class for our JetNet dataset, implementations of the evaluation metrics discussed in the paper, and some more useful utilities for development in machine learning + jets.

## Dependencies

Expand All @@ -19,7 +19,7 @@ Additionally, we release the standalone [JetNet](https://github.com/rkansal47/Je
#### Training, Plotting, Evaluation

- `torch >= 1.8.0`
- `jetnet >= 0.0.3`
- `jetnet >= 0.1.0`
- `numpy >= 1.21.0`
- `matplotlib`
- `mplhep`
Expand All @@ -29,14 +29,17 @@ Additionally, we release the standalone [JetNet](https://github.com/rkansal47/Je
- `torch`
- `torch_geometric`

### TODO: add JetNet to Docker image.

A Docker image containing all necessary libraries can be found [here](https://gitlab-registry.nautilus.optiputer.net/raghsthebest/mnist-graph-gan:latest) ([Dockerfile](Dockerfile)).


## Training

Start training with `python train.py --name test_model --jets g [args]`.
Start training with:

```python
python train.py --name test_model --jets g [args]
```

By default, model parameters, figures of particle and jet features, and plots of the training losses and evaluation metrics over time will be saved every five epochs in an automatically created `outputs/[name]` directory.

Expand All @@ -46,9 +49,10 @@ Some notes:
- Run `python train.py --help` or look at [setup_training.py](setup_training.py) for a full list of arguments.


## Generation (Not finalized!)
## Generation

Pre-trained models can be used for data generation using the [gen_samples.py](gen_samples.py) script.
By default it generates samples for all the models listed in the [final_models](final_models) directory and saves them in numpy format as `final_models/[model name]/samples.npy`.
Samples for your own pre-trained models can be generated by adding the training args and model weights in the `final_models` directory with the same format and running `gen_samples.py`.
Pre-trained generators with saved state dictionaries and arguments can be used to generate samples with, for example:

```python
python gen.py --G-state-dict trained_models/mp_g/G_best_epoch.pt --G-args trained_models/mp_g/args.txt --num-samples 50,000 --output-file trained_models/mp_g/gen_jets.npy
```
109 changes: 109 additions & 0 deletions correlation_plots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import numpy as np
import matplotlib.pyplot as plt
import mplhep as hep


model = "20_t/"
loss_dir = f"outputs/{model}/losses/"
plots_dir = "plots/correlation_plots/"

loss_keys = ["fpnd", "mmd", "coverage"]

losses = {}

for key in loss_keys:
losses[key] = np.loadtxt(loss_dir + key + ".txt")

losses["w1m"] = np.loadtxt(loss_dir + "w1m.txt")[:, 0]
losses["w1p"] = np.mean(np.loadtxt(loss_dir + "w1p.txt")[:, :3], axis=1)
losses["w1efp"] = np.mean(np.loadtxt(loss_dir + "w1efp.txt")[:, :5], axis=1)


def correlation_plot(xkey, ykey, xlabel, ylabel, range, scilimits=False):
plt.rcParams.update({"font.size": 16})
plt.style.use(hep.style.CMS)

fig = plt.figure(figsize=(12, 10))
h = plt.hist2d(losses[xkey], losses[ykey], bins=50, range=range, cmap="jet")
if scilimits:
plt.ticklabel_format(axis="y", scilimits=(0, 0), useMathText=True)
c = plt.colorbar(h[3])
c.set_label("Number of batches")
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(f"{xlabel} vs {ylabel} Correlation")
plt.savefig(f"{plots_dir}/{xkey}v{ykey}.pdf", bbox_inches="tight")
plt.show()


correlation_plot("w1m", "fpnd", "W1-M", "FPND", [[0, 0.01], [0, 10]])
correlation_plot("w1m", "w1efp", "W1-M", "W1-EFP", [[0, 0.01], [0, 0.00025]], True)
correlation_plot("w1m", "w1p", "W1-M", "W1-P", [[0, 0.01], [0, 0.005]])
correlation_plot("w1p", "fpnd", "W1-P", "FPND", [[0, 0.005], [0, 10]])
correlation_plot("w1m", "mmd", "W1-M", "MMD", [[0, 0.01], [0, 0.1]])
correlation_plot("w1m", "coverage", "W1-M", "Coverage", [[0, 0.01], [0, 1]])


fig = plt.figure(figsize=(12, 10))
h = plt.hist2d(losses["w1m"], losses["fpnd"], bins=50, range=[[0, 0.02], [0, 50]], cmap="jet")
c = plt.colorbar(h[3])
c.set_label("Number of batches")
plt.xlabel("W1-M")
plt.ylabel("FPND")
plt.title("W1-M vs FPND Correlation")
plt.savefig(f"{plots_dir}/w1mvfpnd.pdf", bbox_inches="tight")


fig = plt.figure(figsize=(12, 10))
h = plt.hist2d(losses["w1m"], losses["w1efp"], bins=50, range=[[0, 0.015], [0, 0.0005]], cmap="jet")
plt.ticklabel_format(axis="y", scilimits=(0, 0), useMathText=True)
c = plt.colorbar(h[3])
c.set_label("Number of batches")
plt.xlabel("W1-M")
plt.ylabel("W1-EFP")
plt.title("W1-M vs W1-EFP Correlation")
plt.savefig(f"{plots_dir}/w1mvw1efp.pdf", bbox_inches="tight")


fig = plt.figure(figsize=(12, 10))
h = plt.hist2d(losses["w1m"], losses["w1p"], bins=50, range=[[0, 0.02], [0, 0.01]], cmap="jet")
# plt.ticklabel_format(axis='y', scilimits=(0, 0), useMathText=True)
c = plt.colorbar(h[3])
c.set_label("Number of batches")
plt.xlabel("W1-M")
plt.ylabel("W1-P")
plt.title("W1-M vs W1-P Correlation")
plt.savefig(f"{plots_dir}/w1mvw1p.pdf", bbox_inches="tight")


fig = plt.figure(figsize=(12, 10))
h = plt.hist2d(losses["w1p"], losses["fpnd"], bins=50, range=[[0, 0.01], [0, 50]], cmap="jet")
# plt.ticklabel_format(axis='y', scilimits=(0, 0), useMathText=True)
c = plt.colorbar(h[3])
c.set_label("Number of batches")
plt.xlabel("W1-P")
plt.ylabel("FPND")
plt.title("W1-P vs FPND Correlation")
plt.savefig(f"{plots_dir}/w1pvfpnd.pdf", bbox_inches="tight")


fig = plt.figure(figsize=(12, 10))
h = plt.hist2d(losses["w1m"], losses["mmd"], bins=50, range=[[0, 0.01], [0, 0.1]], cmap="jet")
# plt.ticklabel_format(axis='y', scilimits=(0, 0), useMathText=True)
c = plt.colorbar(h[3])
c.set_label("Number of batches")
plt.xlabel("W1-M")
plt.ylabel("MMD")
plt.title("W1-M vs MMD Correlation")
plt.savefig(f"{plots_dir}/w1mvmmd.pdf", bbox_inches="tight")


fig = plt.figure(figsize=(12, 10))
h = plt.hist2d(losses["w1m"], losses["coverage"], bins=50, range=[[0, 0.01], [0, 1]], cmap="jet")
# plt.ticklabel_format(axis='y', scilimits=(0, 0), useMathText=True)
c = plt.colorbar(h[3])
c.set_label("Number of batches")
plt.xlabel("W1-M")
plt.ylabel("COV")
plt.title("W1-M vs COV Correlation")
plt.savefig(f"{plots_dir}/w1mvcov.pdf", bbox_inches="tight")
Loading

0 comments on commit 4e2ec94

Please sign in to comment.