Skip to content

Commit

Permalink
add optional tf pytest marker, option test_dragonnet
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandrmgservices committed Dec 11, 2023
1 parent e30dd10 commit 71e4909
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
20 changes: 20 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,23 @@ def _generate_data():
return data

yield _generate_data

def pytest_addoption(parser):
parser.addoption(
"--runtf", action="store_true", default=False, help="run tf tests"
)


def pytest_configure(config):
config.addinivalue_line("markers", "tf: mark test as tf to run")


def pytest_collection_modifyitems(config, items):
if config.getoption("--runtf"):
# --runtf given in cli: do not skip tf tests
return
skip_tf = pytest.mark.skip(reason="need --runtf option to run")
for item in items:
if "tf" in item.keywords:
item.add_marker(skip_tf)

3 changes: 2 additions & 1 deletion tests/test_dragonnet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from causalml.inference.tf import DragonNet
from causalml.dataset.regression import simulate_nuisance_and_easy_treatment
import shutil
import pytest


@pytest.mark.tf
def test_save_load_dragonnet():
y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)

Expand Down

0 comments on commit 71e4909

Please sign in to comment.