From ae6c28ed42a826f3d7a3e310855d59175fa1cc22 Mon Sep 17 00:00:00 2001 From: ras44 <9282281+ras44@users.noreply.github.com> Date: Fri, 1 Dec 2023 20:39:58 +0100 Subject: [PATCH] minimal fix to resolve #707 (#720) * minimal fix to resolve #707 * lint * remove .bool() * operate on series not df --------- Co-authored-by: ras44 --- causalml/metrics/visualize.py | 31 ++++++++++++++++++------------- tests/test_visualize.py | 14 ++++++++++++++ 2 files changed, 32 insertions(+), 13 deletions(-) create mode 100644 tests/test_visualize.py diff --git a/causalml/metrics/visualize.py b/causalml/metrics/visualize.py index e6b9b226..011fc521 100644 --- a/causalml/metrics/visualize.py +++ b/causalml/metrics/visualize.py @@ -77,11 +77,13 @@ def get_cumlift( Returns: (pandas.DataFrame): average uplifts of model estimates in cumulative population """ - assert ( - (outcome_col in df.columns) - and (treatment_col in df.columns) - or treatment_effect_col in df.columns + (outcome_col in df.columns and df[outcome_col].notnull().all()) + and (treatment_col in df.columns and df[treatment_col].notnull().all()) + or ( + treatment_effect_col in df.columns + and df[treatment_effect_col].notnull().all() + ) ) df = df.copy() @@ -214,9 +216,12 @@ def get_qini( (pandas.DataFrame): cumulative gains of model estimates in population """ assert ( - (outcome_col in df.columns) - and (treatment_col in df.columns) - or treatment_effect_col in df.columns + (outcome_col in df.columns and df[outcome_col].notnull().all()) + and (treatment_col in df.columns and df[treatment_col].notnull().all()) + or ( + treatment_effect_col in df.columns + and df[treatment_effect_col].notnull().all() + ) ) df = df.copy() @@ -310,9 +315,9 @@ def get_tmlegain( (pandas.DataFrame): cumulative gains of model estimates based of TMLE """ assert ( - (outcome_col in df.columns) - and (treatment_col in df.columns) - or p_col in df.columns + (outcome_col in df.columns and df[outcome_col].notnull().all()) + and (treatment_col in df.columns and df[treatment_col].notnull().all()) + or (p_col in df.columns and df[p_col].notnull().all()) ) inference_col = [x for x in inference_col if x in df.columns] @@ -416,9 +421,9 @@ def get_tmleqini( (pandas.DataFrame): cumulative gains of model estimates based of TMLE """ assert ( - (outcome_col in df.columns) - and (treatment_col in df.columns) - or p_col in df.columns + (outcome_col in df.columns and df[outcome_col].notnull().all()) + and (treatment_col in df.columns and df[treatment_col].notnull().all()) + or (p_col in df.columns and df[p_col].notnull().all()) ) inference_col = [x for x in inference_col if x in df.columns] diff --git a/tests/test_visualize.py b/tests/test_visualize.py new file mode 100644 index 00000000..f3bc0669 --- /dev/null +++ b/tests/test_visualize.py @@ -0,0 +1,14 @@ +import pandas as pd +import numpy as np +import pytest +from causalml.metrics.visualize import get_cumlift + + +def test_visualize_get_cumlift_errors_on_nan(): + df = pd.DataFrame( + [[0, np.nan, 0.5], [1, np.nan, 0.1], [1, 1, 0.4], [0, 1, 0.3], [1, 1, 0.2]], + columns=["w", "y", "pred"], + ) + + with pytest.raises(Exception): + get_cumlift(df)