Skip to content

Commit

Permalink
Refactor plots and introduce --figs
Browse files Browse the repository at this point in the history
  • Loading branch information
jochenklar committed Nov 23, 2023
1 parent fe98811 commit 572c9b6
Show file tree
Hide file tree
Showing 46 changed files with 283 additions and 238 deletions.
8 changes: 5 additions & 3 deletions isimip_qa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ def get_parser():
parser.add_argument('-p', '--periods', dest='periods', default=None,
help='Extract only specific periods (comma seperated, format: YYYY_YYYY)')

parser.add_argument('-g', '--grid', type=int, dest='grid', default=0, choices=[0, 1, 2],
help='Number of dimensions of the plot grid [default: 0, i.e. one plot]')
parser.add_argument('-f', '--force', dest='force', action='store_true', default=False,
help='Always run extractions')
parser.add_argument('-l', '--load', dest='load', action='store_true', default=False,
Expand All @@ -56,6 +54,10 @@ def get_parser():
help='Treat these placeholders as primary and plot them in color [default: all]')
parser.add_argument('--gridarea', dest='gridarea', default=None,
help='Use a CDO gridarea file instead of computing the gridarea when computing means')
parser.add_argument('--grid', type=int, dest='grid', default=0, choices=[0, 1, 2],
help='Number of dimensions of the plot grid [default: 0, i.e. no grid]')
parser.add_argument('--figs', type=int, dest='figs', default=0,
help='Number of placeholders which generate seperate figures [default: 0]')

parser.add_argument('--ymin', type=float, dest='ymin', default=None,
help='Fixed minimal y value for plots.')
Expand Down Expand Up @@ -152,5 +154,5 @@ def main():
):
plot = plot_class(extraction_class, datasets, region, period,
path=settings.PATHS[0].stem, dimensions=settings.PLACEHOLDERS,
grid=settings.GRID)
grid=settings.GRID, figs=settings.FIGS)
plot.create()
82 changes: 45 additions & 37 deletions isimip_qa/mixins/plots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import sys
from itertools import chain, product
from pathlib import Path

Expand All @@ -18,7 +19,7 @@ def write(self, fig, path):

logger.info(f'write {path}')
try:
fig.savefig(path, bbox_inches='tight')
fig.savefig(path)
except ValueError as e:
logger.error(f'could not save {path} ({e})')
plt.close()
Expand All @@ -37,25 +38,32 @@ class GridPlotMixin:
]
linestyles = ['solid', 'dashed', 'dashdot', 'dotted']
markers = ['.', '*', 'D', 's']
max_dimensions = sys.maxsize

def __init__(self, *args, path=None, dimensions=None, grid=2, **kwargs):
def __init__(self, *args, path=None, dimensions=None, grid=0, figs=0, **kwargs):
self.path = Path(path) if path else None
self.dimensions = dimensions
self.grid = grid
self.figs = figs

if self.dimensions:
self.keys = list(self.dimensions.keys())
self.values = list(self.dimensions.values())
self.permutations = list(product(*self.values))
self.dimensions_keys = list(self.dimensions.keys())
self.dimensions_values = list(self.dimensions.values())
self.dimensions_len = len(self.dimensions_keys)
self.permutations = list(product(*self.dimensions_values))
self.styles = self.get_styles()
self.figs = max(self.figs, self.dimensions_len - self.grid - self.max_dimensions)

super().__init__(*args, **kwargs)

def get_figure(self, nrows, ncols, ratio=1):
fig, axs = plt.subplots(nrows, ncols, squeeze=False, figsize=(6 * ratio * ncols, 6 * nrows))
fig, axs = plt.subplots(nrows, ncols, squeeze=False, figsize=(6 * ratio * ncols, 6 * nrows),
constrained_layout=True)
for ax in chain.from_iterable(axs):
ax.tick_params(bottom=False, labelbottom=False, left=False, labelleft=False)
return fig, axs

def get_path(self, ifig):
def get_figure_path(self, ifig=0):
if self.path is None:
return None

Expand All @@ -68,18 +76,18 @@ def get_path(self, ifig):
if self.dimensions:
placeholders = {}

