From 71e49097ea3e4315585636e7e32168b586e31062 Mon Sep 17 00:00:00 2001 From: Roland Stevenson Date: Mon, 11 Dec 2023 13:48:18 +0100 Subject: [PATCH] add optional tf pytest marker, option test_dragonnet --- tests/conftest.py | 20 ++++++++++++++++++++ tests/test_dragonnet.py | 3 ++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index d9532fac..fca2f627 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,3 +66,23 @@ def _generate_data(): return data yield _generate_data + +def pytest_addoption(parser): + parser.addoption( + "--runtf", action="store_true", default=False, help="run tf tests" + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "tf: mark test as tf to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runtf"): + # --runtf given in cli: do not skip tf tests + return + skip_tf = pytest.mark.skip(reason="need --runtf option to run") + for item in items: + if "tf" in item.keywords: + item.add_marker(skip_tf) + diff --git a/tests/test_dragonnet.py b/tests/test_dragonnet.py index 1d4e613d..98210231 100644 --- a/tests/test_dragonnet.py +++ b/tests/test_dragonnet.py @@ -1,8 +1,9 @@ from causalml.inference.tf import DragonNet from causalml.dataset.regression import simulate_nuisance_and_easy_treatment import shutil +import pytest - +@pytest.mark.tf def test_save_load_dragonnet(): y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000)