diff --git a/causalml/metrics/visualize.py b/causalml/metrics/visualize.py index 011fc521..39fe4bf7 100644 --- a/causalml/metrics/visualize.py +++ b/causalml/metrics/visualize.py @@ -848,7 +848,7 @@ def qini_score( return (qini.sum(axis=0) - qini[RANDOM_COL].sum()) / qini.shape[0] -def plot_ps_diagnostics(df, covariate_col, treatment_col="w", p_col="p"): +def plot_ps_diagnostics(df, covariate_col, treatment_col="w", p_col="p", bal_tol=0.1): """Plot covariate balances (standardized differences between the treatment and the control) before and after weighting the sample using the inverse probability of treatment weights. @@ -865,40 +865,42 @@ def plot_ps_diagnostics(df, covariate_col, treatment_col="w", p_col="p"): IPTW = get_simple_iptw(W, PS) diffs_pre = get_std_diffs(X, W, weighted=False) - num_unbal_pre = (np.abs(diffs_pre) > 0.1).sum()[0] + num_unbal_pre = (np.abs(diffs_pre) > bal_tol).sum()[0] diffs_post = get_std_diffs(X, W, IPTW, weighted=True) - num_unbal_post = (np.abs(diffs_post) > 0.1).sum()[0] + num_unbal_post = (np.abs(diffs_post) > bal_tol).sum()[0] - diff_plot = _plot_std_diffs(diffs_pre, num_unbal_pre, diffs_post, num_unbal_post) + diff_plot = _plot_std_diffs( + diffs_pre, num_unbal_pre, diffs_post, num_unbal_post, bal_tol=bal_tol + ) return diff_plot -def _plot_std_diffs(diffs_pre, num_unbal_pre, diffs_post, num_unbal_post): - fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10), sharex=True, sharey=True) +def _plot_std_diffs(diffs_pre, num_unbal_pre, diffs_post, num_unbal_post, bal_tol=0.1): + fig, ax1 = plt.subplots() color = "#EA2566" - sns.stripplot(diffs_pre.iloc[:, 0], diffs_pre.index, ax=ax1) - ax1.set_xlabel( - "Before. Number of unbalanced covariates: {num_unbal}".format( - num_unbal=num_unbal_pre - ), - fontsize=14, + sds_pre = pd.DataFrame( + {"std_diff": diffs_pre[0], "covariate": diffs_pre.index, "prepost": "pre"} ) - ax1.axvline(x=-0.1, ymin=0, ymax=1, color=color, linestyle="--") - ax1.axvline(x=0.1, ymin=0, ymax=1, color=color, linestyle="--") + sds_post = pd.DataFrame( + {"std_diff": diffs_post[0], "covariate": diffs_post.index, "prepost": "post"} + ) + + sds = pd.concat([sds_pre, sds_post], ignore_index=True) - sns.stripplot(diffs_post.iloc[:, 0], diffs_post.index, ax=ax2) - ax2.set_xlabel( - "After. Number of unbalanced covariates: {num_unbal}".format( - num_unbal=num_unbal_post + sns.stripplot(data=sds, x="std_diff", y="covariate", hue="prepost", ax=ax1) + + ax1.set_xlabel( + "Pre/Post Number of unbalanced covariates: {num_unbal_pre}/{num_unbal_post}".format( + num_unbal_pre=num_unbal_pre, num_unbal_post=num_unbal_post ), fontsize=14, ) - ax2.axvline(x=-0.1, ymin=0, ymax=1, color=color, linestyle="--") - ax2.axvline(x=0.1, ymin=0, ymax=1, color=color, linestyle="--") + ax1.axvline(x=-bal_tol, ymin=0, ymax=1, color=color, linestyle="--", lw=2) + ax1.axvline(x=bal_tol, ymin=0, ymax=1, color=color, linestyle="--", lw=2) fig.suptitle("Standardized differences in means", fontsize=16)