Skip to content

Commit

Permalink
Fixing paths for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Jun 24, 2024
1 parent 7b0efb9 commit a1505fe
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 55 deletions.
1 change: 1 addition & 0 deletions src/deepdiagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST

Metrics = {
"": lambda **kwargs: None,
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LC2ST
Expand Down
30 changes: 15 additions & 15 deletions src/deepdiagnostics/metrics/coverage_fraction.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from torch import tensor
from tqdm import tqdm
from typing import Any, Sequence

Expand Down Expand Up @@ -42,7 +41,7 @@ def _collect_data_params(self):

def _run_model_inference(self, samples_per_inference, y_inference):
samples = self.model.sample_posterior(samples_per_inference, y_inference)
return samples
return samples.numpy()

def calculate(self) -> tuple[Sequence, Sequence]:
"""
Expand All @@ -52,19 +51,21 @@ def calculate(self) -> tuple[Sequence, Sequence]:
tuple[Sequence, Sequence]: A tuple of the samples tested (M samples, Samples per inference, N parameters) and the coverage over those samples.
"""
all_samples = np.empty(
(len(self.context), self.samples_per_inference, np.shape(self.thetas)[1])
(self.number_simulations, self.samples_per_inference, np.shape(self.thetas)[1])
)
count_array = []
iterator = enumerate(self.context)
iterator = range(self.number_simulations)
if self.use_progress_bar:
iterator = tqdm(
iterator,
desc="Sampling from the posterior for each observation",
unit=" observation",
)
for y_sample_index, y_sample in iterator:
samples = self._run_model_inference(self.samples_per_inference, y_sample)
all_samples[y_sample_index] = samples
for sample_index in iterator:
context_sample = self.context[self.data.rng.integers(0, len(self.context))]
samples = self._run_model_inference(self.samples_per_inference, context_sample)

all_samples[sample_index] = samples

count_vector = []
# step through the percentile list
Expand All @@ -75,23 +76,22 @@ def calculate(self) -> tuple[Sequence, Sequence]:
# find the percentile for the posterior for this observation
# this is n_params dimensional
# the units are in parameter space
confidence_lower = tensor(
np.percentile(samples.cpu(), percentile_lower, axis=0)
)
confidence_upper = tensor(
np.percentile(samples.cpu(), percentile_upper, axis=0)
)
confidence_lower = np.percentile(samples, percentile_lower, axis=0)

confidence_upper = np.percentile(samples, percentile_upper, axis=0)


# this is asking if the true parameter value
# is contained between the
# upper and lower confidence intervals
# checks separately for each side of the 50th percentile

count = np.logical_and(
confidence_upper - self.thetas[y_sample_index, :] > 0,
self.thetas[y_sample_index, :] - confidence_lower > 0,
confidence_upper - self.thetas[sample_index, :].numpy() > 0,
self.thetas[sample_index, :].numpy() - confidence_lower > 0,
)
count_vector.append(count)

# each time the above is > 0, adds a count
count_array.append(count_vector)

Expand Down
22 changes: 12 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import numpy as np
from deepbench.astro_object import StarObject

from data import H5Data
from data.simulator import Simulator
from models import SBIModel
from utils.config import get_item
from utils.register import register_simulator
from deepdiagnostics.data import H5Data
from deepdiagnostics.data.simulator import Simulator
from deepdiagnostics.models import SBIModel
from deepdiagnostics.utils.config import get_item
from deepdiagnostics.utils.register import register_simulator


class MockSimulator(Simulator):
Expand Down Expand Up @@ -67,7 +67,7 @@ def simulate(self, theta, context_samples: np.ndarray):
return np.array(generated_stars)

@pytest.fixture(autouse=True)
def setUp():
def setUp(result_output):
register_simulator("MockSimulator", MockSimulator)
register_simulator("Mock2DSimulator", Mock2DSimulator)
yield
Expand All @@ -76,10 +76,9 @@ def setUp():
sim_paths = f"{simulator_config_path.strip('/')}/simulators.json"
os.remove(sim_paths)

out_dir = get_item("common", "out_dir", raise_exception=False)
os.makedirs("resources/test_results/", exist_ok=True)
shutil.copytree(out_dir, "resources/test_results/", dirs_exist_ok=True)
shutil.rmtree(out_dir)
shutil.copytree(result_output, "resources/test_results/", dirs_exist_ok=True)
shutil.rmtree(result_output)

@pytest.fixture
def model_path():
Expand All @@ -92,7 +91,10 @@ def data_path():

@pytest.fixture
def result_output():
return "./temp_results/"
path = "./temp_results/"
if not os.path.exists(path):
os.makedirs(path)
return path

@pytest.fixture
def simulator_name():
Expand Down
14 changes: 8 additions & 6 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ def test_parser_args(model_path, data_path, simulator_name):
data_path,
"--simulator",
simulator_name,
"--metrics",
"",
"--plots",
""
]
process = subprocess.run(command)
exit_code = process.returncode
Expand All @@ -20,31 +24,29 @@ def test_parser_args(model_path, data_path, simulator_name):

