diff --git a/sup3r/preprocessing/data_handlers/__init__.py b/sup3r/preprocessing/data_handlers/__init__.py index 1216ac9f7..73cfa97d4 100644 --- a/sup3r/preprocessing/data_handlers/__init__.py +++ b/sup3r/preprocessing/data_handlers/__init__.py @@ -1,10 +1,10 @@ """Composite objects built from loaders, rasterizers, and derivers.""" -from .exo import ExoData, ExoDataHandler, SingleExoDataStep -from .factory import ( +from .base import ( DailyDataHandler, DataHandler, DataHandlerH5SolarCC, DataHandlerH5WindCC, ) +from .exo import ExoData, ExoDataHandler, SingleExoDataStep from .nc_cc import DataHandlerNCforCC, DataHandlerNCforCCwithPowerLaw diff --git a/sup3r/preprocessing/data_handlers/factory.py b/sup3r/preprocessing/data_handlers/base.py similarity index 81% rename from sup3r/preprocessing/data_handlers/factory.py rename to sup3r/preprocessing/data_handlers/base.py index fa22a2923..8ebe322be 100644 --- a/sup3r/preprocessing/data_handlers/factory.py +++ b/sup3r/preprocessing/data_handlers/base.py @@ -1,7 +1,7 @@ -"""DataHandler objects, which are built through composition of +"""DataHandler objects, which inherit from the +:class:`~sup3r.preprocessing.derivers.Deriver` class and compose :class:`~sup3r.preprocessing.rasterizers.Rasterizer`, :class:`~sup3r.preprocessing.loaders.Loader`, -:class:`~sup3r.preprocessing.derivers.Deriver`, and :class:`~sup3r.preprocessing.cachers.Cacher` classes. TODO: If ``.data`` is a ``Sup3rDataset`` with more than one member only the @@ -12,7 +12,7 @@ import logging from typing import Callable, Dict, Optional, Union -from rex import MultiFileNSRDBX +from rex import MultiFileNSRDBX, MultiFileWindX from sup3r.preprocessing.base import ( Sup3rDataset, @@ -38,8 +38,7 @@ class DataHandler(Deriver): """Base DataHandler. Composes :class:`~sup3r.preprocessing.rasterizers.Rasterizer`, - :class:`~sup3r.preprocessing.loaders.Loader`, - :class:`~sup3r.preprocessing.derivers.Deriver`, and + :class:`~sup3r.preprocessing.loaders.Loader`, and :class:`~sup3r.preprocessing.cachers.Cacher` classes.""" @log_args @@ -197,7 +196,7 @@ def __init__( ) if cache_kwargs is not None and 'cache_pattern' in cache_kwargs: - _ = Cacher(data=self.data, cache_kwargs=cache_kwargs) + self.cacher = Cacher(data=self.data, cache_kwargs=cache_kwargs) self._deriver_hook() def _rasterizer_hook(self): @@ -313,74 +312,17 @@ def _deriver_hook(self): self.data = Sup3rDataset(daily=daily_data, hourly=hourly_data) -def DataHandlerFactory(cls, BaseLoader=None, FeatureRegistry=None, name=None): - """Build composite objects that load from file_paths, rasterize a specified - region, derive new features, and cache derived data. - - Parameters - ---------- - BaseLoader : Callable - Optional base loader update. The default for H5 is MultiFileWindX and - for NETCDF the default is xarray - FeatureRegistry : Dict[str, DerivedFeature] - Dictionary of compute methods for features. This is used to look up how - to derive features that are not contained in the raw loaded data. - name : str - Optional class name, used to resolve `repr(Class)` and distinguish - partially initialized DataHandlers with different FeatureRegistrys - """ - - class FactoryDataHandler(cls): - """FactoryDataHandler object. Is a partially initialized instance with - `BaseLoader`, `FeatureRegistry`, and `name` set.""" - - FEATURE_REGISTRY = FeatureRegistry or None - BASE_LOADER = BaseLoader or None - __name__ = name or 'FactoryDataHandler' - - def __init__(self, file_paths, features='all', **kwargs): - """ - Parameters - ---------- - file_paths : str | list | pathlib.Path - file_paths input to LoaderClass - features : list | str - Features to load and / or derive. If 'all' then all available - raw features will be loaded. Specify explicit feature names for - derivations. - kwargs : dict - kwargs for parent class, except for FeatureRegistry and - BaseLoader - """ - - if 'FeatureRegistry' in kwargs: - self.FEATURE_REGISTRY.update(kwargs.pop('FeatureRegistry')) +class DataHandlerH5SolarCC(DailyDataHandler): + """Extended ``DailyDataHandler`` specifically for handling H5 data for + SolarCC applications""" - super().__init__( - file_paths, - features=features, - BaseLoader=self.BASE_LOADER, - FeatureRegistry=self.FEATURE_REGISTRY, - **kwargs, - ) + BASE_LOADER = MultiFileNSRDBX + FEATURE_REGISTRY = RegistryH5SolarCC - _signature_objs = (cls,) - _skip_params = ('FeatureRegistry', 'BaseLoader') - return FactoryDataHandler +class DataHandlerH5WindCC(DailyDataHandler): + """Extended ``DailyDataHandler`` specifically for handling H5 data for + WindCC applications""" - -DataHandlerH5SolarCC = DataHandlerFactory( - DailyDataHandler, - BaseLoader=MultiFileNSRDBX, - FeatureRegistry=RegistryH5SolarCC, - name='DataHandlerH5SolarCC', -) - - -DataHandlerH5WindCC = DataHandlerFactory( - DailyDataHandler, - BaseLoader=MultiFileNSRDBX, - FeatureRegistry=RegistryH5WindCC, - name='DataHandlerH5WindCC', -) + BASE_LOADER = MultiFileWindX + FEATURE_REGISTRY = RegistryH5WindCC diff --git a/sup3r/preprocessing/data_handlers/nc_cc.py b/sup3r/preprocessing/data_handlers/nc_cc.py index 395197da5..1187a75fb 100644 --- a/sup3r/preprocessing/data_handlers/nc_cc.py +++ b/sup3r/preprocessing/data_handlers/nc_cc.py @@ -16,18 +16,17 @@ from sup3r.preprocessing.names import Dimension from sup3r.preprocessing.utilities import log_args -from .factory import DataHandler, DataHandlerFactory +from .base import DataHandler logger = logging.getLogger(__name__) -BaseNCforCC = DataHandlerFactory(DataHandler, FeatureRegistry=RegistryNCforCC) - - -class DataHandlerNCforCC(BaseNCforCC): +class DataHandlerNCforCC(DataHandler): """Extended NETCDF data handler. This implements a rasterizer hook to add "clearsky_ghi" to the rasterized data if "clearsky_ghi" is requested.""" + FEATURE_REGISTRY = RegistryNCforCC + @log_args def __init__( self, @@ -78,8 +77,7 @@ def __init__( self._cs_ghi_scale = 1 super().__init__(file_paths=file_paths, features=features, **kwargs) - _signature_objs = (__init__, BaseNCforCC) - _skip_params = ('name', 'FeatureRegistry') + _signature_objs = (__init__, DataHandler) def _rasterizer_hook(self): """Rasterizer hook implementation to add 'clearsky_ghi' data to diff --git a/sup3r/preprocessing/derivers/base.py b/sup3r/preprocessing/derivers/base.py index 9a7670528..3c2ce0675 100644 --- a/sup3r/preprocessing/derivers/base.py +++ b/sup3r/preprocessing/derivers/base.py @@ -84,9 +84,9 @@ def _check_registry(self, feature) -> Union[Type[DerivedFeature], None]: keys. Return the corresponding value if found.""" if feature.lower() in self.FEATURE_REGISTRY: return self.FEATURE_REGISTRY[feature.lower()] - for pattern in self.FEATURE_REGISTRY: + for pattern, method in self.FEATURE_REGISTRY.items(): if re.match(pattern.lower(), feature.lower()): - return self.FEATURE_REGISTRY[pattern] + return method return None def _get_inputs(self, feature): diff --git a/tests/docs/test_doc_automation.py b/tests/docs/test_doc_automation.py index 0bd5d154d..a296b1b33 100644 --- a/tests/docs/test_doc_automation.py +++ b/tests/docs/test_doc_automation.py @@ -115,6 +115,8 @@ def test_nc_for_cc_sig(): 'chunks', 'interp_kwargs', 'nan_method_kwargs', + 'BaseLoader', + 'FeatureRegistry', ] sig = signature(DataHandlerNCforCC) init_sig = signature(DataHandlerNCforCC.__init__) diff --git a/tests/training/test_train_exo_cc.py b/tests/training/test_train_exo_cc.py index 2888b1e62..30225c8d4 100644 --- a/tests/training/test_train_exo_cc.py +++ b/tests/training/test_train_exo_cc.py @@ -9,7 +9,7 @@ from sup3r.models import Sup3rGan from sup3r.preprocessing.batch_handlers.factory import BatchHandlerCC -from sup3r.preprocessing.data_handlers.factory import DataHandlerH5WindCC +from sup3r.preprocessing.data_handlers.base import DataHandlerH5WindCC from sup3r.preprocessing.utilities import lowered from sup3r.utilities.utilities import RANDOM_GENERATOR