diff --git a/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py b/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py index 303549d30..ae618840d 100644 --- a/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py +++ b/rastervision_core/rastervision/core/data/crs_transformer/rasterio_crs_transformer.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable import logging from pyproj import Transformer @@ -17,6 +17,38 @@ log = logging.getLogger(__name__) +def pyproj_wrapper( + func: Callable[..., tuple[Any, Any]], + from_crs: str, + to_crs: str, +) -> Callable[..., tuple[Any, Any]]: + # For some transformations, pyproj attempts to download transformation + # grids from the internet for improved accuracy when + # Transformer.transform() is called. If it fails to connect to the + # internet, it silently returns (inf, inf) and silently modifies its + # behavior to not access the internet on subsequent calls, causing + # them to succeed (though possibly with a loss of accuracy). See + # https://github.com/pyproj4/pyproj/issues/705 for details. + # + # The below workaround forces an error to be raised by setting + # errcheck=True and ignoring the first error. + def _wrapper(*args, **kwargs): + try: + return func(*args, **kwargs, errcheck=True) + except ProjError as e: + log.debug(f'pyproj: {e}') + if 'network' in str(e).lower(): + log.warning( + 'pyproj tried and failed to connect to the internet to ' + 'download transformation grids for the transformation from' + f'\n{from_crs}\nto\n{to_crs}.\nSee ' + 'https://github.com/pyproj4/pyproj/issues/705 for details.' + ) + return func(*args, **kwargs, errcheck=True) + + return _wrapper + + class RasterioCRSTransformer(CRSTransformer): """Transformer for a RasterioRasterSource.""" @@ -40,10 +72,14 @@ def __init__(self, self.map2image = lambda *args, **kws: args[:2] self.image2map = lambda *args, **kws: args[:2] else: - self.map2image = Transformer.from_crs( + self._map2image = Transformer.from_crs( map_crs, image_crs, always_xy=True).transform - self.image2map = Transformer.from_crs( + self._image2map = Transformer.from_crs( image_crs, map_crs, always_xy=True).transform + self.map2image = pyproj_wrapper(self._map2image, map_crs, + image_crs) + self.image2map = pyproj_wrapper(self._image2map, image_crs, + map_crs) self.round_pixels = round_pixels @@ -82,29 +118,7 @@ def _map_to_pixel( Returns: (x, y) tuple in pixel coordinates """ - # For some transformations, pyproj attempts to download transformation - # grids from the internet for improved accuracy when - # Transformer.transform() is called. If it fails to connect to the - # internet, it silently returns (inf, inf) and silently modifies its - # behavior to not access the internet on subsequent calls, causing - # them to succeed (though possibly with a loss of accuracy). See - # https://github.com/pyproj4/pyproj/issues/705 for details. - # - # The below workaround forces an error to be raised by setting - # errcheck=True and ignoring the first error. - try: - image_point = self.map2image(*map_point, errcheck=True) - except ProjError as e: - log.debug(f'pyproj: {e}') - if 'network' in str(e).lower(): - log.warning( - 'pyproj tried and failed to connect to the internet to ' - 'download transformation grids for the transformation from\n' - f'{self.map_crs}\nto\n{self.image_crs}.\nSee ' - 'https://github.com/pyproj4/pyproj/issues/705 for details.' - ) - image_point = self.map2image(*map_point, errcheck=True) - + image_point = self.map2image(*map_point) x, y = image_point if self.round_pixels: row, col = rowcol(self.transform, x, y) @@ -129,29 +143,7 @@ def _pixel_to_map( col = col.astype(int) if isinstance(col, np.ndarray) else int(col) row = row.astype(int) if isinstance(row, np.ndarray) else int(row) image_point = xy(self.transform, row, col, offset='center') - - # For some transformations, pyproj attempts to download transformation - # grids from the internet for improved accuracy when - # Transformer.transform() is called. If it fails to connect to the - # internet, it silently returns (inf, inf) and silently modifies its - # behavior to not access the internet on subsequent calls, causing - # them to succeed (though possibly with a loss of accuracy). See - # https://github.com/pyproj4/pyproj/issues/705 for details. - # - # The below workaround forces an error to be raised by setting - # errcheck=True and ignoring the first error. - try: - map_point = self.image2map(*image_point, errcheck=True) - except ProjError as e: - log.debug(f'pyproj: {e}') - if 'network' in str(e).lower(): - log.warning( - 'pyproj tried and failed to connect to the internet to ' - 'download transformation grids for the transformation from' - f'\n{self.image_crs}\nto\n{self.map_crs}.\nSee ' - 'https://github.com/pyproj4/pyproj/issues/705 for details.' - ) - map_point = self.image2map(*image_point, errcheck=True) + map_point = self.image2map(*image_point) return map_point @classmethod