diff --git a/requirements.txt b/requirements.txt index 7caa979..e141049 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +matplotlib pydantic xarray toolz diff --git a/setup.cfg b/setup.cfg index b560680..1cebec8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,7 +8,7 @@ extend-ignore = E203,E501,E402,W605 [isort] known_first_party=xcollection -known_third_party=pkg_resources,pydantic,pytest,setuptools,toolz,xarray +known_third_party=matplotlib,pkg_resources,pydantic,pytest,setuptools,toolz,xarray multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 diff --git a/xcollection/main.py b/xcollection/main.py index 279848c..349124a 100644 --- a/xcollection/main.py +++ b/xcollection/main.py @@ -4,6 +4,7 @@ import pydantic import toolz import xarray as xr +from matplotlib import pyplot as plt def _rpartial(func, *args, **kwargs): @@ -26,6 +27,21 @@ def _validate_input(value): return value +def _make_plot(var, isel_dict, key, ds): + # Are all dims in isel_dict in ds? + if len(set(isel_dict.keys()) - set(ds.dims)) > 0: + return False + + # Is ds a dataset? If so, does it contain var? + if type(ds) == xr.Dataset: + return var in ds.variables + + # Is ds a DataArray? If so, is its name var? + if type(ds) == xr.DataArray: + return var == ds.name + return False + + class Config: validate_assignment = True arbitrary_types_allowed = True @@ -164,3 +180,47 @@ def map( func = _rpartial(func, *args, **kwargs) return type(self)(datasets=toolz.valmap(func, self.datasets)) + + def plot( + self, + var: typing.Union[str, list], + isel_dict: typing.Dict[str, int] = {}, + figsize: typing.Union[list, tuple] = None, + *args: typing.Sequence[typing.Any], + **kwargs: typing.Dict[str, typing.Any], + ): + """For each Dataset that contains {var}, plot ds[var].isel(isel_dict) + For each DataArray for {var}, plot da.isel(isel_dict) + + Parameters + ---------- + var : string or list + Name of variable to plot + isel_dict : dict + Arguments to pass to Dataset.isel() or DataArray.isel() + (Will be deprecated when Collections.isel() exists) + figsize : list or tuple + Argument will be passed to plt.figure() + args + Positional arguments to pass data.plot() + kwargs + Additional keyword arguments to pass as keywords arguments to data.plot() + + Returns + ------- + dict + A dictionary {key: fig} + For each key in Collection that contains var, fig = ds[var].plot() + """ + + return_dict = {} + for key, data in self.items(): + if _make_plot(var, isel_dict, key, data): + fig = plt.figure(figsize=figsize) + axes = fig.add_subplot() + if type(data) == xr.Dataset: + data.isel(isel_dict)[var].plot(ax=axes, *args, **kwargs) + if type(data) == xr.DataArray: + data.isel(isel_dict).plot(ax=axes, *args, **kwargs) + return_dict[key] = (fig, axes) + return return_dict