Skip to content

Commit

Permalink
Merge pull request #1980 from AdeelH/xarray_source_config
Browse files Browse the repository at this point in the history
Add `XarraySourceConfig` to allow specifying an `XarraySource` from STAC Items
  • Loading branch information
AdeelH authored Nov 6, 2023
2 parents a7d56de + 1afb81d commit 885f779
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from rastervision.core.data.raster_source.multi_raster_source import *
from rastervision.core.data.raster_source.multi_raster_source_config import *
from rastervision.core.data.raster_source.xarray_source import *
from rastervision.core.data.raster_source.xarray_source_config import *
from rastervision.core.data.raster_source.temporal_multi_raster_source import *
from rastervision.core.data.raster_source.stac_config import *

__all__ = [
RasterSource.__name__,
Expand All @@ -21,5 +23,8 @@
MultiRasterSource.__name__,
MultiRasterSourceConfig.__name__,
XarraySource.__name__,
XarraySourceConfig.__name__,
TemporalMultiRasterSource.__name__,
STACItemConfig.__name__,
STACItemCollectionConfig.__name__,
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from typing import TYPE_CHECKING, List, Optional

from rastervision.pipeline.config import (Config, Field, register_config)
from rastervision.pipeline.file_system.utils import file_to_json

if TYPE_CHECKING:
from pystac import Item, ItemCollection


@register_config('stac_item')
class STACItemConfig(Config):
"""Specify a raster via a STAC Item."""

uri: str = Field(..., description='URI to a JSON-serialized STAC Item.')
assets: Optional[List[str]] = Field(
None,
description=
'Subset of assets to use. This should be a list of asset keys')

def build(self) -> 'Item':
from pystac import Item

item = Item.from_dict(file_to_json(self.uri))
if self.assets is not None:
item = subset_assets(item, self.assets)
return item


@register_config('stac_item_collection')
class STACItemCollectionConfig(Config):
"""Specify a raster via a STAC ItemCollection."""

uri: str = Field(
..., description='URI to a JSON-serialized STAC ItemCollection.')
assets: Optional[List[str]] = Field(
None,
description=
'Subset of assets to use. This should be a list of asset keys')

def build(self) -> 'ItemCollection':
from pystac import ItemCollection

items = ItemCollection.from_dict(file_to_json(self.uri))
if self.assets is not None:
items = [subset_assets(item, self.assets) for item in items]
items = ItemCollection(items)
return items


def subset_assets(item: 'Item', assets: List[str]) -> 'Item':
"""Subset the assets in a STAC Item."""
src_assets = item.assets
item.assets = {k: src_assets[k] for k in assets}
return item
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Optional, Tuple, Union
import logging

from rastervision.pipeline.config import Field, register_config
from rastervision.core.box import Box
from rastervision.core.data.raster_source.raster_source_config import (
RasterSourceConfig)
from rastervision.core.data.crs_transformer import RasterioCRSTransformer
from rastervision.core.data.raster_source.stac_config import (
STACItemConfig, STACItemCollectionConfig)
from rastervision.core.data.raster_source.xarray_source import (XarraySource)

log = logging.getLogger(__name__)


@register_config('xarray_source')
class XarraySourceConfig(RasterSourceConfig):
"""Configure an :class:`.XarraySource`."""

stac: Union[STACItemConfig, STACItemCollectionConfig] = Field(
...,
description='STAC Item or ItemCollection to build the DataArray from.')
allow_streaming: bool = Field(
True,
description='If False, load the entire DataArray into memory. '
'Defaults to True.')
bbox_map_coords: Optional[Tuple[float, float, float, float]] = Field(
None,
description='Optional user-specified bbox in EPSG:4326 coords of the '
'form (ymin, xmin, ymax, xmax). Useful for cropping the raster source '
'so that only part of the raster is read from. This is ignored if '
'bbox is also specified. Defaults to None.')
temporal: bool = Field(
False, description='Whether the data is a time-series.')

def build(self,
tmp_dir: Optional[str] = None,
use_transformers: bool = True) -> XarraySource:
import stackstac

item_or_item_collection = self.stac.build()
data_array = stackstac.stack(item_or_item_collection)

if not self.temporal and 'time' in data_array.dims:
if len(data_array.time) > 1:
raise ValueError('temporal=False but len(data_array.time) > 1')
data_array = data_array.isel(time=0)

if not self.allow_streaming:
from humanize import naturalsize
log.info('Loading the full DataArray into memory '
f'({naturalsize(data_array.nbytes)}).')
data_array.load()

crs_transformer = RasterioCRSTransformer(
transform=data_array.transform, image_crs=data_array.crs)
raster_transformers = ([rt.build() for rt in self.transformers]
if use_transformers else [])

if self.bbox is not None:
log.info('Using bbox and ignoring bbox_map_coords.')
bbox = Box(*self.bbox)
elif self.bbox_map_coords is not None:
bbox_map_coords = Box(*self.bbox_map_coords)
bbox = crs_transformer.map_to_pixel(bbox_map_coords).normalize()
else:
bbox = None

raster_source = XarraySource(
data_array,
crs_transformer=crs_transformer,
raster_transformers=raster_transformers,
channel_order=self.channel_order,
bbox=bbox,
temporal=self.temporal)
return raster_source
4 changes: 3 additions & 1 deletion rastervision_core/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ numpy==1.25.0
pillow==10.0.1
pyproj==3.4.0
rasterio==1.3.7
pystac==1.6.1
pystac==1.9.0
scikit-learn==1.2.2
scipy==1.10.1
opencv-python-headless==4.6.0.66
tqdm==4.65.0
xarray==2023.2.0
scikit-image==0.21.0
boto3==1.28.8
stackstac==0.5.0
humanize==4.8.0
34 changes: 34 additions & 0 deletions tests/core/data/raster_source/test_stac_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import unittest

from pystac import Item, ItemCollection

from rastervision.core.data.raster_source import (STACItemConfig,
STACItemCollectionConfig)

from tests import data_file_path


class TestSTACItemConfig(unittest.TestCase):
def test_build(self):
uri = data_file_path('stac/item.json')
cfg = STACItemConfig(uri=uri, assets=['red'])
item = cfg.build()
self.assertIsInstance(item, Item)
self.assertEqual(len(item.assets), 1)
self.assertIn('red', item.assets)


class TestSTACItemCollectionConfig(unittest.TestCase):
def test_build(self):
uri = data_file_path('stac/item_collection.json')
cfg = STACItemCollectionConfig(uri=uri, assets=['red'])
items = cfg.build()
self.assertIsInstance(items, ItemCollection)
self.assertEqual(len(items), 3)
for item in items:
self.assertEqual(len(item.assets), 1)
self.assertIn('red', item.assets)


if __name__ == '__main__':
unittest.main()
46 changes: 44 additions & 2 deletions tests/core/data/raster_source/test_xarray_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,50 @@

from rastervision.core.box import Box
from rastervision.core.data.crs_transformer import IdentityCRSTransformer
from rastervision.core.data.raster_source import (ChannelOrderError,
XarraySource)
from rastervision.core.data.raster_source import (
ChannelOrderError, XarraySource, XarraySourceConfig, STACItemConfig,
STACItemCollectionConfig)

from tests import data_file_path


class TestXarraySourceConfig(unittest.TestCase):
def test_build_with_item(self):
bbox = Box(
ymin=48.8155755, xmin=2.224122, ymax=48.902156, xmax=2.4697602)
item_uri = data_file_path('stac/item.json')
cfg = XarraySourceConfig(
stac=STACItemConfig(uri=item_uri, assets=['red']),
bbox_map_coords=tuple(bbox),
allow_streaming=True,
temporal=False,
)
rs = cfg.build()
self.assertIsInstance(rs, XarraySource)
self.assertFalse(rs.temporal)

def test_build_with_item_collection(self):
bbox = Box(
ymin=48.8155755, xmin=2.224122, ymax=48.902156, xmax=2.4697602)
item_coll_uri = data_file_path('stac/item_collection.json')
cfg = XarraySourceConfig(
stac=STACItemCollectionConfig(uri=item_coll_uri, assets=['red']),
bbox_map_coords=tuple(bbox),
allow_streaming=True,
temporal=False,
)
with self.assertRaises(ValueError):
rs = cfg.build()

cfg = XarraySourceConfig(
stac=STACItemCollectionConfig(uri=item_coll_uri, assets=['red']),
bbox_map_coords=tuple(bbox),
allow_streaming=True,
temporal=True,
)
rs = cfg.build()
self.assertIsInstance(rs, XarraySource)
self.assertTrue(rs.temporal)


class TestXarraySource(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions tests/data_files/stac/item.json

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions tests/data_files/stac/item_collection.json

Large diffs are not rendered by default.

0 comments on commit 885f779

Please sign in to comment.