From 3c0b7a672f7a9fecde60911df1ab55a78f0a13ac Mon Sep 17 00:00:00 2001 From: Roland Stevenson Date: Tue, 28 Nov 2023 10:36:51 +0100 Subject: [PATCH] minimal fix to resolve #707 --- causalml/metrics/visualize.py | 12 ++++++++++++ tests/test_visualize.py | 14 ++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 tests/test_visualize.py diff --git a/causalml/metrics/visualize.py b/causalml/metrics/visualize.py index e6b9b226..12534adf 100644 --- a/causalml/metrics/visualize.py +++ b/causalml/metrics/visualize.py @@ -84,6 +84,10 @@ def get_cumlift( or treatment_effect_col in df.columns ) + assert not ( + (df[[outcome_col, treatment_col, treatment_effect_col]].isnull().values.any()) + ) + df = df.copy() np.random.seed(random_seed) random_cols = [] @@ -219,6 +223,10 @@ def get_qini( or treatment_effect_col in df.columns ) + assert not ( + (df[[outcome_col, treatment_col, treatment_effect_col]].isnull().values.any()) + ) + df = df.copy() np.random.seed(random_seed) random_cols = [] @@ -315,6 +323,8 @@ def get_tmlegain( or p_col in df.columns ) + assert not ((df[[outcome_col, treatment_col, p_col]].isnull().values.any())) + inference_col = [x for x in inference_col if x in df.columns] # Initialize TMLE @@ -421,6 +431,8 @@ def get_tmleqini( or p_col in df.columns ) + assert not ((df[[outcome_col, treatment_col, p_col]].isnull().values.any())) + inference_col = [x for x in inference_col if x in df.columns] # Initialize TMLE diff --git a/tests/test_visualize.py b/tests/test_visualize.py new file mode 100644 index 00000000..f3bc0669 --- /dev/null +++ b/tests/test_visualize.py @@ -0,0 +1,14 @@ +import pandas as pd +import numpy as np +import pytest +from causalml.metrics.visualize import get_cumlift + + +def test_visualize_get_cumlift_errors_on_nan(): + df = pd.DataFrame( + [[0, np.nan, 0.5], [1, np.nan, 0.1], [1, 1, 0.4], [0, 1, 0.3], [1, 1, 0.2]], + columns=["w", "y", "pred"], + ) + + with pytest.raises(Exception): + get_cumlift(df)