Skip to content

Commit

Permalink
Support single level reprojection
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrjones committed Jan 11, 2024
1 parent 5aee191 commit 708296f
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 75 deletions.
3 changes: 2 additions & 1 deletion ndpyramid/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa

from .core import pyramid_coarsen, pyramid_reproject
from .coarsen import pyramid_coarsen
from .reproject import pyramid_reproject
from .regrid import pyramid_regrid
from ._version import __version__
50 changes: 50 additions & 0 deletions ndpyramid/coarsen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations # noqa: F401

import datatree as dt
import xarray as xr

from .utils import get_version, multiscales_template


def pyramid_coarsen(
ds: xr.Dataset, *, factors: list[int], dims: list[str], **kwargs
) -> dt.DataTree:
"""Create a multiscale pyramid via coarsening of a dataset by given factors
Parameters
----------
ds : xarray.Dataset
The dataset to coarsen.
factors : list[int]
The factors to coarsen by.
dims : list[str]
The dimensions to coarsen.
kwargs : dict
Additional keyword arguments to pass to xarray.Dataset.coarsen.
"""

# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']

attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(len(factors))],
type='reduce',
method='pyramid_coarsen',
version=get_version(),
kwargs=save_kwargs,
)
}

# set up pyramid
plevels = {}

# pyramid data
for key, factor in enumerate(factors):
# merge dictionary via union operator
kwargs |= {d: factor for d in dims}
plevels[str(key)] = ds.coarsen(**kwargs).mean()

plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)
145 changes: 71 additions & 74 deletions ndpyramid/core.py → ndpyramid/reproject.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,53 +5,90 @@

import datatree as dt
import xarray as xr
from rasterio.warp import Resampling

from .common import Projection
from .utils import add_metadata_and_zarr_encoding, get_version, multiscales_template


def pyramid_coarsen(
ds: xr.Dataset, *, factors: list[int], dims: list[str], **kwargs
) -> dt.DataTree:
"""Create a multiscale pyramid via coarsening of a dataset by given factors
Parameters
----------
ds : xarray.Dataset
The dataset to coarsen.
factors : list[int]
The factors to coarsen by.
dims : list[str]
The dimensions to coarsen.
kwargs : dict
Additional keyword arguments to pass to xarray.Dataset.coarsen.
"""

def _define_spec(
levels: int,
pixels_per_tile: int
):
# multiscales spec
save_kwargs = locals()
del save_kwargs['ds']

attrs = {
save_kwargs = {'levels': levels, 'pixels_per_tile': pixels_per_tile}
return {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(len(factors))],
datasets=[{'path': str(i)} for i in range(levels)],
type='reduce',
method='pyramid_coarsen',
method='pyramid_reproject',
version=get_version(),
kwargs=save_kwargs,
)
}

# set up pyramid
plevels = {}
def _da_reproject(da, *, dim, crs, resampling, transform):
return da.rio.reproject(
crs,
resampling=resampling,
shape=(dim, dim),
transform=transform,
)

# pyramid data
for key, factor in enumerate(factors):
# merge dictionary via union operator
kwargs |= {d: factor for d in dims}
plevels[str(key)] = ds.coarsen(**kwargs).mean()
def level_reproject(
ds: xr.Dataset,
*,
projection_model: Projection,
level: int,
pixels_per_tile: int,
resampling_dict: dict,
extra_dim: str = None,
) -> xr.Dataset:

plevels['/'] = xr.Dataset(attrs=attrs)
return dt.DataTree.from_dict(plevels)
"""Create a level of a multiscale pyramid of a dataset via reprojection.
Parameters
----------
ds : xarray.Dataset
The dataset to create a multiscale pyramid of.
projection : Projection
The projection model to use.
level : int
The level of the pyramid to create.
pixels_per_tile : int, optional
Number of pixels per tile
resampling : dict
Rasterio warp resampling method to use. Keys are variable names and values are warp resampling methods.
extra_dim : str, optional
The name of the extra dimension to iterate over. Default is None.
Returns
-------
xr.Dataset
The multiscale pyramid level.
"""

dim = 2**level * pixels_per_tile
dst_transform = projection_model.transform(dim=dim)

# create the data array for each level
ds_level = xr.Dataset(attrs=ds.attrs)
for k, da in ds.items():
if len(da.shape) == 4:
# if extra_dim is not specified, raise an error
if extra_dim is None:
raise ValueError("must specify 'extra_dim' to iterate over 4d data")
da_all = []
for index in ds[extra_dim]:
# reproject each index of the 4th dimension
da_reprojected = _da_reproject(da.sel({extra_dim: index}), dim=dim, crs=projection_model._crs, resampling=Resampling[resampling_dict[k]], transform=dst_transform)
da_all.append(da_reprojected)
ds_level[k] = xr.concat(da_all, ds[extra_dim])
else:
# if the data array is not 4D, just reproject it
ds_level[k] = _da_reproject(da, dim=dim, crs=projection_model._crs, resampling=Resampling[resampling_dict[k]], transform=dst_transform)
return ds_level


def pyramid_reproject(
Expand Down Expand Up @@ -93,20 +130,7 @@ def pyramid_reproject(
"""

import rioxarray # noqa: F401
from rasterio.warp import Resampling

# multiscales spec
save_kwargs = {'levels': levels, 'pixels_per_tile': pixels_per_tile}
attrs = {
'multiscales': multiscales_template(
datasets=[{'path': str(i)} for i in range(levels)],
type='reduce',
method='pyramid_reproject',
version=get_version(),
kwargs=save_kwargs,
)
}
attrs = _define_spec(levels, pixels_per_tile)

# Convert resampling from string to dictionary if necessary
if isinstance(resampling, str):
Expand All @@ -121,34 +145,7 @@ def pyramid_reproject(

# pyramid data
for level in range(levels):
lkey = str(level)
dim = 2**level * pixels_per_tile
dst_transform = projection_model.transform(dim=dim)

def reproject(da, var):
return da.rio.reproject(
projection_model._crs,
resampling=Resampling[resampling_dict[var]],
shape=(dim, dim),
transform=dst_transform,
)

# create the data array for each level
plevels[lkey] = xr.Dataset(attrs=ds.attrs)
for k, da in ds.items():
if len(da.shape) == 4:
# if extra_dim is not specified, raise an error
if extra_dim is None:
raise ValueError("must specify 'extra_dim' to iterate over 4d data")
da_all = []
for index in ds[extra_dim]:
# reproject each index of the 4th dimension
da_reprojected = reproject(da.sel({extra_dim: index}), k)
da_all.append(da_reprojected)
plevels[lkey][k] = xr.concat(da_all, ds[extra_dim])
else:
# if the data array is not 4D, just reproject it
plevels[lkey][k] = reproject(da, k)
plevels[str(level)] = level_reproject(ds, projection_model=projection_model, level=level, pixels_per_tile=pixels_per_tile, resampling_dict=resampling_dict, extra_dim=extra_dim)

# create the final multiscale pyramid
plevels['/'] = xr.Dataset(attrs=attrs)
Expand Down

0 comments on commit 708296f

Please sign in to comment.