Skip to content

Commit

Permalink
updated notebook to match new package structure
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed Jun 25, 2024
1 parent 0505569 commit fbb8d4a
Show file tree
Hide file tree
Showing 4 changed files with 404 additions and 138 deletions.
521 changes: 394 additions & 127 deletions notebooks/example.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions src/deepdiagnostics/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,19 @@ def parser():
# List of metrics (cannot supply specific kwargs)
parser.add_argument(
"--metrics",
nargs="?",
default=list(Defaults["metrics"].keys()),
nargs="+",
default=[],
choices=Metrics.keys(),
help="List of metrics to run. To not run any, supply `--metrics `"
help="List of metrics to run."
)

# List of plots
parser.add_argument(
"--plots",
nargs="?",
default=list(Defaults["plots"].keys()),
nargs="+",
default=[],
choices=Plots.keys(),
help="List of plots to run. To not run any, supply `--plots `"
help="List of plots to run."

)

Expand Down Expand Up @@ -109,7 +109,7 @@ def main():
plots = config.get_section("plots", raise_exception=False)

for metrics_name, metrics_args in metrics.items():
Metrics[metrics_name](model, data, **metrics_args)()
Metrics[metrics_name](model, data, save=True)(**metrics_args)

for plot_name, plot_args in plots.items():
Plots[plot_name](model, data, save=True, show=False, out_dir=out_dir)(
Expand Down
5 changes: 2 additions & 3 deletions src/deepdiagnostics/plots/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ def __init__(
"plots_common", "default_colorway", raise_exception=False
)

if save:
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)
self.out_dir = out_dir if out_dir is not None else get_item("common", "out_dir", raise_exception=False)

if self.out_dir is not None:
if self.out_dir is not None and self.save:
if not os.path.exists(os.path.dirname(self.out_dir)):
os.makedirs(os.path.dirname(self.out_dir))

Expand Down
2 changes: 1 addition & 1 deletion src/deepdiagnostics/utils/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"common": {
"out_dir": "./DeepDiagnosticsResources/results/",
"temp_config": "./DeepDiagnosticsResources/temp/temp_config.yml",
"sim_location": "deepdiagnosticsResources/simulators",
"sim_location": "./DeepDiagnosticsResources/simulators",
"random_seed": 42,
},
"model": {"model_engine": "SBIModel"},
Expand Down

0 comments on commit fbb8d4a

Please sign in to comment.