Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Example nb n bugfix #80

Merged
merged 2 commits into from
Jun 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
681 changes: 558 additions & 123 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions src/deepdiagnostics/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ def parser():
# List of metrics (cannot supply specific kwargs)
parser.add_argument(
"--metrics",
nargs="?",
default=list(Defaults["metrics"].keys()),
nargs="+",
default=[],
choices=Metrics.keys(),
help="List of metrics to run. To not run any, supply `--metrics `"
help="List of metrics to run."
)

# List of plots
parser.add_argument(
"--plots",
nargs="?",
default=list(Defaults["plots"].keys()),
nargs="+",
default=[],
choices=Plots.keys(),
help="List of plots to run. To not run any, supply `--plots `"
help="List of plots to run."

)

Expand Down Expand Up @@ -109,7 +109,7 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()
Metrics[metrics_name](model, data, save=True)(**metrics_args)

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
Expand Down
8 changes: 7 additions & 1 deletion src/deepdiagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,14 @@
from deepdiagnostics.metrics.coverage_fraction import CoverageFraction
from deepdiagnostics.metrics.local_two_sample import LocalTwoSampleTest as LC2ST

def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2


Metrics = {
"": lambda **kwargs: None,
"": void,
CoverageFraction.__name__: CoverageFraction,
AllSBC.__name__: AllSBC,
"LC2ST": LC2ST
Expand Down
8 changes: 7 additions & 1 deletion src/deepdiagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,14 @@
from deepdiagnostics.plots.parity import Parity
from deepdiagnostics.plots.predictive_prior_check import PriorPC

def void(*args, **kwargs):
def void2(*args, **kwargs):
return None
return void2


Plots = {
"": lambda **kwargs: None,
"": void,
CDFRanks.__name__: CDFRanks,
CoverageFraction.__name__: CoverageFraction,
Ranks.__name__: Ranks,
Expand Down
5 changes: 2 additions & 3 deletions src/deepdiagnostics/plots/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ def __init__(
"plots_common", "default_colorway", raise_exception=False
)

if save:
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)

if self.out_dir is not None:
if self.out_dir is not None and self.save:
if not os.path.exists(os.path.dirname(self.out_dir)):
os.makedirs(os.path.dirname(self.out_dir))

Expand Down
2 changes: 1 addition & 1 deletion src/deepdiagnostics/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"common": {
"out_dir": "./DeepDiagnosticsResources/results/",
"temp_config": "./DeepDiagnosticsResources/temp/temp_config.yml",
"sim_location": "deepdiagnosticsResources/simulators",
"sim_location": "./DeepDiagnosticsResources/simulators",
"random_seed": 42,
},
"model": {"model_engine": "SBIModel"},
Expand Down
Loading