Skip to content

Commit

Permalink
fix plot_std_diffs, add bal_tol, condense to one plot (#723)
Browse files Browse the repository at this point in the history
Co-authored-by: Roland Stevenson <[email protected]>
  • Loading branch information
ras44 and rolandrmgservices authored Dec 8, 2023
1 parent aa1bc90 commit f200829
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions causalml/metrics/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def qini_score(
return (qini.sum(axis=0) - qini[RANDOM_COL].sum()) / qini.shape[0]


def plot_ps_diagnostics(df, covariate_col, treatment_col="w", p_col="p"):
def plot_ps_diagnostics(df, covariate_col, treatment_col="w", p_col="p", bal_tol=0.1):
"""Plot covariate balances (standardized differences between the treatment and the control)
before and after weighting the sample using the inverse probability of treatment weights.
Expand All @@ -865,40 +865,42 @@ def plot_ps_diagnostics(df, covariate_col, treatment_col="w", p_col="p"):
IPTW = get_simple_iptw(W, PS)

diffs_pre = get_std_diffs(X, W, weighted=False)
num_unbal_pre = (np.abs(diffs_pre) > 0.1).sum()[0]
num_unbal_pre = (np.abs(diffs_pre) > bal_tol).sum()[0]

diffs_post = get_std_diffs(X, W, IPTW, weighted=True)
num_unbal_post = (np.abs(diffs_post) > 0.1).sum()[0]
num_unbal_post = (np.abs(diffs_post) > bal_tol).sum()[0]

diff_plot = _plot_std_diffs(diffs_pre, num_unbal_pre, diffs_post, num_unbal_post)
diff_plot = _plot_std_diffs(
diffs_pre, num_unbal_pre, diffs_post, num_unbal_post, bal_tol=bal_tol
)

return diff_plot


def _plot_std_diffs(diffs_pre, num_unbal_pre, diffs_post, num_unbal_post):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10), sharex=True, sharey=True)
def _plot_std_diffs(diffs_pre, num_unbal_pre, diffs_post, num_unbal_post, bal_tol=0.1):
fig, ax1 = plt.subplots()

color = "#EA2566"

sns.stripplot(diffs_pre.iloc[:, 0], diffs_pre.index, ax=ax1)
ax1.set_xlabel(
"Before. Number of unbalanced covariates: {num_unbal}".format(
num_unbal=num_unbal_pre
),
fontsize=14,
sds_pre = pd.DataFrame(
{"std_diff": diffs_pre[0], "covariate": diffs_pre.index, "prepost": "pre"}
)
ax1.axvline(x=-0.1, ymin=0, ymax=1, color=color, linestyle="--")
ax1.axvline(x=0.1, ymin=0, ymax=1, color=color, linestyle="--")
sds_post = pd.DataFrame(
{"std_diff": diffs_post[0], "covariate": diffs_post.index, "prepost": "post"}
)

sds = pd.concat([sds_pre, sds_post], ignore_index=True)

sns.stripplot(diffs_post.iloc[:, 0], diffs_post.index, ax=ax2)
ax2.set_xlabel(
"After. Number of unbalanced covariates: {num_unbal}".format(
num_unbal=num_unbal_post
sns.stripplot(data=sds, x="std_diff", y="covariate", hue="prepost", ax=ax1)

ax1.set_xlabel(
"Pre/Post Number of unbalanced covariates: {num_unbal_pre}/{num_unbal_post}".format(
num_unbal_pre=num_unbal_pre, num_unbal_post=num_unbal_post
),
fontsize=14,
)
ax2.axvline(x=-0.1, ymin=0, ymax=1, color=color, linestyle="--")
ax2.axvline(x=0.1, ymin=0, ymax=1, color=color, linestyle="--")
ax1.axvline(x=-bal_tol, ymin=0, ymax=1, color=color, linestyle="--", lw=2)
ax1.axvline(x=bal_tol, ymin=0, ymax=1, color=color, linestyle="--", lw=2)

fig.suptitle("Standardized differences in means", fontsize=16)

Expand Down

1 comment on commit f200829

@ImanEmtiazi728
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I want to install causalml package in jupyter notebook, I receive the below error;
Why?

ERROR: Exception:
Traceback (most recent call last):
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\cli\base_command.py", line 173, in _main
status = self.run(options, args)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\cli\req_command.py", line 203, in wrapper
return func(self, options, args)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\commands\install.py", line 315, in run
requirement_set = resolver.resolve(
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\resolution\resolvelib\resolver.py", line 94, in resolve
result = self._result = resolver.resolve(
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\resolvelib\resolvers.py", line 472, in resolve
state = resolution.resolve(requirements, max_rounds=max_rounds)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\resolvelib\resolvers.py", line 341, in resolve
self._add_to_criteria(self.state.criteria, r, parent=None)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\resolvelib\resolvers.py", line 172, in _add_to_criteria
if not criterion.candidates:
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\resolvelib\structs.py", line 151, in bool
return bool(self._sequence)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\resolution\resolvelib\found_candidates.py", line 140, in bool
return any(self)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\resolution\resolvelib\found_candidates.py", line 128, in
return (c for c in iterator if id(c) not in self._incompatible_ids)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\resolution\resolvelib\found_candidates.py", line 29, in _iter_built
for version, func in infos:
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\resolution\resolvelib\factory.py", line 272, in iter_index_candidate_infos
result = self._finder.find_best_candidate(
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\index\package_finder.py", line 851, in find_best_candidate
candidates = self.find_all_candidates(project_name)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\index\package_finder.py", line 798, in find_all_candidates
page_candidates = list(page_candidates_it)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\index\sources.py", line 134, in page_candidates
yield from self._candidates_from_page(self._link)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\index\package_finder.py", line 758, in process_project_url
html_page = self._link_collector.fetch_page(project_url)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\index\collector.py", line 490, in fetch_page
return _get_html_page(location, session=self.session)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\index\collector.py", line 400, in _get_html_page
resp = _get_html_response(url, session=session)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\index\collector.py", line 115, in _get_html_response
resp = session.get(
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\requests\sessions.py", line 555, in get
return self.request('GET', url, **kwargs)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_internal\network\session.py", line 454, in request
return super().request(method, url, *args, **kwargs)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\requests\sessions.py", line 542, in request
resp = self.send(prep, **send_kwargs)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\requests\sessions.py", line 655, in send
r = adapter.send(request, **kwargs)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\cachecontrol\adapter.py", line 53, in send
resp = super(CacheControlAdapter, self).send(request, **kw)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\requests\adapters.py", line 439, in send
resp = conn.urlopen(
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\urllib3\connectionpool.py", line 696, in urlopen
self._prepare_proxy(conn)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\urllib3\connectionpool.py", line 964, in _prepare_proxy
conn.connect()
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\urllib3\connection.py", line 359, in connect
conn = self._connect_tls_proxy(hostname, conn)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\urllib3\connection.py", line 500, in connect_tls_proxy
return ssl_wrap_socket(
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\urllib3\util\ssl.py", line 453, in ssl_wrap_socket
ssl_sock = ssl_wrap_socket_impl(sock, context, tls_in_tls)
File "C:\Users\Iman\anaconda3\lib\site-packages\pip_vendor\urllib3\util\ssl.py", line 495, in _ssl_wrap_socket_impl
return ssl_context.wrap_socket(sock)
File "C:\Users\Iman\anaconda3\lib\ssl.py", line 500, in wrap_socket
return self.sslsocket_class._create(
File "C:\Users\Iman\anaconda3\lib\ssl.py", line 997, in _create
raise ValueError("check_hostname requires server_hostname")
ValueError: check_hostname requires server_hostname

Please sign in to comment.