Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to plotting functions #978

Merged
merged 4 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions ogcore/output_plots.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
from mpl_toolkits.mplot3d import Axes3D
import matplotlib
from ogcore.constants import (
Expand Down Expand Up @@ -391,7 +392,7 @@ def ss_3Dplot(
data = (reform_ss[var] - base_ss[var]).T
elif plot_type == "pct_diff":
data = ((reform_ss[var] - base_ss[var]) / base_ss[var]).T
cmap1 = matplotlib.cm.get_cmap("jet")
cmap1 = matplotlib.colormaps.get_cmap("jet")
X, Y = np.meshgrid(domain, Jgrid)
fig5, ax5 = plt.subplots(subplot_kw={"projection": "3d"})
ax5.set_xlabel(r"age-$s$")
Expand Down Expand Up @@ -652,7 +653,7 @@ def ability_bar_ss(
plt.ylabel(r"Percentage Change in " + VAR_LABELS[var])
if plot_title:
plt.title(plot_title, fontsize=15)
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
# plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
Expand Down Expand Up @@ -1199,14 +1200,17 @@ def inequality_plot(
plt.title(plot_title, fontsize=15)
vals = ax1.get_yticks()
if plot_type == "pct_diff":
ax1.set_yticklabels(["{:,.2%}".format(x) for x in vals])
ticks_loc = ax1.get_yticks().tolist()
ax1.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax1.set_yticklabels(["{:,.2%}".format(x) for x in ticks_loc])
plt.xlim(
(
base_params.start_year - 1,
base_params.start_year + num_years_to_plot,
)
)
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if plot_type == "levels":
plt.legend(loc=9, bbox_to_anchor=(0.5, -0.15), ncol=2)
if path:
fig_path1 = os.path.join(path)
plt.savefig(fig_path1, bbox_inches="tight", dpi=300)
Expand Down
21 changes: 13 additions & 8 deletions ogcore/parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import matplotlib.pyplot as plt
import matplotlib
import matplotlib.ticker as mticker
from ogcore.constants import GROUP_LABELS
from ogcore import utils, txfunc
from ogcore.constants import DEFAULT_START_YEAR, VAR_LABELS
Expand Down Expand Up @@ -107,8 +108,9 @@ def plot_mort_rates(
plt.ylabel(r"Mortality Rates $\rho_{s}$")
plt.legend(loc="upper left")
title = "Mortality Rates"
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.0%}".format(x) for x in vals])
ticks_loc = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_yticklabels(["{:,.0%}".format(x) for x in ticks_loc])
if include_title:
plt.title(title)
if path is None:
Expand Down Expand Up @@ -150,8 +152,9 @@ def plot_pop_growth(
plt.plot(year_vec, p.g_n[start_index : start_index + num_years_to_plot])
plt.xlabel(r"Year $t$")
plt.ylabel(r"Population Growth Rate $g_{n, t}$")
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.2%}".format(x) for x in vals])
ticks_loc = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_yticklabels(["{:,.2%}".format(x) for x in ticks_loc])
if include_title:
plt.title("Population Growth Rates")
if path is None:
Expand Down Expand Up @@ -485,9 +488,11 @@ def plot_g_n(p_list, label_list=[""], include_title=False, path=None):
plt.plot(years, p.g_n[: p.T], label=label_list[i])
plt.xlabel(r"Year $s$ (model periods)")
plt.ylabel(r"Population Growth Rate $g_{n,t}$")
plt.legend(loc="upper right")
vals = ax.get_yticks()
ax.set_yticklabels(["{:,.0%}".format(x) for x in vals])
if label_list[0] != "":
plt.legend(loc="upper right")
ticks_loc = ax.get_yticks().tolist()
ax.yaxis.set_major_locator(mticker.FixedLocator(ticks_loc))
ax.set_yticklabels(["{:,.0%}".format(x) for x in ticks_loc])
if include_title:
plt.title("Population Growth Rates")
if path is None:
Expand Down Expand Up @@ -972,7 +977,7 @@ def plot_income_data(
t = -1
J = abil_midp.shape[0]
abil_mesh, age_mesh = np.meshgrid(abil_midp, ages)
cmap1 = matplotlib.cm.get_cmap("summer")
cmap1 = matplotlib.colormaps["summer"]
if path:
# Make sure that directory is created
utils.mkdirs(path)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_output_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import numpy as np
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from ogcore import utils, output_plots, constants


Expand Down Expand Up @@ -166,6 +167,7 @@ def test_plot_aggregates(
plot_title=plot_title,
)
assert fig
plt.close()


test_data = [
Expand Down Expand Up @@ -217,6 +219,7 @@ def test_plot_industry_aggregates(
plot_title=plot_title,
)
assert fig
plt.close()


test_data = [
Expand Down Expand Up @@ -300,6 +303,7 @@ def test_plot_gdp_ratio(
plot_title=plot_title,
)
assert fig
plt.close()


def test_plot_gdp_ratio_save_fig(tmpdir):
Expand Down Expand Up @@ -327,6 +331,7 @@ def test_ability_bar():
plot_title=" Test Plot Title",
)
assert fig
plt.close()


def test_ability_bar_save_fig(tmpdir):
Expand All @@ -353,6 +358,7 @@ def test_ability_bar_ss():
plot_title=" Test Plot Title",
)
assert fig
plt.close()


data_for_plot = np.ones(80) * 0.3
Expand All @@ -374,6 +380,7 @@ def test_ss_profiles(by_j, plot_data):
plot_title=" Test Plot Title",
)
assert fig
plt.close()


def test_ss_profiles_save_fig(tmpdir):
Expand All @@ -398,6 +405,7 @@ def test_tpi_profiles(by_j):
plot_title=" Test Plot Title",
)
assert fig
plt.close()


test_data = [
Expand Down Expand Up @@ -454,6 +462,7 @@ def test_ss_3Dplot(
plot_title=plot_title,
)
assert fig
plt.close()


def test_ss_3Dplot_save_fig(tmpdir):
Expand Down Expand Up @@ -540,6 +549,7 @@ def test_inequality_plot(
plot_type=plot_type,
)
assert fig
plt.close()


def test_inequality_plot_save_fig(tmpdir):
Expand Down
17 changes: 17 additions & 0 deletions tests/test_parameter_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import scipy.interpolate as si
import matplotlib.image as mpimg
from ogcore import utils, parameter_plots, Specifications
Expand Down Expand Up @@ -74,13 +75,15 @@ def test_plot_imm_rates_save_fig(tmpdir):
def test_plot_mort_rates():
fig = parameter_plots.plot_mort_rates([base_params], include_title=True)
assert fig
plt.close()


def test_plot_surv_rates():
fig = parameter_plots.plot_mort_rates(
[base_params], survival_rates=True, include_title=True
)
assert fig
plt.close()


def test_plot_mort_rates_save_fig(tmpdir):
Expand All @@ -104,6 +107,7 @@ def test_plot_pop_growth():
base_params, start_year=int(base_params.start_year), include_title=True
)
assert fig
plt.close()


def test_plot_pop_growth_rates_save_fig(tmpdir):
Expand All @@ -119,6 +123,7 @@ def test_plot_ability_profiles():
p = Specifications()
fig = parameter_plots.plot_ability_profiles(p, p2=p, include_title=True)
assert fig
plt.close()


def test_plot_log_ability_profiles():
Expand All @@ -127,6 +132,7 @@ def test_plot_log_ability_profiles():
p, p2=p, log_scale=True, include_title=True
)
assert fig
plt.close()


def test_plot_ability_profiles_save_fig(tmpdir):
Expand All @@ -144,6 +150,7 @@ def test_plot_elliptical_u():
)
assert fig1
assert fig2
plt.close()


def test_plot_elliptical_u_save_fig(tmpdir):
Expand All @@ -157,6 +164,7 @@ def test_plot_chi_n():
p = Specifications()
fig = parameter_plots.plot_chi_n([p], include_title=True)
assert fig
plt.close()


def test_plot_chi_n_save_fig(tmpdir):
Expand All @@ -177,6 +185,7 @@ def test_plot_population(years_to_plot):
base_params, years_to_plot=years_to_plot, include_title=True
)
assert fig
plt.close()


