Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into two_sample_test
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed May 23, 2024
2 parents d6d6bda + 98b95a9 commit 37893c0
Show file tree
Hide file tree
Showing 16 changed files with 214 additions and 76 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ numpy = "^1.26.4"
matplotlib = "^3.8.3"
tarp = "^0.1.1"
deprecation = "^2.1.0"
scipy = "1.12.0"


[tool.poetry.group.dev.dependencies]
Expand Down
14 changes: 4 additions & 10 deletions src/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,15 +97,9 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
try:
Metrics[metrics_name](model, data, **metrics_args)()
except (NotImplementedError, RuntimeError) as error:
print(f"WARNING - skipping metric {metrics_name} due to error: {error}")
Metrics[metrics_name](model, data, **metrics_args)()

for plot_name, plot_args in plots.items():
try:
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
except (NotImplementedError, RuntimeError) as error:
print(f"WARNING - skipping plot {plot_name} due to error: {error}")
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
**plot_args
)
2 changes: 1 addition & 1 deletion src/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,4 @@ def load_prior(self, prior, prior_kwargs):
return lambda size: choices[prior](**prior_kwargs, size=size)

except KeyError as e:
raise RuntimeError(f"Data missing a prior specification - {e}")
raise RuntimeError(f"Data missing a prior specification - {e}")
3 changes: 1 addition & 2 deletions src/data/h5_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
class H5Data(Data):
def __init__(self, path: str, simulator: Callable):
super().__init__(path, simulator)
self.theta_true = self.get_theta_true()


def _load(self, path):
assert path.split(".")[-1] == "h5", "File extension must be h5"
loaded_data = {}
Expand Down
5 changes: 1 addition & 4 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from metrics.all_sbc import AllSBC
from metrics.coverage_fraction import CoverageFraction
from metrics.local_two_sample import LocalTwoSampleTest


_all = [CoverageFraction, AllSBC, LocalTwoSampleTest]
Metrics = {m.__name__: m for m in _all}
Metrics = {CoverageFraction.__name__: CoverageFraction, AllSBC.__name__: AllSBC}
2 changes: 1 addition & 1 deletion src/metrics/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _cross_eval_score(self, p, q, x_p, x_q, classifier, classifier_kwargs, n_cro
probabilities = []
self.evaluation_data = np.zeros((n_cross_folds, len(next(cv_splits)[1]), self.evaluation_context.shape[-1]))
self.prior_evaluation = np.zeros_like(p)

kf = KFold(n_splits=n_cross_folds, shuffle=True, random_state=42)
cv_splits = kf.split(p)
for cross_trial, (train_index, val_index) in enumerate(cv_splits):
Expand Down
4 changes: 1 addition & 3 deletions src/models/sbi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def sample_posterior(self, n_samples: int, y_true): # TODO typing
def predict_posterior(self, data):
posterior_samples = self.sample_posterior(data.y_true)
posterior_predictive_samples = data.simulator(
data.get_theta_true(), posterior_samples
data.theta_true(), posterior_samples
)
return posterior_predictive_samples


9 changes: 6 additions & 3 deletions src/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from plots.cdf_ranks import CDFRanks
from plots.coverage_fraction import CoverageFraction
from plots.ranks import Ranks
from plots.local_two_sample import LocalTwoSampleTest
from plots.tarp import TARP

_all = [CoverageFraction, CDFRanks, Ranks, LocalTwoSampleTest, TARP]
Plots = {m.__name__: m for m in _all}
Plots = {
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
TARP.__name__: TARP,
}
5 changes: 2 additions & 3 deletions src/plots/local_two_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def probability_intensity(self, subplot, features, n_bins=20):
colors = plt.get_cmap(self.colorway)

weights = np.empty((n_bins, n_bins)) * np.nan
print(weights)
for i in range(n_bins):
for j in range(n_bins):
local_and = np.logical_and(eval_bins_dim_1==i, eval_bins_dim_2==j)
Expand All @@ -151,7 +150,7 @@ def probability_intensity(self, subplot, features, n_bins=20):
edgecolor="none",
)
subplot.add_patch(rect)


def _plot(self,
use_intensity_plot:bool=True,
Expand Down Expand Up @@ -254,4 +253,4 @@ def __call__(self, **plot_args) -> None:
except NotImplementedError:
pass

self._plot(**plot_args)
self._plot(**plot_args)
11 changes: 2 additions & 9 deletions src/plots/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(

self.model = model
self._common_settings()
self._plot_settings()
self.plot_name = self._plot_name()

def _plot_name(self):
Expand Down Expand Up @@ -76,14 +77,6 @@ def _finish(self):
plt.cla()

def __call__(self, **plot_args) -> None:
try:
self._data_setup()
except NotImplementedError:
pass
try:
self._plot_settings()
except NotImplementedError:
pass

self._data_setup()
self._plot(**plot_args)
self._finish()
10 changes: 1 addition & 9 deletions src/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@
"data_engine": "H5Data",
"prior":"normal",
"prior_kwargs": None,
"simulator_kwargs": None,
"prior": "normal",
"prior_kwargs":{}

"simulator_kwargs": None,
},
"plots_common": {
"axis_spines": False,
Expand All @@ -29,7 +26,6 @@
"CDFRanks": {},
"Ranks": {"num_bins": None},
"CoverageFraction": {},
"LocalTwoSampleTest":{},
"TARP": {
"coverage_sigma": 3 # How many sigma to show coverage over
},
Expand All @@ -43,9 +39,5 @@
"metrics": {
"AllSBC": {},
"CoverageFraction": {},
"LocalTwoSampleTest":{
"linear_classifier":"MLP",
"classifier_kwargs":{"alpha":0, "max_iter":2500}
}
},
}
11 changes: 0 additions & 11 deletions src/utils/plotting_utils.py

This file was deleted.

4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@


class MockSimulator(Simulator):
def generate_context(self, n_samples=None) -> np.ndarray:
return np.linspace(0, 100, 101)
def generate_context(self, n_samples: int) -> np.ndarray:
return np.linspace(0, 100, n_samples)

def simulate(self, theta: np.ndarray, context_samples: np.ndarray) -> np.ndarray:
thetas = np.atleast_2d(theta)
Expand Down
185 changes: 185 additions & 0 deletions tests/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import sys
import pytest
import torch
import numpy as np
import sbi
import os

# flake8: noqa
#sys.path.append("..")
print(sys.path)
from scripts.evaluate import Diagnose_static, Diagnose_generative
from scripts.io import ModelLoader
#from src.scripts import evaluate


"""
"""


"""
Test the evaluate module
"""


@pytest.fixture
def diagnose_static_instance():
return Diagnose_static()

@pytest.fixture
def diagnose_generative_instance():
return Diagnose_generative()


@pytest.fixture
def posterior_generative_sbi_model():
# create a temporary directory for the saved model
#dir = "savedmodels/sbi/"
#os.makedirs(dir)

# now save the model
low_bounds = torch.tensor([0, -10])
high_bounds = torch.tensor([10, 10])

prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds)