for j, key in enumerate(self.keys):
if ifig is None or j < self.grid:
primary_values = [value for value in self.values[j] if value in settings.PRIMARY]
for j, key in enumerate(self.dimensions_keys):
if (j < self.grid) or (j < self.dimensions_len - self.figs):
# for the first dimensions, which are not figure dimensions, combine the values
primary_values = [value for value in self.dimensions_values[j] if value in settings.PRIMARY]
if primary_values:
values_strings = primary_values
elif len(self.values[j]) < 10:
values_strings = self.values[j]
elif len(self.dimensions_values[j]) < 10:
values_strings = self.dimensions_values[j]
else:
values_strings = ['various']
else:
# this works because for j > self.grid, the permutations
# only repeat with a "period" of nfig
# for the last self.figs dimensions, which generate seperate figures, take seperate values
values_strings = [self.permutations[ifig][j]]

placeholders[key] = '+'.join(values_strings).lower()
Expand All @@ -102,21 +110,19 @@ def get_path(self, ifig):

return settings.PLOTS_PATH / self.path.with_name(stem).with_suffix(suffix)

def get_grid(self, figs=False):
grid = [1, 1, 1] if figs else [1, 1]
def get_grid(self):
grid = [1, 1, 1]

if self.dimensions:
for j, key in enumerate(self.keys):
for j, key in enumerate(self.dimensions_keys):
ndim = len(self.dimensions[key])

if j < self.grid:
# the grid dimensions generate rows and columns
grid[j] = ndim

if figs:
if j == self.grid:
grid[-1] = len(self.values[j])
elif j > self.grid:
grid[-1] *= len(self.values[j])
elif (j >= self.dimensions_len - self.figs):
# the last self.figs dimensions multiply the number of seperate figures
grid[-1] *= len(self.dimensions_values[j])

return reversed(grid)

Expand All @@ -126,16 +132,25 @@ def get_grid_indexes(self, i):
if self.dimensions:
permutation = self.permutations[i]

for j, key in enumerate(self.keys):
for j, key in enumerate(self.dimensions_keys):
value = permutation[j]
value_index = self.dimensions[key].index(value)

if j < self.grid:
# the first dimensions indicate the column and the row
grid_indexes[j] = value_index
elif j == len(self.keys) - 1:
grid_indexes[-1] += value_index
else:
grid_indexes[-1] += value_index * len(self.values[j+1])
elif self.figs > 0:
# the figure index is computed like this:
# lets assume self.figs = 3 and i3,i2,i1 are the indexes of dimensions d3,d2,d1
# the figure index is i1 + i2 * len(d1) + i3 * len(d1) * len(d2)
if j == self.dimensions_len - 1:
# the last value_index just adds to the figure index
grid_indexes[-1] += value_index
elif j >= self.dimensions_len - self.figs:
# the other value_indexes need be multiplied by the lenghts of the dimensions to the right
for values in self.dimensions_values[j+1:]:
value_index *= len(values)
grid_indexes[-1] += value_index

return reversed(grid_indexes)

Expand All @@ -160,7 +175,7 @@ def get_primary(self, i):
if self.dimensions and settings.PRIMARY:
permutation = self.permutations[i]

for j, key in enumerate(self.keys):
for j, key in enumerate(self.dimensions_keys):
if permutation[j] in settings.PRIMARY:
return True

Expand All @@ -185,7 +200,6 @@ def get_subplots(self):
continue

try:

attrs = self.get_attrs(dataset)
except ExtractionNotFound:
attrs = {}
Expand All @@ -198,15 +212,13 @@ def get_subplots(self):
var=var,
label=self.get_label(index),
title=self.get_title(index),
full_title=self.get_full_title(index),
color=self.get_color(index),
linestyle=self.get_linestyle(index),
marker=self.get_marker(index),
ifig=ifig,
irow=irow,
icol=icol,
primary=self.get_primary(index),
path=self.get_path(ifig)
primary=self.get_primary(index)
)

subplots.append(subplot)
Expand All @@ -219,10 +231,6 @@ def get_df(self, dataset):
def get_attrs(self, dataset):
raise NotImplementedError

def get_full_title(self, i):
if self.dimensions:
return ' '.join(self.permutations[i])

def get_title(self, i):
if self.dimensions:
return ' '.join(self.permutations[i][:self.grid])
Expand Down
59 changes: 32 additions & 27 deletions isimip_qa/plots/cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,39 @@ def get_attrs(self, dataset):
def create(self):
logger.info(f'plot {self.region.specifier} {self.extraction_class.specifier} {self.specifier}')

nfigs, nrows, ncols = self.get_grid()
subplots = self.get_subplots()
if subplots:
nrows, ncols = self.get_grid()
fig, axs = self.get_figure(nrows, ncols)

