Skip to content

Commit

Permalink
Update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
rkansal47 authored Oct 26, 2021
1 parent 4a1a420 commit 04f31f8
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ Code for Kansal et. al., *Particle Cloud Generation with Message Passing Generat

## Overview

This repository contains PyTorch code for the message passing GAN (MPGAN) [model](mpgan/model.py), as well as scripts for [training](train.py) the models from scratch, [generating](gen_samples.py) and [plotting](save_outputs.py) the particle clouds.
We include also [weights](final_models) of the fully trained models discussed in the paper.
This repository contains PyTorch code for the message passing GAN (MPGAN) [model](mpgan/model.py), as well as scripts for [training](train.py) the models from scratch, [generating](gen.py) and [plotting](plotting.py) the particle clouds.
We include also [weights](trained_models) of fully trained models discussed in the paper.

Additionally, we release the standalone [JetNet](https://github.com/rkansal47/JetNet) library, which provides a PyTorch Dataset class for our JetNet dataset, as well as implementations of the evaluation metrics discussed in the paper.
Additionally, we release the standalone [JetNet](https://github.com/rkansal47/JetNet) library, which provides a PyTorch Dataset class for our JetNet dataset, implementations of the evaluation metrics discussed in the paper, and some more useful utilities for development in machine learning + jets.

## Dependencies

Expand All @@ -19,7 +19,7 @@ Additionally, we release the standalone [JetNet](https://github.com/rkansal47/Je
#### Training, Plotting, Evaluation

- `torch >= 1.8.0`
- `jetnet >= 0.0.3`
- `jetnet >= 0.1.0`
- `numpy >= 1.21.0`
- `matplotlib`
- `mplhep`
Expand All @@ -35,7 +35,11 @@ A Docker image containing all necessary libraries can be found [here](https://gi

## Training

Start training with `python train.py --name test_model --jets g [args]`.
Start training with:

```python
python train.py --name test_model --jets g [args]
```

By default, model parameters, figures of particle and jet features, and plots of the training losses and evaluation metrics over time will be saved every five epochs in an automatically created `outputs/[name]` directory.

Expand All @@ -45,9 +49,10 @@ Some notes:
- Run `python train.py --help` or look at [setup_training.py](setup_training.py) for a full list of arguments.


## Generation (Not finalized!)
## Generation

Pre-trained models can be used for data generation using the [gen_samples.py](gen_samples.py) script.
By default it generates samples for all the models listed in the [final_models](final_models) directory and saves them in numpy format as `final_models/[model name]/samples.npy`.
Samples for your own pre-trained models can be generated by adding the training args and model weights in the `final_models` directory with the same format and running `gen_samples.py`.
Pre-trained generators with saved state dictionaries and arguments can be used to generate samples with, for example:

```python
python gen.py --G-state-dict trained_models/mp_g/G_best_epoch.pt --G-args trained_models/mp_g/args.txt --num-samples 50,000 --output-file trained_models/mp_g/gen_jets.npy
```

0 comments on commit 04f31f8

Please sign in to comment.