From 8817652d5f7acb4bbe80c6f0f370102202f77fbd Mon Sep 17 00:00:00 2001 From: Braden Dubois Date: Mon, 17 Jan 2022 19:21:12 -0600 Subject: [PATCH] added easier imports to do module --- do/API.py | 2 ++ do/__init__.py | 5 ++++- do/core/API.py | 12 +++++++++++- do/core/Model.py | 19 ++++++++++--------- tests/identification/test_Identification.py | 8 ++++---- tests/source.py | 4 ++-- 6 files changed, 33 insertions(+), 17 deletions(-) diff --git a/do/API.py b/do/API.py index 8edca6f..9c580d5 100644 --- a/do/API.py +++ b/do/API.py @@ -2,6 +2,8 @@ from .deconfounding.API import API as Deconfounding from .identification.API import API as Identification +from .core.Expression import Expression + class API(Core, Deconfounding, Identification): def __init__(self): diff --git a/do/__init__.py b/do/__init__.py index fd881ab..6458fad 100644 --- a/do/__init__.py +++ b/do/__init__.py @@ -1 +1,4 @@ -__all__ = ["API"] +from .API import API + +from .core.Expression import Expression +from .core.Variables import Intervention, Outcome, Variable diff --git a/do/core/API.py b/do/core/API.py index b0e9cb0..04695f5 100644 --- a/do/core/API.py +++ b/do/core/API.py @@ -1,6 +1,9 @@ +from pathlib import Path +from typing import Union + from .Expression import Expression from .Inference import inference, validate -from .Model import Model +from .Model import Model, from_dict, from_path class API: @@ -10,3 +13,10 @@ def validate(self, model: Model) -> bool: def probability(self, query: Expression, model: Model) -> float: return inference(query, model) + + def instantiate_model(self, model_target: Union[str, Path, dict]) -> Model: + + if isinstance(model_target, dict): + return from_dict(model_target) + + return from_path(Path(model_target) if isinstance(model_target, str) else model_target) diff --git a/do/core/Model.py b/do/core/Model.py index 868d379..2b7fc1d 100755 --- a/do/core/Model.py +++ b/do/core/Model.py @@ -36,21 +36,22 @@ def all_variables(self) -> Collection[Variable]: return self._v.values() -def from_json(path: str) -> Model: - with Path(path).open() as f: - data = json_load(f) +def from_dict(data: dict) -> Model: return parse_model(data) -def from_yaml(path: str) -> Model: - with Path(path).open() as f: - data = yaml_load(f) - return parse_model(data) +def from_path(p: Path) -> Model: + if not p.exists() or not p.is_file(): + raise FileNotFoundError + if p.suffix == ".json": + return parse_model(json_load(p.read_text())) -def from_dict(data: dict) -> Model: - return parse_model(data) + elif p.suffix in [".yml", ".yaml"]: + return parse_model(yaml_load(p.read_text())) + else: + raise Exception(f"Unknown extension for {p}") def parse_model(data: dict) -> Model: diff --git a/tests/identification/test_Identification.py b/tests/identification/test_Identification.py index 98eafae..1e6720a 100644 --- a/tests/identification/test_Identification.py +++ b/tests/identification/test_Identification.py @@ -18,19 +18,19 @@ """ def test_NoDeconfounding_Pearl34(): - assert within_precision(api.probability(Expression(Outcome("Xj", "xj")), pearl34), api.identification({Outcome("Xj", "xj")}, [], pearl34)) + assert within_precision(api.probability(Expression(Outcome("Xj", "xj")), pearl34), api.identification({Outcome("Xj", "xj")}, [], pearl34, False)) def test_NoDeconfounding_Melanoma(): - assert within_precision(api.probability(Expression(Outcome("Y", "y")), melanoma), api.identification({Outcome("Y", "y")}, [], melanoma)) + assert within_precision(api.probability(Expression(Outcome("Y", "y")), melanoma), api.identification({Outcome("Y", "y")}, [], melanoma, False)) def test_p34(): - assert within_precision(api.identification({Outcome("Xj", "xj")}, {Intervention("Xi", "xi")}, pearl34), api.treat(Expression(Outcome("Xj", "xj")), [Intervention("Xi", "xi")], pearl34)) + assert within_precision(api.identification({Outcome("Xj", "xj")}, {Intervention("Xi", "xi")}, pearl34, False), api.treat(Expression(Outcome("Xj", "xj")), [Intervention("Xi", "xi")], pearl34)) def test_melanoma(): - assert within_precision(api.identification({Outcome("Y", "y")}, {Intervention("X", "x")}, melanoma), api.treat(Expression(Outcome("Y", "y")), [Intervention("X", "x")], melanoma)) + assert within_precision(api.identification({Outcome("Y", "y")}, {Intervention("X", "x")}, melanoma, False), api.treat(Expression(Outcome("Y", "y")), [Intervention("X", "x")], melanoma)) def test_proof(): print(api.proof({Outcome("Y", "y")}, {Intervention("X", "x")}, melanoma)) diff --git a/tests/source.py b/tests/source.py index 10b6078..eef2ba7 100644 --- a/tests/source.py +++ b/tests/source.py @@ -2,7 +2,7 @@ from do.API import API from do.core.Inference import validate -from do.core.Model import from_yaml +from do.core.Model import from_path api = API() @@ -13,7 +13,7 @@ models = dict() for file in model_path.iterdir(): - models[file.name] = from_yaml(file.absolute()) + models[file.name] = from_path(file) # verify all the models as correct #for name, model in models.items():