Skip to content

Commit

Permalink
Update input arguments of XGBoost to be compatible with the latest AP…
Browse files Browse the repository at this point in the history
…Is (#788)
  • Loading branch information
jeongyoonlee authored Aug 1, 2024
1 parent 084a6d0 commit 129b5d9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 11 deletions.
24 changes: 17 additions & 7 deletions causalml/inference/meta/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,15 +555,28 @@ def __init__(
self.test_size = test_size
self.early_stopping_rounds = early_stopping_rounds

super().__init__(
outcome_learner=XGBRegressor(random_state=random_state, *args, **kwargs),
effect_learner=XGBRegressor(
effect_learner = XGBRegressor(
objective=self.effect_learner_objective,
n_estimators=self.effect_learner_n_estimators,
eval_metric=self.effect_learner_eval_metric,
early_stopping_rounds=self.early_stopping_rounds,
random_state=random_state,
*args,
**kwargs,
)
else:
effect_learner = XGBRegressor(
objective=self.effect_learner_objective,
n_estimators=self.effect_learner_n_estimators,
eval_metric=self.effect_learner_eval_metric,
random_state=random_state,
*args,
**kwargs,
),
)

super().__init__(
outcome_learner=XGBRegressor(random_state=random_state, *args, **kwargs),
effect_learner=effect_learner,
)

def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True):
Expand Down Expand Up @@ -665,8 +678,6 @@ def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True):
sample_weight_eval_set=[
sample_weight_test_filt * ((w_test - p_test_filt) ** 2)
],
eval_metric=self.effect_learner_eval_metric,
early_stopping_rounds=self.early_stopping_rounds,
verbose=verbose,
)

Expand All @@ -675,7 +686,6 @@ def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True):
X_filt,
(y_filt - yhat_filt) / (w - p_filt),
sample_weight=sample_weight_filt * ((w - p_filt) ** 2),
eval_metric=self.effect_learner_eval_metric,
)

diff_c = y_filt[w == 0] - yhat_filt[w == 0]
Expand Down
4 changes: 3 additions & 1 deletion causalml/metrics/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,9 @@ def plot_qini(
def plot_tmlegain(
df,
inference_col,
learner=LGBMRegressor(num_leaves=64, learning_rate=0.05, n_estimators=300),
learner=LGBMRegressor(
num_leaves=64, learning_rate=0.05, n_estimators=300, verbose=-1
),
outcome_col="y",
treatment_col="w",
p_col="tau",
Expand Down
9 changes: 6 additions & 3 deletions causalml/propensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,11 @@ class GradientBoostedPropensityModel(PropensityModel):
"""

def __init__(self, early_stop=False, clip_bounds=(1e-3, 1 - 1e-3), **model_kwargs):
self.early_stop = early_stop

super(GradientBoostedPropensityModel, self).__init__(
clip_bounds, **model_kwargs
)
self.early_stop = early_stop

@property
def _model(self):
Expand All @@ -131,9 +132,12 @@ def _model(self):
}
kwargs.update(self.model_kwargs)

if self.early_stop:
kwargs.update({"early_stopping_rounds": 10})

return xgb.XGBClassifier(**kwargs)

def fit(self, X, y, early_stopping_rounds=10, stop_val_size=0.2):
def fit(self, X, y, stop_val_size=0.2):
"""
Fit a propensity model.
Expand All @@ -151,7 +155,6 @@ def fit(self, X, y, early_stopping_rounds=10, stop_val_size=0.2):
X_train,
y_train,
eval_set=[(X_val, y_val)],
early_stopping_rounds=early_stopping_rounds,
)
else:
super(GradientBoostedPropensityModel, self).fit(X, y)
Expand Down

0 comments on commit 129b5d9

Please sign in to comment.