def test_parser_config(config_factory, model_path, data_path, simulator_name):
config_path = config_factory(
model_path=model_path, data_path=data_path, simulator=simulator_name
model_path=model_path, data_path=data_path, simulator=simulator_name, metrics=[], plots=[]
)
command = ["diagnose", "--config", config_path]
process = subprocess.run(command)
exit_code = process.returncode
assert exit_code == 0


def test_main_no_methods(config_factory, model_path, data_path, simulator_name):
out_dir = "./test_out_dir/"
def test_main_no_methods(config_factory, model_path, data_path, simulator_name, result_output):
config_path = config_factory(
model_path=model_path,
data_path=data_path,
simulator=simulator_name,
plots=[],
metrics=[],
out_dir=out_dir,
metrics=[]
)
command = ["diagnose", "--config", config_path]
process = subprocess.run(command)
exit_code = process.returncode
assert exit_code == 0

# There should be nothing at the outpath
assert os.listdir(out_dir) == []
assert os.listdir(result_output) == []


def test_main_missing_config():
Expand Down
35 changes: 11 additions & 24 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import pytest

from utils.defaults import Defaults
from utils.config import Config
from metrics import (
Metrics,
from deepdiagnostics.utils.config import Config
from deepdiagnostics.metrics import (
CoverageFraction,
AllSBC,
LocalTwoSampleTest
LC2ST
)

@pytest.fixture
Expand All @@ -17,41 +15,30 @@ def metric_config(config_factory):
"samples_per_inference": 10,
"percentiles": [95],
}
config = config_factory(metrics_settings=metrics_settings)
Config(config)
return config_factory(metrics_settings=metrics_settings)


def test_all_defaults(metric_config, mock_model, mock_data):
"""
Ensures each metric has a default set of parameters and is included in the defaults list
Ensures each test can initialize, regardless of the veracity of the output
"""

for metric_name, metric_obj in Metrics.items():
assert metric_name in Defaults["metrics"]
metric_obj(mock_model, mock_data)


def test_coverage_fraction(metric_config, mock_model, mock_data):
Config(metric_config)
coverage_fraction = CoverageFraction(mock_model, mock_data, save=True)
_, coverage = coverage_fraction.calculate()
assert coverage_fraction.output.all() is not None

# TODO Shape of coverage
assert coverage.shape
assert coverage.shape == (1, 2) # One percentile over 2 dimensions of theta.

coverage_fraction = CoverageFraction(mock_model, mock_data, save=True)
coverage_fraction()
assert os.path.exists(f"{coverage_fraction.out_dir}/diagnostic_metrics.json")

def test_all_sbc(metric_config, mock_model, mock_data):
def test_all_sbc(metric_config, mock_model, mock_data):
Config(metric_config)
all_sbc = AllSBC(mock_model, mock_data, save=True)
all_sbc()
assert all_sbc.output is not None
assert os.path.exists(f"{all_sbc.out_dir}/diagnostic_metrics.json")

def test_lc2st(metric_config, mock_model, mock_data):
lc2st = LocalTwoSampleTest(mock_model, mock_data, save=True)
def test_lc2st(metric_config, mock_model, mock_data):
Config(metric_config)
lc2st = LC2ST(mock_model, mock_data, save=True)
lc2st()
assert lc2st.output is not None
assert os.path.exists(f"{lc2st.out_dir}/diagnostic_metrics.json")
Expand Down

0 comments on commit a1505fe

Please sign in to comment.