Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/plot cleanup #138

Merged
merged 38 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
412ff05
Remove old functions from plot2
cweniger Jul 28, 2023
fbb84e5
Consolidate plot2.py and plot.py
cweniger Jul 28, 2023
73f43d6
Merge branch 'master' into feat/plot-cleanup
cweniger Jul 28, 2023
172da05
Add multiple options for corner plot labels.
cweniger Jul 28, 2023
246ffd0
Reorder functions in plot.py
cweniger Jul 29, 2023
e557f2d
Update plot function docstrings.
cweniger Jul 29, 2023
709aae7
Add first grid plot version
cweniger Jul 29, 2023
c0067c2
Fix typo
cweniger Jul 29, 2023
31326e8
Add grid default labeller
cweniger Jul 29, 2023
deb06c8
Renaming of plotting functions
cweniger Jul 30, 2023
43f1c77
propagate cred_level in 2d contour plots
NoemiAM Aug 2, 2023
3ee4703
add true values lines in plotting routines
NoemiAM Aug 2, 2023
926cbd7
add cred_levels in plot_1d
NoemiAM Aug 2, 2023
65579c6
fix plot_posterior
NoemiAM Aug 2, 2023
b2f94ba
update notebooks with new plotting routines names
NoemiAM Aug 2, 2023
43c5fe1
Remove default figsize
cweniger Aug 4, 2023
faabe0d
Renamed functions again
cweniger Aug 4, 2023
ce4fb23
Pretify plot_posterior
cweniger Aug 22, 2023
d7bf2ec
Add plot_pair & some bug fixes
cweniger Aug 23, 2023
fc3ad5e
Update notebook plotting routines
cweniger Aug 24, 2023
f030245
Black
cweniger Sep 1, 2023
2ad1de1
Clean up plot.py
cweniger Sep 1, 2023
452e3b2
Docstring corrections
cweniger Sep 1, 2023
83afc6c
Clean up plot.py
cweniger Sep 2, 2023
f44f1c3
Clean up mass.py
cweniger Sep 2, 2023
97b24ca
Remove histogram.py
cweniger Sep 2, 2023
103c54c
Deprecated files
cweniger Sep 2, 2023
b15bc2e
Bugfix
cweniger Sep 2, 2023
d48ab4b
Merge branch 'rc0.4.5' into feat/plot-cleanup
cweniger Sep 4, 2023
9ed682b
Fix
cweniger Sep 4, 2023
315ec6e
Add `smooth_prior` option to plotting routines.
cweniger Sep 4, 2023
502c446
Update notebook
cweniger Sep 4, 2023
13477d9
Add smooth prior flag to plot_pair
cweniger Sep 12, 2023
4fb1860
Updates some notebooks
cweniger Sep 15, 2023
c23d040
- Verbose reloading of modules
cweniger Sep 15, 2023
b4e3645
Update Hyperparameter notebook
cweniger Sep 15, 2023
f911b0a
More fixes
cweniger Sep 15, 2023
bcd97b6
Add warning to non-updated notebooks.
cweniger Sep 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
764 changes: 679 additions & 85 deletions notebooks/00 - Swyft in 15 minutes.ipynb

Large diffs are not rendered by default.

95 changes: 41 additions & 54 deletions notebooks/0A - SwyftModule.ipynb

Large diffs are not rendered by default.

498 changes: 443 additions & 55 deletions notebooks/0B - Multi-dimensional posteriors.ipynb

Large diffs are not rendered by default.

201 changes: 171 additions & 30 deletions notebooks/0C - Linear regression.ipynb

Large diffs are not rendered by default.

242 changes: 181 additions & 61 deletions notebooks/0D - Data summaries.ipynb

Large diffs are not rendered by default.

269 changes: 217 additions & 52 deletions notebooks/0E - Coverage tests.ipynb

Large diffs are not rendered by default.

663 changes: 201 additions & 462 deletions notebooks/0F - Hyper parameters during training.ipynb

Large diffs are not rendered by default.