for sp in subplots:
ax = axs.item(sp.irow, sp.icol)

ymin = self.get_ymin(sp, subplots)
ymax = self.get_ymax(sp, subplots)

if sp.primary:
ax.step(sp.df.index, sp.df[sp.var], where='mid', color=sp.color,
linestyle=sp.linestyle, label=sp.label, zorder=10)
if sp.label:
ax.legend(loc='lower left').set_zorder(20)
else:
ax.step(sp.df.index, sp.df[sp.var], where='mid', color='grey', zorder=0)

ax.set_title(sp.title)
ax.set_xlabel(sp.attrs.get('standard_name'))
ax.set_ylabel('CDF')
ax.set_ylim(ymin, ymax)
ax.tick_params(bottom=True, labelbottom=True, left=True, labelleft=True)

if self.path:
self.write(fig, sp.path)
else:
self.show()
for ifig in range(nfigs):
fig_path = self.get_figure_path(ifig)
fig_subplots = [sp for sp in subplots if sp.ifig == ifig]

if fig_subplots:
fig, axs = self.get_figure(nrows, ncols)

for sp in fig_subplots:
ax = axs.item(sp.irow, sp.icol)

ymin = self.get_ymin(sp, subplots)
ymax = self.get_ymax(sp, subplots)

if sp.primary:
ax.step(sp.df.index, sp.df[sp.var], where='mid', color=sp.color,
linestyle=sp.linestyle, label=sp.label, zorder=10)
if sp.label:
ax.legend(loc='lower left').set_zorder(20)
else:
ax.step(sp.df.index, sp.df[sp.var], where='mid', color='grey', zorder=0)

ax.set_title(sp.title)
ax.set_xlabel(sp.attrs.get('standard_name'))
ax.set_ylabel('CDF')
ax.set_ylim(ymin, ymax)
ax.tick_params(bottom=True, labelbottom=True, left=True, labelleft=True)

if fig_path:
self.write(fig, fig_path)
else:
self.show()
else:
logger.info('nothing to plot')
57 changes: 31 additions & 26 deletions isimip_qa/plots/daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,38 @@ def get_attrs(self, dataset):
def create(self):
logger.info(f'plot {self.region.specifier} {self.extraction_class.specifier} {self.specifier}')

nfigs, nrows, ncols = self.get_grid()
subplots = self.get_subplots()
if subplots:
nrows, ncols = self.get_grid()
fig, axs = self.get_figure(nrows, ncols)

for sp in subplots:
ax = axs.item(sp.irow, sp.icol)

ymin = self.get_ymin(sp, subplots)
ymax = self.get_ymax(sp, subplots)

if sp.primary:
ax.plot(sp.df.index, sp.df[sp.var], label=sp.label, zorder=10)
if sp.label:
ax.legend(loc='lower left').set_zorder(20)
else:
ax.plot(sp.df.index, sp.df[sp.var], color='grey', zorder=0)

ax.set_title(sp.title)
ax.set_xlabel('date')
ax.set_ylabel(f'{sp.var} [{sp.attrs.get("units")}]')
ax.set_ylim(ymin, ymax)
ax.tick_params(bottom=True, labelbottom=True, left=True, labelleft=True)

if self.path:
self.write(fig, sp.path)
else:
self.show()
for ifig in range(nfigs):
fig_path = self.get_figure_path(ifig)
fig_subplots = [sp for sp in subplots if sp.ifig == ifig]

if fig_subplots:
fig, axs = self.get_figure(nrows, ncols)

for sp in fig_subplots:
ax = axs.item(sp.irow, sp.icol)

ymin = self.get_ymin(sp, subplots)
ymax = self.get_ymax(sp, subplots)

if sp.primary:
ax.plot(sp.df.index, sp.df[sp.var], label=sp.label, zorder=10)
if sp.label:
ax.legend(loc='lower left').set_zorder(20)
else:
ax.plot(sp.df.index, sp.df[sp.var], color='grey', zorder=0)

ax.set_title(sp.title)
ax.set_xlabel('date')
ax.set_ylabel(f'{sp.var} [{sp.attrs.get("units")}]')
ax.set_ylim(ymin, ymax)
ax.tick_params(bottom=True, labelbottom=True, left=True, labelleft=True)

if fig_path:
self.write(fig, fig_path)
else:
self.show()
else:
logger.info('nothing to plot')
Loading

0 comments on commit 572c9b6

Please sign in to comment.