Skip to content

Commit

Permalink
remove .bool()
Browse files Browse the repository at this point in the history
  • Loading branch information
rolandrmgservices committed Dec 1, 2023
1 parent 1354106 commit 53adee1
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions causalml/metrics/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ def get_cumlift(
(pandas.DataFrame): average uplifts of model estimates in cumulative population
"""
assert (
(outcome_col in df.columns and df[[outcome_col]].notnull().all().bool())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all().bool())
(outcome_col in df.columns and df[[outcome_col]].notnull().all())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all())
or (
treatment_effect_col in df.columns
and df[[treatment_effect_col]].notnull().all().bool()
and df[[treatment_effect_col]].notnull().all()
)
)

Expand Down Expand Up @@ -216,11 +216,11 @@ def get_qini(
(pandas.DataFrame): cumulative gains of model estimates in population
"""
assert (
(outcome_col in df.columns and df[[outcome_col]].notnull().all().bool())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all().bool())
(outcome_col in df.columns and df[[outcome_col]].notnull().all())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all())
or (
treatment_effect_col in df.columns
and df[[treatment_effect_col]].notnull().all().bool()
and df[[treatment_effect_col]].notnull().all()
)
)

Expand Down Expand Up @@ -315,9 +315,9 @@ def get_tmlegain(
(pandas.DataFrame): cumulative gains of model estimates based of TMLE
"""
assert (
(outcome_col in df.columns and df[[outcome_col]].notnull().all().bool())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all().bool())
or (p_col in df.columns and df[[p_col]].notnull().all().bool())
(outcome_col in df.columns and df[[outcome_col]].notnull().all())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all())
or (p_col in df.columns and df[[p_col]].notnull().all())
)

inference_col = [x for x in inference_col if x in df.columns]
Expand Down Expand Up @@ -421,9 +421,9 @@ def get_tmleqini(
(pandas.DataFrame): cumulative gains of model estimates based of TMLE
"""
assert (
(outcome_col in df.columns and df[[outcome_col]].notnull().all().bool())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all().bool())
or (p_col in df.columns and df[[p_col]].notnull().all().bool())
(outcome_col in df.columns and df[[outcome_col]].notnull().all())
and (treatment_col in df.columns and df[[treatment_col]].notnull().all())
or (p_col in df.columns and df[[p_col]].notnull().all())
)

inference_col = [x for x in inference_col if x in df.columns]
Expand Down

0 comments on commit 53adee1

Please sign in to comment.