posterior = sbi.inference.base.infer(simulator, prior, "SNPE", num_simulations=10000)

# Provide the posterior to the tests
yield prior, posterior

# Teardown: Remove the temporary directory and its contents
#shutil.rmtree(dataset_dir)

@pytest.fixture
def setup_plot_dir():
# create a temporary directory for the saved model
dir = "tests/plots/"
os.makedirs(dir)
yield dir

def simulator(thetas): # , percent_errors):
# convert to numpy array (if tensor):
thetas = np.atleast_2d(thetas)
# Check if the input has the correct shape
if thetas.shape[1] != 2:
raise ValueError(
"Input tensor must have shape (n, 2) \
where n is the number of parameter sets."
)

# Unpack the parameters
if thetas.shape[0] == 1:
# If there's only one set of parameters, extract them directly
m, b = thetas[0, 0], thetas[0, 1]
else:
# If there are multiple sets of parameters, extract them for each row
m, b = thetas[:, 0], thetas[:, 1]
x = np.linspace(0, 100, 101)
rs = np.random.RandomState() # 2147483648)#
# I'm thinking sigma could actually be a function of x
# if we want to get fancy down the road
# Generate random noise (epsilon) based
# on a normal distribution with mean 0 and standard deviation sigma
sigma = 5
ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0]))

# Initialize an empty array to store the results for each set of parameters
y = np.zeros((len(x), thetas.shape[0]))
for i in range(thetas.shape[0]):
m, b = thetas[i, 0], thetas[i, 1]
y[:, i] = m * x + b + ε[:, i]
return torch.Tensor(y.T)


def test_generate_sbc_samples(diagnose_generative_instance,
posterior_generative_sbi_model):
# Mock data
#low_bounds = torch.tensor([0, -10])
#high_bounds = torch.tensor([10, 10])

#prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds)
prior, posterior = posterior_generative_sbi_model
#inference_instance # provide a mock posterior object
simulator_test = simulator # provide a mock simulator function
num_sbc_runs = 1000
num_posterior_samples = 1000

# Generate SBC samples
thetas, ys, ranks, dap_samples = diagnose_generative_instance.generate_sbc_samples(
prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples
)

# Add assertions based on the expected behavior of the method


def test_run_all_sbc(diagnose_generative_instance,
posterior_generative_sbi_model,
setup_plot_dir):
labels_list = ["$m$", "$b$"]
colorlist = ["#9C92A3", "#0F5257"]

prior, posterior = posterior_generative_sbi_model
simulator_test = simulator # provide a mock simulator function

save_path = setup_plot_dir

diagnose_generative_instance.run_all_sbc(
prior,
posterior,
simulator_test,
labels_list,
colorlist,
num_sbc_runs=1_000,
num_posterior_samples=1_000,
samples_per_inference=1_000,
plot=False,
save=True,
path=save_path,
)
# Check if PDF files were saved
assert os.path.exists(save_path), f"No 'plots' folder found at {save_path}"

# List all files in the directory
files_in_directory = os.listdir(save_path)

# Check if at least one PDF file is present
pdf_files = [file for file in files_in_directory if file.endswith(".pdf")]
assert pdf_files, "No PDF files found in the 'plots' folder"

# We expect the pdfs to exist in the directory
expected_pdf_files = ["sbc_ranks.pdf", "sbc_ranks_cdf.pdf", "coverage.pdf"]
for expected_file in expected_pdf_files:
assert (
expected_file in pdf_files
), f"Expected PDF file '{expected_file}' not found"


"""
def test_sbc_statistics(diagnose_instance):
# Mock data
ranks = # provide mock ranks
thetas = # provide mock thetas
dap_samples = # provide mock dap_samples
num_posterior_samples = 1000
# Calculate SBC statistics
check_stats = diagnose_instance.sbc_statistics(
ranks, thetas, dap_samples, num_posterior_samples
)
# Add assertions based on the expected behavior of the method
def test_plot_1d_ranks(diagnose_instance):
# Mock data
ranks = # provide mock ranks
num_posterior_samples = 1000
labels_list = # provide mock labels_list
colorlist = # provide mock colorlist
# Plot 1D ranks
diagnose_instance.plot_1d_ranks(
ranks, num_posterior_samples, labels_list,
colorlist, plot=False, save=False
)
"""
Loading

0 comments on commit 37893c0

Please sign in to comment.