def test_plot_population_save_fig(tmpdir):
Expand Down Expand Up @@ -215,6 +224,7 @@ def test_plot_fert_rates():
fert_rates = np.random.uniform(size=totpers).reshape((1, totpers))
fig = parameter_plots.plot_fert_rates([fert_rates], include_title=True)
assert fig
plt.close()


def test_plot_fert_rates_save_fig(tmpdir):
Expand Down Expand Up @@ -258,6 +268,7 @@ def test_plot_g_n():
p = Specifications()
fig = parameter_plots.plot_g_n([p], include_title=True)
assert fig
plt.close()


def test_plot_g_n_savefig(tmpdir):
Expand All @@ -276,6 +287,7 @@ def test_plot_mort_rates_data():
path=None,
)
assert fig
plt.close()


def test_plot_mort_rates_data_save_fig(tmpdir):
Expand All @@ -300,6 +312,7 @@ def test_plot_omega_fixed():
age_per_EpS, omega_SS_orig, omega_SSfx, E, S
)
assert fig
plt.close()


def test_plot_omega_fixed_save_fig(tmpdir):
Expand All @@ -326,6 +339,7 @@ def test_plot_imm_fixed():
age_per_EpS, imm_rates_orig, imm_rates_adj, E, S
)
assert fig
plt.close()


def test_plot_imm_fixed_save_fig(tmpdir):
Expand Down Expand Up @@ -360,6 +374,7 @@ def test_plot_population_path():
S,
)
assert fig
plt.close()


def test_plot_population_path_save_fig(tmpdir):
Expand Down Expand Up @@ -398,6 +413,7 @@ def test_plot_income_data():
fig = parameter_plots.plot_income_data(ages, abil_midp, abil_pcts, emat)

assert fig
plt.close()


def test_plot_income_data_save_fig(tmpdir):
Expand Down Expand Up @@ -481,6 +497,7 @@ def test_plot_2D_taxfunc(
)

assert fig
plt.close()
else:
assert True

Expand Down
Loading