Skip to content

Commit

Permalink
add MultiRasterSource.from_stac() constructor (#2156)
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH authored Jun 5, 2024
1 parent 4a613e8 commit b74e30e
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 15 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from typing import Optional, Sequence, List, Tuple
from typing import TYPE_CHECKING, Optional, Sequence, Self, Tuple
from pydantic import conint

import numpy as np
from pystac import Item

from rastervision.core.box import Box
from rastervision.core.data.raster_source import RasterSource
from rastervision.core.data.crs_transformer import CRSTransformer
from rastervision.core.data.raster_source import RasterSource, RasterioSource
from rastervision.core.data.raster_source.stac_config import subset_assets
from rastervision.core.data.utils import all_equal

if TYPE_CHECKING:
from rastervision.core.data import RasterTransformer, CRSTransformer


class MultiRasterSource(RasterSource):
"""Merge multiple ``RasterSources`` by concatenating along channel dim."""
Expand Down Expand Up @@ -69,6 +73,83 @@ def __init__(self,

self.validate_raster_sources()

@classmethod
def from_stac(
cls,
item: Item,
assets: list[str] | None,
primary_source_idx: conint(ge=0) = 0,
raster_transformers: list['RasterTransformer'] = [],
force_same_dtype: bool = False,
channel_order: Sequence[int] | None = None,
bbox: Box | tuple[int, int, int, int] | None = None,
bbox_map_coords: Box | tuple[int, int, int, int] | None = None,
allow_streaming: bool = False) -> Self:
"""Construct a ``MultiRasterSource`` from a STAC Item.
This creates a :class:`.RasterioSource` for each asset and puts all
the raster sources together into a ``MultiRasterSource``. If ``assets``
is not specified, all the assets in the STAC item are used.
Only assets that are readable by rasterio are supported.
Args:
item: STAC Item.
assets: List of names of assets to use. If ``None``, all assets
present in the item will be used. Defaults to ``None``.
primary_source_idx (0 <= int < len(raster_sources)): Index of the
raster source whose CRS, dtype, and other attributes will
override those of the other raster sources.
raster_transformers: RasterTransformers to use to transform chips
after they are read.
force_same_dtype: If true, force all sub-chips to have the
same dtype as the primary_source_idx-th sub-chip. No careful
conversion is done, just a quick cast. Use with caution.
channel_order: List of indices of channels to extract from raw
imagery. Can be a subset of the available channels. If None,
all channels available in the image will be read.
Defaults to None.
bbox: User-specified crop of the extent. Can be :class:`.Box` or
(ymin, xmin, ymax, xmax) tuple. If None, the full extent
available in the source file is used. Mutually exclusive with
``bbox_map_coords``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords. Can be
:class:`.Box` or (ymin, xmin, ymax, xmax) tuple. Useful for
cropping the raster source so that only part of the raster is
read from. Mutually exclusive with ``bbox``.
Defaults to ``None``.
allow_streaming: Passed to :class:`.RasterioSource`. If ``False``,
assets will be downloaded. Defaults to ``True``.
"""
if bbox is not None and bbox_map_coords is not None:
raise ValueError('Specify either bbox or bbox_map_coords, '
'but not both.')

if assets is not None:
item = subset_assets(item, assets)

uris = [asset.href for asset in item.assets.values()]
raster_sources = [
RasterioSource(uri, allow_streaming=allow_streaming)
for uri in uris
]

crs_transformer = raster_sources[primary_source_idx].crs_transformer
if bbox_map_coords is not None:
bbox_map_coords = Box(*bbox_map_coords)
bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize()
elif bbox is not None:
bbox = Box(*bbox)

raster_source = MultiRasterSource(
raster_sources,
primary_source_idx=primary_source_idx,
raster_transformers=raster_transformers,
channel_order=channel_order,
force_same_dtype=force_same_dtype,
bbox=bbox)
return raster_source

def validate_raster_sources(self) -> None:
"""Validate sub-``RasterSources``.
Expand Down Expand Up @@ -101,13 +182,13 @@ def dtype(self) -> np.dtype:
return self.primary_source.dtype

@property
def crs_transformer(self) -> CRSTransformer:
def crs_transformer(self) -> 'CRSTransformer':
return self.primary_source.crs_transformer

def _get_sub_chips(self,
window: Box,
out_shape: Optional[Tuple[int, int]] = None
) -> List[np.ndarray]:
) -> list[np.ndarray]:
"""Return chips from sub raster sources as a list.
If all extents are identical, simply retrieves chips from each sub
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def __init__(self,

self.validate_raster_sources()

@classmethod
def from_stac(cls, *args, **kwargs):
"""Not implemented for ``TemporalMultiRasterSource``."""
raise NotImplementedError(
'Create raster sources by calling MultiRasterSource.from_stac() '
'on each Item and then pass them to TemporalMultiRasterSource.')

def _get_chip(self,
window: Box,
out_shape: Optional[Tuple[int, int]] = None) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,11 @@ def __init__(self,
@classmethod
def from_stac(
cls,
item_or_item_collection: Union['Item', 'ItemCollection'],
raster_transformers: List['RasterTransformer'] = [],
channel_order: Optional[Sequence[int]] = None,
bbox: Optional[Box] = None,
bbox_map_coords: Optional[Box] = None,
item_or_item_collection: 'Item | ItemCollection',
raster_transformers: list['RasterTransformer'] = [],
channel_order: Sequence[int] | None = None,
bbox: Box | tuple[int, int, int, int] | None = None,
bbox_map_coords: Box | tuple[int, int, int, int] | None = None,
temporal: bool = False,
allow_streaming: bool = False,
stackstac_args: dict = dict(rescale=False)) -> 'XarraySource':
Expand All @@ -113,13 +113,15 @@ def from_stac(
imagery. Can be a subset of the available channels. If None,
all channels available in the image will be read.
Defaults to None.
bbox: User-specified crop of the extent. If None, the full extent
bbox: User-specified crop of the extent. Can be :class:`.Box` or
(ymin, xmin, ymax, xmax) tuple. If None, the full extent
available in the source file is used. Mutually exclusive with
``bbox_map_coords``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords of the
form (ymin, xmin, ymax, xmax). Useful for cropping the raster
source so that only part of the raster is read from. Mutually
exclusive with ``bbox``. Defaults to ``None``.
bbox_map_coords: User-specified bbox in EPSG:4326 coords. Can be
:class:`.Box` or (ymin, xmin, ymax, xmax) tuple. Useful for
cropping the raster source so that only part of the raster is
read from. Mutually exclusive with ``bbox``.
Defaults to ``None``.
temporal: If True, data_array is expected to have a "time"
dimension and the chips returned will be of shape (T, H, W, C).
allow_streaming: If False, load the entire DataArray into memory.
Expand Down
37 changes: 37 additions & 0 deletions tests/core/data/raster_source/test_multi_raster_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
from xarray import DataArray
from pystac import Item

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.core.box import Box
Expand Down Expand Up @@ -84,6 +85,12 @@ def test_build_temporal(self):


class TestMultiRasterSource(unittest.TestCase):
def assertNoError(self, fn: Callable, msg: str = ''):
try:
fn()
except Exception:
self.fail(msg)

def setUp(self):
self.tmp_dir_obj = get_tmp_dir()
self.tmp_dir = self.tmp_dir_obj.name
Expand Down Expand Up @@ -218,6 +225,36 @@ def test_temporal_sub_raster_sources(self):
chip_expected[..., 4:] *= np.arange(4, dtype=np.uint8)
np.testing.assert_array_equal(chip, chip_expected)

def test_from_stac(self):
item = Item.from_file(data_file_path('stac/item.json'))

# avoid reading actual remote files
mock_raster_uri = data_file_path('ones.tif')
item.assets['red'].__setattr__('href', mock_raster_uri)
item.assets['green'].__setattr__('href', mock_raster_uri)

# test bbox
bbox = Box(ymin=0, xmin=0, ymax=100, xmax=100)
rs = MultiRasterSource.from_stac(
item, assets=['red', 'green'], bbox=bbox)
self.assertEqual(rs.bbox, bbox)

# test bbox_map_coords
bbox_map_coords = Box(
ymin=29.978710, xmin=31.134949, ymax=29.977309, xmax=31.136567)
rs = MultiRasterSource.from_stac(
item, assets=['red', 'green'], bbox_map_coords=bbox_map_coords)
self.assertEqual(rs.bbox, Box(ymin=50, xmin=50, ymax=206, xmax=206))

# test error if both bbox and bbox_map_coords specified
args = dict(
item=item,
assets=['red', 'green'],
bbox=bbox,
bbox_map_coords=bbox_map_coords)
self.assertRaises(ValueError,
lambda: MultiRasterSource.from_stac(**args))


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def test_getitem(self):
], dtype=dtype)
np.testing.assert_array_equal(chip, chip_expected)

def test_from_stac(self):
self.assertRaises(NotImplementedError,
TemporalMultiRasterSource.from_stac)


if __name__ == '__main__':
unittest.main()

0 comments on commit b74e30e

Please sign in to comment.