Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jan 31, 2024
1 parent 7af1198 commit 498fb1a
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 28 deletions.
2 changes: 1 addition & 1 deletion rastervision_core/rastervision/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


def register_plugin(registry):
registry.set_plugin_version('rastervision.core', 10)
registry.set_plugin_version('rastervision.core', 11)
from rastervision.core.cli import predict, predict_scene
registry.add_plugin_command(predict)
registry.add_plugin_command(predict_scene)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from rastervision.core.backend import BackendConfig
from rastervision.core.evaluation import EvaluatorConfig
from rastervision.core.analyzer import AnalyzerConfig
from rastervision.core.rv_pipeline.chip_options import ChipOptions
from rastervision.core.rv_pipeline.chip_options import (ChipOptions,
WindowSamplingConfig)
from rastervision.pipeline.config import (Config, Field, register_config)

if TYPE_CHECKING:
Expand All @@ -21,7 +22,21 @@ class PredictOptions(Config):
pass


@register_config('rv_pipeline')
def rv_pipeline_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version == 10:
train_chip_sz = cfg_dict.pop('train_chip_sz', 300)
nodata_threshold = cfg_dict.pop('chip_nodata_threshold')
if 'chip_options' not in cfg_dict:
cfg_dict['chip_options'] = ChipOptions(
sampling=WindowSamplingConfig(size=train_chip_sz),
nodata_threshold=nodata_threshold)
else:
cfg_dict['chip_options']['sampling']['size'] = train_chip_sz
cfg_dict['chip_options']['nodata_threshold'] = nodata_threshold
return cfg_dict


@register_config('rv_pipeline', upgrader=rv_pipeline_config_upgrader)
class RVPipelineConfig(PipelineConfig):
"""Configure an :class:`.RVPipeline`."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,28 @@

from rastervision.pipeline.config import (ConfigError, register_config, Field,
validator)
from rastervision.core.rv_pipeline.rv_pipeline_config import (
ChipOptions, PredictOptions, RVPipelineConfig)
from rastervision.core.rv_pipeline.rv_pipeline_config import (PredictOptions,
RVPipelineConfig)
from rastervision.core.rv_pipeline.chip_options import (ChipOptions,
WindowSamplingConfig)
from rastervision.core.data import SemanticSegmentationLabelStoreConfig
from rastervision.core.evaluation import SemanticSegmentationEvaluatorConfig


def ss_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version < 1:
try:
# removed in version 1
del cfg_dict['channel_display_groups']
del cfg_dict['img_format']
del cfg_dict['label_format']
except KeyError:
pass
def ss_chip_options_upgrader(cfg_dict: dict, version: int) -> dict:
if version == 10:
sampling = WindowSamplingConfig(
method=cfg_dict.pop('window_method', None),
size=300,
stride=cfg_dict.pop('stride', None),
max_windows=cfg_dict.pop('chips_per_scene', None),
)
cfg_dict['sampling'] = sampling.dict()
return cfg_dict


@register_config('semantic_segmentation_chip_options')
@register_config(
'semantic_segmentation_chip_options', upgrader=ss_chip_options_upgrader)
class SemanticSegmentationChipOptions(ChipOptions):
"""Chipping options for semantic segmentation."""
target_class_ids: Optional[List[int]] = Field(
Expand Down Expand Up @@ -99,6 +102,18 @@ def validate_crop_sz(cls,
return crop_sz


def ss_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version == 0:
try:
# removed in version 1
del cfg_dict['channel_display_groups']
del cfg_dict['img_format']
del cfg_dict['label_format']
except KeyError:
pass
return cfg_dict


@register_config('semantic_segmentation', upgrader=ss_config_upgrader)
class SemanticSegmentationConfig(RVPipelineConfig):
"""Configure a :class:`.SemanticSegmentation` pipeline."""
Expand Down
24 changes: 14 additions & 10 deletions rastervision_pipeline/rastervision/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,20 @@ def _upgrade_config(x: Union[dict, List[dict]], plugin_versions: Dict[str, int]
for k, v in x.items():
new_x[k] = _upgrade_config(v, plugin_versions)
type_hint = new_x.get('type_hint')
if type_hint is not None:
type_hint_lineage = registry.get_type_hint_lineage(type_hint)
for th in type_hint_lineage:
plugin = registry.get_plugin(th)
old_version = plugin_versions[plugin]
curr_version = registry.get_plugin_version(plugin)
upgrader = registry.get_upgrader(th)
if upgrader:
for version in range(old_version, curr_version):
new_x = upgrader(new_x, version)
if type_hint is None:
return new_x
if type_hint in registry.renamed_type_hints:
type_hint = registry.renamed_type_hints[type_hint]
new_x['type_hint'] = type_hint
type_hint_lineage = registry.get_type_hint_lineage(type_hint)
for th in type_hint_lineage:
plugin = registry.get_plugin(th)
old_version = plugin_versions[plugin]
curr_version = registry.get_plugin_version(plugin)
upgrader = registry.get_upgrader(th)
if upgrader:
for version in range(old_version, curr_version):
new_x = upgrader(new_x, version)
return new_x
elif isinstance(x, list):
return [_upgrade_config(v, plugin_versions) for v in x]
Expand Down
11 changes: 11 additions & 0 deletions rastervision_pipeline/rastervision/pipeline/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self):
self.type_hint_to_lineage = {}
self.type_hint_to_plugin = {}
self.type_hint_to_upgrader = {}
self.renamed_type_hints = {}

def add_plugin_command(self, cmd: Command):
"""Add a click command contributed by a plugin."""
Expand Down Expand Up @@ -56,6 +57,16 @@ def set_plugin_version(self, plugin: str, version: int):
"""
self.plugin_versions[plugin] = version

def register_renamed_type_hints(self, type_hint_old: str,
type_hint_new: str):
"""Register renamed type_hints.
Args:
type_hint_old: Old type hint.
type_hint_new: New type hint.
"""
self.renamed_type_hints[type_hint_old] = type_hint_new

def get_type_hint_lineage(self, type_hint: str) -> List[str]:
"""Get the lineage for a type hint.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# flake8: noqa
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from rastervision.pipeline import Registry

def register_plugin(registry):
registry.set_plugin_version('rastervision.pytorch_learner', 5)

def register_plugin(registry: 'Registry'):
registry.set_plugin_version('rastervision.pytorch_learner', 6)
registry.register_renamed_type_hints('geo_data_window', 'window_sampling')


import rastervision.pipeline
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def solver_config_upgrader(cfg_dict: dict, version: int) -> dict:
# removed in version 5
cfg_dict.pop('test_batch_sz', None)
cfg_dict.pop('test_num_epochs', None)
cfg_dict.pop('overfit_num_steps', None)
return cfg_dict


Expand Down Expand Up @@ -1176,7 +1177,13 @@ def unzip_data(self, zip_uris: List[str], unzip_dir: str) -> List[str]:
return data_dirs


@register_config('geo_data')
def geo_data_config_upgrader(cfg_dict: dict, version: int) -> dict:
if version == 5:
cfg_dict['sampling'] = cfg_dict.pop('window_opts')
return cfg_dict


@register_config('geo_data', upgrader=geo_data_config_upgrader)
class GeoDataConfig(DataConfig):
"""Configure :class:`GeoDatasets <.GeoDataset>`.
Expand Down

0 comments on commit 498fb1a

Please sign in to comment.