diff --git a/README.md b/README.md index 85d7e08..a4712c0 100644 --- a/README.md +++ b/README.md @@ -5,10 +5,10 @@ Code for Kansal et. al., *Particle Cloud Generation with Message Passing Generat ## Overview -This repository contains PyTorch code for the message passing GAN (MPGAN) [model](mpgan/model.py), as well as scripts for [training](train.py) the models from scratch, [generating](gen_samples.py) and [plotting](save_outputs.py) the particle clouds. -We include also [weights](final_models) of the fully trained models discussed in the paper. +This repository contains PyTorch code for the message passing GAN (MPGAN) [model](mpgan/model.py), as well as scripts for [training](train.py) the models from scratch, [generating](gen.py) and [plotting](plotting.py) the particle clouds. +We include also [weights](trained_models) of fully trained models discussed in the paper. -Additionally, we release the standalone [JetNet](https://github.com/rkansal47/JetNet) library, which provides a PyTorch Dataset class for our JetNet dataset, as well as implementations of the evaluation metrics discussed in the paper. +Additionally, we release the standalone [JetNet](https://github.com/rkansal47/JetNet) library, which provides a PyTorch Dataset class for our JetNet dataset, implementations of the evaluation metrics discussed in the paper, and some more useful utilities for development in machine learning + jets. ## Dependencies @@ -19,7 +19,7 @@ Additionally, we release the standalone [JetNet](https://github.com/rkansal47/Je #### Training, Plotting, Evaluation - `torch >= 1.8.0` - - `jetnet >= 0.0.3` + - `jetnet >= 0.1.0` - `numpy >= 1.21.0` - `matplotlib` - `mplhep` @@ -35,7 +35,11 @@ A Docker image containing all necessary libraries can be found [here](https://gi ## Training -Start training with `python train.py --name test_model --jets g [args]`. +Start training with: + +```python +python train.py --name test_model --jets g [args] +``` By default, model parameters, figures of particle and jet features, and plots of the training losses and evaluation metrics over time will be saved every five epochs in an automatically created `outputs/[name]` directory. @@ -45,9 +49,10 @@ Some notes: - Run `python train.py --help` or look at [setup_training.py](setup_training.py) for a full list of arguments. -## Generation (Not finalized!) +## Generation -Pre-trained models can be used for data generation using the [gen_samples.py](gen_samples.py) script. -By default it generates samples for all the models listed in the [final_models](final_models) directory and saves them in numpy format as `final_models/[model name]/samples.npy`. -Samples for your own pre-trained models can be generated by adding the training args and model weights in the `final_models` directory with the same format and running `gen_samples.py`. +Pre-trained generators with saved state dictionaries and arguments can be used to generate samples with, for example: +```python +python gen.py --G-state-dict trained_models/mp_g/G_best_epoch.pt --G-args trained_models/mp_g/args.txt --num-samples 50,000 --output-file trained_models/mp_g/gen_jets.npy +```