Skip to content

Commit

Permalink
Rc0.4.5 (#144)
Browse files Browse the repository at this point in the history
* Feat/better defaults (#142)

Updates:
* Modularisation of SwyftModule
* Restructuring of `configure_optimizers` and `configure_callbacks`
* New default is AdamW with early stopping
* Alternatives with AdamW and OneCycleLR and ReduceLROnPlateau exist as well
* Automatic reloading of best model at the end of training is default

* Feat/plot cleanup (#138)

Updates:
* Remove plot2.
* Extended options
* Add `smooth_prior` flag for interpolating likelihoods
* Renaming of plotting routines (plot_posterior, plot_corner, plot_pair)
* Update example notebook

---------

Co-authored-by: NoemiAM <[email protected]>
  • Loading branch information
cweniger and NoemiAM authored Sep 15, 2023
1 parent bd1a71b commit ba1b474
Show file tree
Hide file tree
Showing 26 changed files with 3,625 additions and 2,534 deletions.
764 changes: 679 additions & 85 deletions notebooks/00 - Swyft in 15 minutes.ipynb

Large diffs are not rendered by default.

425 changes: 289 additions & 136 deletions notebooks/0A - SwyftModule.ipynb

Large diffs are not rendered by default.

498 changes: 443 additions & 55 deletions notebooks/0B - Multi-dimensional posteriors.ipynb

Large diffs are not rendered by default.

201 changes: 171 additions & 30 deletions notebooks/0C - Linear regression.ipynb

Large diffs are not rendered by default.

242 changes: 181 additions & 61 deletions notebooks/0D - Data summaries.ipynb

Large diffs are not rendered by default.

269 changes: 217 additions & 52 deletions notebooks/0E - Coverage tests.ipynb

Large diffs are not rendered by default.

663 changes: 201 additions & 462 deletions notebooks/0F - Hyper parameters during training.ipynb

Large diffs are not rendered by default.

62 changes: 35 additions & 27 deletions notebooks/0G - Simulator and Graphical models.ipynb

Large diffs are not rendered by default.

476 changes: 308 additions & 168 deletions notebooks/0H - Truncation and bounds.ipynb

Large diffs are not rendered by default.

129 changes: 87 additions & 42 deletions notebooks/0I - Model comparison.ipynb

Large diffs are not rendered by default.

349 changes: 313 additions & 36 deletions notebooks/0J - ZarrStore and Parallel Simulations.ipynb

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions notebooks/1A - Image analysis and rings.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion notebooks/1B - An MNIST example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"source": [
"## Example B - MNIST & CNN\n",
"\n",
"**(works only with Swyft 0.4.4; to be updated soon)**\n",
"\n",
"Authors: Noemi Anau Montel, James Alvey, Christoph Weniger\n",
"\n",
"Last update: 27 April 2023"
Expand Down Expand Up @@ -2084,7 +2086,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.13"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
16 changes: 8 additions & 8 deletions notebooks/1C - Lotka-Volterra.ipynb

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions swyft/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
"Sample",
"Samples",
"SwyftDataModule",
"OptimizerInit",
"AdamOptimizerInit",
"CoverageSamples",
"estimate_coverage",
"equalize_tensors",
Expand All @@ -31,4 +29,7 @@
"LogRatioEstimator_Gaussian",
"RectBoundSampler",
"LogRatioEstimator_1dim_Gaussian",
"AdamW",
"AdamWOneCycleLR",
"AdamWReduceLROnPlateau",
]
99 changes: 48 additions & 51 deletions swyft/lightning/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@
from swyft.lightning.data import *
from swyft.plot.mass import get_empirical_z_score
from swyft.lightning.utils import (
OptimizerInit,
AdamOptimizerInit,
AdamW,
OnFitEndLoadBestModel,
SwyftParameterError,
_collection_mask,
_collection_flatten,
Expand All @@ -51,42 +51,7 @@
#############


class SwyftModule(pl.LightningModule):
r"""This is the central Swyft LightningModule for handling the training of logratio estimators.
Derived classes are supposed to overwrite the `forward` method in order to implement specific inference tasks.
The attribute `optimizer_init` points to the optimizer initializer (default is `AdamOptimizerInit`).
.. note::
The forward method takes as arguments the sample batches `A` and `B`,
which typically include all sample variables. Joined samples correspond to
A=B, whereas marginal samples correspond to samples A != B.
Example usage:
.. code-block:: python
class MyNetwork(swyft.SwyftModule):
def __init__(self):
self.optimizer_init = AdamOptimizerInit(lr = 1e-4)
self.mlp = swyft.LogRatioEstimator_1dim(4, 4)
def forward(A, B);
x = A['x']
z = A['z']
logratios = self.mlp(x, z)
return logratios
"""

def __init__(self):
super().__init__()
self.optimizer_init = AdamOptimizerInit()

def configure_optimizers(self):
return self.optimizer_init(self.parameters())

class LossAggregationSteps:
def _get_logratios(self, out):
if isinstance(out, dict):
out = {k: v for k, v in out.items() if k[:4] != "aux_"}
Expand All @@ -106,10 +71,14 @@ def _get_logratios(self, out):
logratios = None
return logratios

def validation_step(self, batch, batch_idx):
loss = self._calc_loss(batch, randomized=False)
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss
def _get_aux_losses(self, out):
flattened_out = _collection_flatten(out)
filtered_out = [v for v in flattened_out if isinstance(v, swyft.AuxLoss)]
if len(filtered_out) == 0:
return None
else:
losses = torch.cat([v.loss.unsqueeze(-1) for v in filtered_out], dim=1)
return losses

def _calc_loss(self, batch, randomized=True):
"""Calcualte batch-averaged loss summed over ratio estimators.
Expand Down Expand Up @@ -159,20 +128,16 @@ def _calc_loss(self, batch, randomized=True):

return loss_tot

def _get_aux_losses(self, out):
flattened_out = _collection_flatten(out)
filtered_out = [v for v in flattened_out if isinstance(v, swyft.AuxLoss)]
if len(filtered_out) == 0:
return None
else:
losses = torch.cat([v.loss.unsqueeze(-1) for v in filtered_out], dim=1)
return losses

def training_step(self, batch, batch_idx):
loss = self._calc_loss(batch)
self.log("train_loss", loss, on_step=True, on_epoch=False)
return loss

def validation_step(self, batch, batch_idx):
loss = self._calc_loss(batch, randomized=False)
self.log("val_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
return loss

def test_step(self, batch, batch_idx):
loss = self._calc_loss(batch, randomized=False)
self.log("test_loss", loss, on_epoch=True, on_step=False)
Expand All @@ -184,6 +149,38 @@ def predict_step(self, batch, *args, **kwargs):
return self(A, B)


class SwyftModule(
AdamW, OnFitEndLoadBestModel, LossAggregationSteps, pl.LightningModule
):
r"""This is the central Swyft LightningModule for handling the training of logratio estimators.
Derived classes are supposed to overwrite the `forward` method in order to implement specific inference tasks.
.. note::
The forward method takes as arguments the sample batches `A` and `B`,
which typically include all sample variables. Joined samples correspond to
A=B, whereas marginal samples correspond to samples A != B.
Example usage:
.. code-block:: python
class MyNetwork(swyft.SwyftModule):
def __init__(self):
self.mlp = swyft.LogRatioEstimator_1dim(4, 4)
def forward(A, B);
x = A['x']
z = B['z']
logratios = self.mlp(x, z)
return logratios
"""

def __init__(self):
super().__init__()


#################
# LogRatioSamples
#################
Expand Down
30 changes: 19 additions & 11 deletions swyft/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ class SwyftDataModule(pl.LightningDataModule):
Args:
data: Simulation data
lenghts: List of number of samples used for [training, validation, testing].
fractions: Fraction of samples used for [training, validation, testing].
val_fraction: Fraction of data used for validation.
batch_size: Minibatch size.
num_workers: Number of workers for dataloader.
shuffle: Shuffle training data.
Expand All @@ -43,18 +42,26 @@ class SwyftDataModule(pl.LightningDataModule):
def __init__(
self,
data,
lengths: Union[Sequence[int], None] = None,
fractions: Union[Sequence[float], None] = None,
#lengths: Union[Sequence[int], None] = None,
#fractions: Union[Sequence[float], None] = None,
val_fraction: float = 0.2,
batch_size: int = 32,
num_workers: int = 0,
shuffle: bool = False,
on_after_load_sample: Optional[callable] = None,
):
super().__init__()
self.data = data
# TODO: Clean up codes
lengths = None
fractions = [1-val_fraction, val_fraction]
if lengths is not None and fractions is None:
assert len(lengths) == 2, "SwyftDataModule only provides training and validation data."
lengths = [lengths[0], lenghts[1], 0]
self.lengths = lengths
elif lengths is None and fractions is not None:
assert len(fractions) == 2, "SwyftDataModule only provides training and validation data."
fractions = [fractions[0], fractions[1], 0]
self.lengths = self._get_lengths(fractions, len(data))
else:
raise ValueError("Either lenghts or fraction must be set, but not both.")
Expand Down Expand Up @@ -114,13 +121,14 @@ def val_dataloader(self):
return dataloader

def test_dataloader(self):
dataloader = torch.utils.data.DataLoader(
self.dataset_test,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return dataloader
return
# dataloader = torch.utils.data.DataLoader(
# self.dataset_test,
# batch_size=self.batch_size,
# shuffle=False,
# num_workers=self.num_workers,
# )
# return dataloader


class SamplesDataset(torch.utils.data.Dataset):
Expand Down
1 change: 0 additions & 1 deletion swyft/lightning/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,4 +517,3 @@ def forward(self, a: torch.Tensor, b: torch.Tensor):
lrs = swyft.LogRatioSamples(logratios, a, self.varnames)

return lrs

Loading

0 comments on commit ba1b474

Please sign in to comment.