From 498fb1aef99ceed2f7bcf0f46da3bb00866c2936 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 31 Jan 2024 16:55:50 -0500 Subject: [PATCH] wip --- .../rastervision/core/__init__.py | 2 +- .../core/rv_pipeline/rv_pipeline_config.py | 19 ++++++++- .../semantic_segmentation_config.py | 39 +++++++++++++------ .../rastervision/pipeline/config.py | 24 +++++++----- .../rastervision/pipeline/registry.py | 11 ++++++ .../rastervision/pytorch_learner/__init__.py | 9 ++++- .../pytorch_learner/learner_config.py | 9 ++++- 7 files changed, 85 insertions(+), 28 deletions(-) diff --git a/rastervision_core/rastervision/core/__init__.py b/rastervision_core/rastervision/core/__init__.py index e8506dbb00..2365cb0f09 100644 --- a/rastervision_core/rastervision/core/__init__.py +++ b/rastervision_core/rastervision/core/__init__.py @@ -2,7 +2,7 @@ def register_plugin(registry): - registry.set_plugin_version('rastervision.core', 10) + registry.set_plugin_version('rastervision.core', 11) from rastervision.core.cli import predict, predict_scene registry.add_plugin_command(predict) registry.add_plugin_command(predict_scene) diff --git a/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py b/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py index 9e54538e70..6b3f5e7e14 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py +++ b/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline_config.py @@ -8,7 +8,8 @@ from rastervision.core.backend import BackendConfig from rastervision.core.evaluation import EvaluatorConfig from rastervision.core.analyzer import AnalyzerConfig -from rastervision.core.rv_pipeline.chip_options import ChipOptions +from rastervision.core.rv_pipeline.chip_options import (ChipOptions, + WindowSamplingConfig) from rastervision.pipeline.config import (Config, Field, register_config) if TYPE_CHECKING: @@ -21,7 +22,21 @@ class PredictOptions(Config): pass -@register_config('rv_pipeline') +def rv_pipeline_config_upgrader(cfg_dict: dict, version: int) -> dict: + if version == 10: + train_chip_sz = cfg_dict.pop('train_chip_sz', 300) + nodata_threshold = cfg_dict.pop('chip_nodata_threshold') + if 'chip_options' not in cfg_dict: + cfg_dict['chip_options'] = ChipOptions( + sampling=WindowSamplingConfig(size=train_chip_sz), + nodata_threshold=nodata_threshold) + else: + cfg_dict['chip_options']['sampling']['size'] = train_chip_sz + cfg_dict['chip_options']['nodata_threshold'] = nodata_threshold + return cfg_dict + + +@register_config('rv_pipeline', upgrader=rv_pipeline_config_upgrader) class RVPipelineConfig(PipelineConfig): """Configure an :class:`.RVPipeline`.""" diff --git a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py index 6d390a10ee..df69ab27a6 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py +++ b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation_config.py @@ -5,25 +5,28 @@ from rastervision.pipeline.config import (ConfigError, register_config, Field, validator) -from rastervision.core.rv_pipeline.rv_pipeline_config import ( - ChipOptions, PredictOptions, RVPipelineConfig) +from rastervision.core.rv_pipeline.rv_pipeline_config import (PredictOptions, + RVPipelineConfig) +from rastervision.core.rv_pipeline.chip_options import (ChipOptions, + WindowSamplingConfig) from rastervision.core.data import SemanticSegmentationLabelStoreConfig from rastervision.core.evaluation import SemanticSegmentationEvaluatorConfig -def ss_config_upgrader(cfg_dict: dict, version: int) -> dict: - if version < 1: - try: - # removed in version 1 - del cfg_dict['channel_display_groups'] - del cfg_dict['img_format'] - del cfg_dict['label_format'] - except KeyError: - pass +def ss_chip_options_upgrader(cfg_dict: dict, version: int) -> dict: + if version == 10: + sampling = WindowSamplingConfig( + method=cfg_dict.pop('window_method', None), + size=300, + stride=cfg_dict.pop('stride', None), + max_windows=cfg_dict.pop('chips_per_scene', None), + ) + cfg_dict['sampling'] = sampling.dict() return cfg_dict -@register_config('semantic_segmentation_chip_options') +@register_config( + 'semantic_segmentation_chip_options', upgrader=ss_chip_options_upgrader) class SemanticSegmentationChipOptions(ChipOptions): """Chipping options for semantic segmentation.""" target_class_ids: Optional[List[int]] = Field( @@ -99,6 +102,18 @@ def validate_crop_sz(cls, return crop_sz +def ss_config_upgrader(cfg_dict: dict, version: int) -> dict: + if version == 0: + try: + # removed in version 1 + del cfg_dict['channel_display_groups'] + del cfg_dict['img_format'] + del cfg_dict['label_format'] + except KeyError: + pass + return cfg_dict + + @register_config('semantic_segmentation', upgrader=ss_config_upgrader) class SemanticSegmentationConfig(RVPipelineConfig): """Configure a :class:`.SemanticSegmentation` pipeline.""" diff --git a/rastervision_pipeline/rastervision/pipeline/config.py b/rastervision_pipeline/rastervision/pipeline/config.py index 891146d4cb..ecd6dca909 100644 --- a/rastervision_pipeline/rastervision/pipeline/config.py +++ b/rastervision_pipeline/rastervision/pipeline/config.py @@ -182,16 +182,20 @@ def _upgrade_config(x: Union[dict, List[dict]], plugin_versions: Dict[str, int] for k, v in x.items(): new_x[k] = _upgrade_config(v, plugin_versions) type_hint = new_x.get('type_hint') - if type_hint is not None: - type_hint_lineage = registry.get_type_hint_lineage(type_hint) - for th in type_hint_lineage: - plugin = registry.get_plugin(th) - old_version = plugin_versions[plugin] - curr_version = registry.get_plugin_version(plugin) - upgrader = registry.get_upgrader(th) - if upgrader: - for version in range(old_version, curr_version): - new_x = upgrader(new_x, version) + if type_hint is None: + return new_x + if type_hint in registry.renamed_type_hints: + type_hint = registry.renamed_type_hints[type_hint] + new_x['type_hint'] = type_hint + type_hint_lineage = registry.get_type_hint_lineage(type_hint) + for th in type_hint_lineage: + plugin = registry.get_plugin(th) + old_version = plugin_versions[plugin] + curr_version = registry.get_plugin_version(plugin) + upgrader = registry.get_upgrader(th) + if upgrader: + for version in range(old_version, curr_version): + new_x = upgrader(new_x, version) return new_x elif isinstance(x, list): return [_upgrade_config(v, plugin_versions) for v in x] diff --git a/rastervision_pipeline/rastervision/pipeline/registry.py b/rastervision_pipeline/rastervision/pipeline/registry.py index 5293e1c877..dcc65a0023 100644 --- a/rastervision_pipeline/rastervision/pipeline/registry.py +++ b/rastervision_pipeline/rastervision/pipeline/registry.py @@ -27,6 +27,7 @@ def __init__(self): self.type_hint_to_lineage = {} self.type_hint_to_plugin = {} self.type_hint_to_upgrader = {} + self.renamed_type_hints = {} def add_plugin_command(self, cmd: Command): """Add a click command contributed by a plugin.""" @@ -56,6 +57,16 @@ def set_plugin_version(self, plugin: str, version: int): """ self.plugin_versions[plugin] = version + def register_renamed_type_hints(self, type_hint_old: str, + type_hint_new: str): + """Register renamed type_hints. + + Args: + type_hint_old: Old type hint. + type_hint_new: New type hint. + """ + self.renamed_type_hints[type_hint_old] = type_hint_new + def get_type_hint_lineage(self, type_hint: str) -> List[str]: """Get the lineage for a type hint. diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/__init__.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/__init__.py index c7cd8456f5..1595ffe65a 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/__init__.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/__init__.py @@ -1,8 +1,13 @@ # flake8: noqa +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from rastervision.pipeline import Registry -def register_plugin(registry): - registry.set_plugin_version('rastervision.pytorch_learner', 5) + +def register_plugin(registry: 'Registry'): + registry.set_plugin_version('rastervision.pytorch_learner', 6) + registry.register_renamed_type_hints('geo_data_window', 'window_sampling') import rastervision.pipeline diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py index cc8f7c3bfa..5c32ba4941 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/learner_config.py @@ -330,6 +330,7 @@ def solver_config_upgrader(cfg_dict: dict, version: int) -> dict: # removed in version 5 cfg_dict.pop('test_batch_sz', None) cfg_dict.pop('test_num_epochs', None) + cfg_dict.pop('overfit_num_steps', None) return cfg_dict @@ -1176,7 +1177,13 @@ def unzip_data(self, zip_uris: List[str], unzip_dir: str) -> List[str]: return data_dirs -@register_config('geo_data') +def geo_data_config_upgrader(cfg_dict: dict, version: int) -> dict: + if version == 5: + cfg_dict['sampling'] = cfg_dict.pop('window_opts') + return cfg_dict + + +@register_config('geo_data', upgrader=geo_data_config_upgrader) class GeoDataConfig(DataConfig): """Configure :class:`GeoDatasets <.GeoDataset>`.