62 changes: 35 additions & 27 deletions notebooks/0G - Simulator and Graphical models.ipynb

Large diffs are not rendered by default.

476 changes: 308 additions & 168 deletions notebooks/0H - Truncation and bounds.ipynb

Large diffs are not rendered by default.

129 changes: 87 additions & 42 deletions notebooks/0I - Model comparison.ipynb

Large diffs are not rendered by default.

349 changes: 313 additions & 36 deletions notebooks/0J - ZarrStore and Parallel Simulations.ipynb

Large diffs are not rendered by default.

17 changes: 9 additions & 8 deletions notebooks/1A - Image analysis and rings.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion notebooks/1B - An MNIST example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"source": [
"## Example B - MNIST & CNN\n",
"\n",
"**(works only with Swyft 0.4.4; to be updated soon)**\n",
"\n",
"Authors: Noemi Anau Montel, James Alvey, Christoph Weniger\n",
"\n",
"Last update: 27 April 2023"
Expand Down Expand Up @@ -2084,7 +2086,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.13"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
Expand Down
16 changes: 8 additions & 8 deletions notebooks/1C - Lotka-Volterra.ipynb

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions swyft/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ class SwyftDataModule(pl.LightningDataModule):

Args:
data: Simulation data
lenghts: List of number of samples used for [training, validation, testing].
fractions: Fraction of samples used for [training, validation, testing].
val_fraction: Fraction of data used for validation.
batch_size: Minibatch size.
num_workers: Number of workers for dataloader.
shuffle: Shuffle training data.
Expand All @@ -43,18 +42,26 @@ class SwyftDataModule(pl.LightningDataModule):
def __init__(
self,
data,
lengths: Union[Sequence[int], None] = None,
fractions: Union[Sequence[float], None] = None,
#lengths: Union[Sequence[int], None] = None,
#fractions: Union[Sequence[float], None] = None,
val_fraction: float = 0.2,
batch_size: int = 32,
num_workers: int = 0,
shuffle: bool = False,
on_after_load_sample: Optional[callable] = None,
):
super().__init__()
self.data = data
# TODO: Clean up codes
lengths = None
fractions = [1-val_fraction, val_fraction]
if lengths is not None and fractions is None:
assert len(lengths) == 2, "SwyftDataModule only provides training and validation data."
lengths = [lengths[0], lenghts[1], 0]
self.lengths = lengths
elif lengths is None and fractions is not None:
assert len(fractions) == 2, "SwyftDataModule only provides training and validation data."
fractions = [fractions[0], fractions[1], 0]
self.lengths = self._get_lengths(fractions, len(data))
else:
raise ValueError("Either lenghts or fraction must be set, but not both.")
Expand Down Expand Up @@ -114,13 +121,14 @@ def val_dataloader(self):
return dataloader

def test_dataloader(self):
dataloader = torch.utils.data.DataLoader(
self.dataset_test,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
)
return dataloader
return
# dataloader = torch.utils.data.DataLoader(
# self.dataset_test,
# batch_size=self.batch_size,
# shuffle=False,
# num_workers=self.num_workers,
# )
# return dataloader


class SamplesDataset(torch.utils.data.Dataset):
Expand Down
16 changes: 11 additions & 5 deletions swyft/lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

from swyft.lightning.data import *
import swyft.lightning.simulator
from swyft.plot.mass import get_empirical_z_score

import scipy
from scipy.ndimage import gaussian_filter1d, gaussian_filter
Expand All @@ -54,28 +53,30 @@ class SwyftParameterError(Exception):
############################


def _pdf_from_weighted_samples(v, w, bins=50, smooth=0, v_aux=None):
def _pdf_from_weighted_samples(v, w, bins=50, smooth=0, smooth_prior=False):
"""Take weighted samples and turn them into a pdf on a grid.

Args:
bins
"""
ndim = v.shape[-1]
if v_aux is None:
if not smooth_prior:
return _weighted_smoothed_histogramdd(v, w, bins=bins, smooth=smooth)
else:
h, xy = _weighted_smoothed_histogramdd(v_aux, None, bins=bins, smooth=smooth)
h, xy = _weighted_smoothed_histogramdd(v, w * 0 + 1, bins=bins, smooth=smooth)
if ndim == 2:
X, Y = np.meshgrid(xy[:, 0], xy[:, 1])
n = len(xy)
out = scipy.interpolate.griddata(
v, w, (X.flatten(), Y.flatten()), method="cubic", fill_value=0.0
).reshape(n, n)
out = out * h.numpy()
return out, xy
elif ndim == 1:
out = scipy.interpolate.griddata(
v[:, 0], w, xy[:, 0], method="cubic", fill_value=0.0
)
out = out * h.numpy()
return out, xy
else:
raise KeyError("Not supported")
Expand Down Expand Up @@ -113,6 +114,7 @@ def get_pdf(
aux=None,
bins: int = 50,
smooth: float = 0.0,
smooth_prior=False,
):
"""Generate binned PDF based on input

Expand All @@ -121,6 +123,7 @@ def get_pdf(
params: Parameter names
bins: Number of bins
smooth: Apply Gaussian smoothing
smooth_prior: Smooth prior instead of posterior

Returns:
np.array, np.array: Returns densities and parameter grid.
Expand All @@ -130,7 +133,9 @@ def get_pdf(
z_aux, _ = get_weighted_samples(aux, params)
else:
z_aux = None
return _pdf_from_weighted_samples(z, w, bins=bins, smooth=smooth, v_aux=z_aux)
return _pdf_from_weighted_samples(
z, w, bins=bins, smooth=smooth, smooth_prior=smooth_prior
)


def _get_weights(logratios, normalize: bool = False):
Expand Down Expand Up @@ -528,4 +533,5 @@ class OnFitEndLoadBestModel:
def on_fit_end(self):
self.best_model_path = self.trainer.checkpoint_callback.best_model_path
checkpoint = torch.load(self.best_model_path)
print("Reloading best model:", self.best_model_path)
self.load_state_dict(checkpoint["state_dict"])
26 changes: 8 additions & 18 deletions swyft/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
# from swyft.plot.constraint import diagonal_constraint, lower_constraint
# from swyft.plot.histogram import corner, hist1d
# from swyft.plot.mass import empirical_z_score_corner, plot_empirical_z_score
# from swyft.plot.violin import violin
from swyft.plot.plot2 import plot_1d, plot_2d, corner, plot_zz, plot_pp
from swyft.plot.plot import (
plot_posterior,
plot_corner,
plot_zz,
plot_pp,
plot_pair,
)

__all__ = [
"plot_1d",
"plot_2d",
"corner",
"plot_zz",
"plot_pp"
# "diagonal_constraint",
# "hist1d",
# "lower_constraint",
# "plot_empirical_z_score",
# "empirical_z_score_corner",
# "violin",
]
__all__ = ["plot_corner", "plot_zz", "plot_pp", "plot_posterior", "plot_pair"]
13 changes: 11 additions & 2 deletions swyft/plot/constraint.py → swyft/plot/deprecated/constraint.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
import matplotlib.patches as mpatches
import matplotlib.path as mpath

from swyft.plot.histogram import split_corner_axes
from swyft.utils.marginals import get_d_dim_marginal_indices

import numpy as np
from typing import Tuple


def _split_corner_axes(axes: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
diag = np.diag(axes)
lower = axes[np.tril(axes, -1).nonzero()]
upper = axes[np.triu(axes, 1).nonzero()]
return lower, diag, upper


def diagonal_constraint(axes, bounds, alpha=0.25):
_, diag, _ = split_corner_axes(axes)
_, diag, _ = _split_corner_axes(axes)
for i, ax in enumerate(diag):
xlim = ax.get_xlim()
ax.axvspan(xlim[0], bounds[i, 0], alpha=alpha)
Expand Down
File renamed without changes.
Loading