diff --git a/causalml/inference/meta/rlearner.py b/causalml/inference/meta/rlearner.py index 915efdbc..dacfa64e 100644 --- a/causalml/inference/meta/rlearner.py +++ b/causalml/inference/meta/rlearner.py @@ -555,15 +555,28 @@ def __init__( self.test_size = test_size self.early_stopping_rounds = early_stopping_rounds - super().__init__( - outcome_learner=XGBRegressor(random_state=random_state, *args, **kwargs), - effect_learner=XGBRegressor( + effect_learner = XGBRegressor( + objective=self.effect_learner_objective, + n_estimators=self.effect_learner_n_estimators, + eval_metric=self.effect_learner_eval_metric, + early_stopping_rounds=self.early_stopping_rounds, + random_state=random_state, + *args, + **kwargs, + ) + else: + effect_learner = XGBRegressor( objective=self.effect_learner_objective, n_estimators=self.effect_learner_n_estimators, + eval_metric=self.effect_learner_eval_metric, random_state=random_state, *args, **kwargs, - ), + ) + + super().__init__( + outcome_learner=XGBRegressor(random_state=random_state, *args, **kwargs), + effect_learner=effect_learner, ) def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True): @@ -665,8 +678,6 @@ def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True): sample_weight_eval_set=[ sample_weight_test_filt * ((w_test - p_test_filt) ** 2) ], - eval_metric=self.effect_learner_eval_metric, - early_stopping_rounds=self.early_stopping_rounds, verbose=verbose, ) @@ -675,7 +686,6 @@ def fit(self, X, treatment, y, p=None, sample_weight=None, verbose=True): X_filt, (y_filt - yhat_filt) / (w - p_filt), sample_weight=sample_weight_filt * ((w - p_filt) ** 2), - eval_metric=self.effect_learner_eval_metric, ) diff_c = y_filt[w == 0] - yhat_filt[w == 0] diff --git a/causalml/metrics/visualize.py b/causalml/metrics/visualize.py index 7dd13980..8d6b4a61 100644 --- a/causalml/metrics/visualize.py +++ b/causalml/metrics/visualize.py @@ -640,7 +640,9 @@ def plot_qini( def plot_tmlegain( df, inference_col, - learner=LGBMRegressor(num_leaves=64, learning_rate=0.05, n_estimators=300), + learner=LGBMRegressor( + num_leaves=64, learning_rate=0.05, n_estimators=300, verbose=-1 + ), outcome_col="y", treatment_col="w", p_col="tau", diff --git a/causalml/propensity.py b/causalml/propensity.py index 0d35179d..f3aee9e3 100644 --- a/causalml/propensity.py +++ b/causalml/propensity.py @@ -113,10 +113,11 @@ class GradientBoostedPropensityModel(PropensityModel): """ def __init__(self, early_stop=False, clip_bounds=(1e-3, 1 - 1e-3), **model_kwargs): + self.early_stop = early_stop + super(GradientBoostedPropensityModel, self).__init__( clip_bounds, **model_kwargs ) - self.early_stop = early_stop @property def _model(self): @@ -131,9 +132,12 @@ def _model(self): } kwargs.update(self.model_kwargs) + if self.early_stop: + kwargs.update({"early_stopping_rounds": 10}) + return xgb.XGBClassifier(**kwargs) - def fit(self, X, y, early_stopping_rounds=10, stop_val_size=0.2): + def fit(self, X, y, stop_val_size=0.2): """ Fit a propensity model. @@ -151,7 +155,6 @@ def fit(self, X, y, early_stopping_rounds=10, stop_val_size=0.2): X_train, y_train, eval_set=[(X_val, y_val)], - early_stopping_rounds=early_stopping_rounds, ) else: super(GradientBoostedPropensityModel, self).fit(X, y)