Skip to content

Commit

Permalink
added easier imports to do module
Browse files Browse the repository at this point in the history
  • Loading branch information
bradendubois committed Jan 18, 2022
1 parent f9e69f9 commit 8817652
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 17 deletions.
2 changes: 2 additions & 0 deletions do/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from .deconfounding.API import API as Deconfounding
from .identification.API import API as Identification

from .core.Expression import Expression

class API(Core, Deconfounding, Identification):

def __init__(self):
Expand Down
5 changes: 4 additions & 1 deletion do/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
__all__ = ["API"]
from .API import API

from .core.Expression import Expression
from .core.Variables import Intervention, Outcome, Variable
12 changes: 11 additions & 1 deletion do/core/API.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from pathlib import Path
from typing import Union

from .Expression import Expression
from .Inference import inference, validate
from .Model import Model
from .Model import Model, from_dict, from_path


class API:
Expand All @@ -10,3 +13,10 @@ def validate(self, model: Model) -> bool:

def probability(self, query: Expression, model: Model) -> float:
return inference(query, model)

def instantiate_model(self, model_target: Union[str, Path, dict]) -> Model:

if isinstance(model_target, dict):
return from_dict(model_target)

return from_path(Path(model_target) if isinstance(model_target, str) else model_target)
19 changes: 10 additions & 9 deletions do/core/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,22 @@ def all_variables(self) -> Collection[Variable]:
return self._v.values()


def from_json(path: str) -> Model:
with Path(path).open() as f:
data = json_load(f)
def from_dict(data: dict) -> Model:
return parse_model(data)


def from_yaml(path: str) -> Model:
with Path(path).open() as f:
data = yaml_load(f)
return parse_model(data)
def from_path(p: Path) -> Model:
if not p.exists() or not p.is_file():
raise FileNotFoundError

if p.suffix == ".json":
return parse_model(json_load(p.read_text()))

def from_dict(data: dict) -> Model:
return parse_model(data)
elif p.suffix in [".yml", ".yaml"]:
return parse_model(yaml_load(p.read_text()))

else:
raise Exception(f"Unknown extension for {p}")

def parse_model(data: dict) -> Model:

Expand Down
8 changes: 4 additions & 4 deletions tests/identification/test_Identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,19 @@
"""

def test_NoDeconfounding_Pearl34():
assert within_precision(api.probability(Expression(Outcome("Xj", "xj")), pearl34), api.identification({Outcome("Xj", "xj")}, [], pearl34))
assert within_precision(api.probability(Expression(Outcome("Xj", "xj")), pearl34), api.identification({Outcome("Xj", "xj")}, [], pearl34, False))


def test_NoDeconfounding_Melanoma():
assert within_precision(api.probability(Expression(Outcome("Y", "y")), melanoma), api.identification({Outcome("Y", "y")}, [], melanoma))
assert within_precision(api.probability(Expression(Outcome("Y", "y")), melanoma), api.identification({Outcome("Y", "y")}, [], melanoma, False))


def test_p34():
assert within_precision(api.identification({Outcome("Xj", "xj")}, {Intervention("Xi", "xi")}, pearl34), api.treat(Expression(Outcome("Xj", "xj")), [Intervention("Xi", "xi")], pearl34))
assert within_precision(api.identification({Outcome("Xj", "xj")}, {Intervention("Xi", "xi")}, pearl34, False), api.treat(Expression(Outcome("Xj", "xj")), [Intervention("Xi", "xi")], pearl34))


def test_melanoma():
assert within_precision(api.identification({Outcome("Y", "y")}, {Intervention("X", "x")}, melanoma), api.treat(Expression(Outcome("Y", "y")), [Intervention("X", "x")], melanoma))
assert within_precision(api.identification({Outcome("Y", "y")}, {Intervention("X", "x")}, melanoma, False), api.treat(Expression(Outcome("Y", "y")), [Intervention("X", "x")], melanoma))

def test_proof():
print(api.proof({Outcome("Y", "y")}, {Intervention("X", "x")}, melanoma))
Expand Down
4 changes: 2 additions & 2 deletions tests/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from do.API import API
from do.core.Inference import validate
from do.core.Model import from_yaml
from do.core.Model import from_path


api = API()
Expand All @@ -13,7 +13,7 @@

models = dict()
for file in model_path.iterdir():
models[file.name] = from_yaml(file.absolute())
models[file.name] = from_path(file)

# verify all the models as correct
#for name, model in models.items():
Expand Down

0 comments on commit 8817652

Please sign in to comment.