diff --git a/docs/framework/cli.rst b/docs/framework/cli.rst index f2c2a36f4..5029805ab 100644 --- a/docs/framework/cli.rst +++ b/docs/framework/cli.rst @@ -13,7 +13,7 @@ It has a main command, with some top level options, and several subcommands. > rastervision --help - Usage: rastervision [OPTIONS] COMMAND [ARGS]... + Usage: python -m rastervision.pipeline.cli [OPTIONS] COMMAND [ARGS]... The main click command. @@ -26,17 +26,18 @@ It has a main command, with some top level options, and several subcommands. --help Show this message and exit. Commands: - predict Use a model bundle to predict on new images. - run Run sequence of commands within pipeline(s). - run_command Run an individual command within a pipeline. + predict Use a model bundle to predict on new images. + predict_scene Use a model bundle to predict on a new scene. + run Run sequence of commands within pipeline(s). + run_command Run an individual command within a pipeline. Subcommands ------------ .. _run cli command: -run -^^^ +``run`` +^^^^^^^ Run is the main interface into running pipelines. @@ -69,15 +70,15 @@ Some specific parameters to call out: .. _split cli option: --\\-splits -~~~~~~~~~~ +``--splits`` +~~~~~~~~~~~~ Use ``-s N`` or ``--splits N``, where ``N`` is the number of splits to create, to parallelize commands that can be split into parallelizable chunks. See :ref:`parallelizing commands` for more information. .. _run_command cli command: -run_command -^^^^^^^^^^^ +``run_command`` +^^^^^^^^^^^^^^^ The ``run_command`` is used to run a specific command from a serialized ``PipelineConfig`` JSON file. This is likely only interesting to people writing :ref:`custom runners `. @@ -86,7 +87,8 @@ This is likely only interesting to people writing :ref:`custom runners > rastervision run_command --help - Usage: rastervision run_command [OPTIONS] CFG_JSON_URI COMMAND + Usage: python -m rastervision.pipeline.cli run_command [OPTIONS] CFG_JSON_URI + COMMAND Run a single COMMAND using a serialized PipelineConfig in CFG_JSON_URI. @@ -99,8 +101,8 @@ This is likely only interesting to people writing :ref:`custom runners .. _predict cli command: -predict -^^^^^^^ +``predict`` +^^^^^^^^^^^ Use ``predict`` to make predictions on new imagery given a :ref:`model bundle `. @@ -108,15 +110,44 @@ Use ``predict`` to make predictions on new imagery given a :ref:`model bundle rastervision predict --help - Usage: rastervision predict [OPTIONS] MODEL_BUNDLE IMAGE_URI LABEL_URI + Usage: python -m rastervision.pipeline.cli predict [OPTIONS] MODEL_BUNDLE + IMAGE_URI LABEL_URI - Make predictions on the images at IMAGE_URI using MODEL_BUNDLE and store - the prediction output at LABEL_URI. + Make predictions on the images at IMAGE_URI using MODEL_BUNDLE and store the + prediction output at LABEL_URI. Options: - -a, --update-stats Run an analysis on this individual image, as - opposed to using any analysis like statistics that - exist in the prediction package - --channel-order TEXT List of indices comprising channel_order. Example: - 2 1 0 - --help Show this message and exit. + -a, --update-stats Run an analysis on this individual image, as opposed + to using any analysis like statistics that exist in + the prediction package + --channel-order LIST List of indices comprising channel_order. Example: 2 1 + 0 + --scene-group TEXT Name of the scene group whose stats will be used by + the StatsTransformer. Requires the stats for this + scene group to be present inside the bundle. + --help Show this message and exit. + + +``predict_scene`` +^^^^^^^^^^^^^^^^^ + +Similar to ``predict`` but allows greater control by allowing the user to specify a full :class:`.SceneConfig` and :class:`.PredictOptions`. + +.. code-block:: console + + > rastervision predict_scene --help + + Usage: python -m rastervision.pipeline.cli predict_scene [OPTIONS] + MODEL_BUNDLE_URI + SCENE_CONFIG_URI + + Use a model-bundle to make predictions on a scene. + + MODEL_BUNDLE_URI URI to a serialized Raster Vision model-bundle. + SCENE_CONFIG_URI URI to a serialized Raster Vision SceneConfig. + + Options: + --predict_options_uri TEXT Optional URI to serialized Raster Vision + PredictOptions config. + --help Show this message and exit. + diff --git a/rastervision_core/rastervision/core/__init__.py b/rastervision_core/rastervision/core/__init__.py index f80bdeb18..e8506dbb0 100644 --- a/rastervision_core/rastervision/core/__init__.py +++ b/rastervision_core/rastervision/core/__init__.py @@ -3,8 +3,9 @@ def register_plugin(registry): registry.set_plugin_version('rastervision.core', 10) - from rastervision.core.cli import predict + from rastervision.core.cli import predict, predict_scene registry.add_plugin_command(predict) + registry.add_plugin_command(predict_scene) import rastervision.pipeline diff --git a/rastervision_core/rastervision/core/backend/backend.py b/rastervision_core/rastervision/core/backend/backend.py index 2052aa3d0..73cf750ff 100644 --- a/rastervision_core/rastervision/core/backend/backend.py +++ b/rastervision_core/rastervision/core/backend/backend.py @@ -1,5 +1,5 @@ +from typing import TYPE_CHECKING, Optional from abc import ABC, abstractmethod -from typing import TYPE_CHECKING from contextlib import AbstractContextManager if TYPE_CHECKING: @@ -42,8 +42,12 @@ def train(self): pass @abstractmethod - def load_model(self): - """Load the model in preparation for one or more prediction calls.""" + def load_model(self, uri: Optional[str] = None): + """Load the model in preparation for one or more prediction calls. + + Args: + uri: Optional URI to load the model from. + """ pass @abstractmethod diff --git a/rastervision_core/rastervision/core/cli.py b/rastervision_core/rastervision/core/cli.py index b744947b7..b75b8e2df 100644 --- a/rastervision_core/rastervision/core/cli.py +++ b/rastervision_core/rastervision/core/cli.py @@ -2,7 +2,7 @@ import click from rastervision.pipeline.file_system import get_tmp_dir -from rastervision.core.predictor import Predictor +from rastervision.core.predictor import Predictor, ScenePredictor # https://stackoverflow.com/questions/48391777/nargs-equivalent-for-options-in-click @@ -80,3 +80,26 @@ def predict(model_bundle: str, predictor = Predictor(model_bundle, tmp_dir, update_stats, channel_order, scene_group) predictor.predict([image_uri], label_uri) + + +@click.command( + 'predict_scene', + short_help='Use a model bundle to predict on a new scene.') +@click.argument('model_bundle_uri') +@click.argument('scene_config_uri') +@click.option( + '--predict_options_uri', + type=str, + default=None, + help='Optional URI to serialized Raster Vision PredictOptions config.') +def predict_scene(model_bundle_uri: str, + scene_config_uri: str, + predict_options_uri: Optional[str] = None): + """Use a model-bundle to make predictions on a scene. + + \b + MODEL_BUNDLE_URI URI to a serialized Raster Vision model-bundle. + SCENE_CONFIG_URI URI to a serialized Raster Vision SceneConfig. + """ + predictor = ScenePredictor(model_bundle_uri, predict_options_uri) + predictor.predict(scene_config_uri) diff --git a/rastervision_core/rastervision/core/predictor.py b/rastervision_core/rastervision/core/predictor.py index 34efedf6d..b433de510 100644 --- a/rastervision_core/rastervision/core/predictor.py +++ b/rastervision_core/rastervision/core/predictor.py @@ -4,8 +4,8 @@ from rastervision.pipeline import rv_config_ as rv_config from rastervision.pipeline.config import (build_config, upgrade_config) -from rastervision.pipeline.file_system.utils import (download_if_needed, - file_to_json, unzip) +from rastervision.pipeline.file_system.utils import ( + download_if_needed, file_to_json, get_tmp_dir, unzip) from rastervision.core.data.raster_source import ChannelOrderError from rastervision.core.data import (SemanticSegmentationLabelStoreConfig, PolygonVectorOutputConfig, @@ -13,7 +13,7 @@ from rastervision.core.analyzer import StatsAnalyzerConfig if TYPE_CHECKING: - from rastervision.core.rv_pipeline import RVPipelineConfig + from rastervision.core.rv_pipeline import RVPipeline, RVPipelineConfig from rastervision.core.data import SceneConfig log = logging.getLogger(__name__) @@ -46,10 +46,10 @@ def __init__(self, self.model_loaded = False bundle_path = download_if_needed(model_bundle_uri) - bundle_dir = join(tmp_dir, 'bundle') - unzip(bundle_path, bundle_dir) + self.bundle_dir = join(tmp_dir, 'bundle') + unzip(bundle_path, self.bundle_dir) - config_path = join(bundle_dir, 'pipeline-config.json') + config_path = join(self.bundle_dir, 'pipeline-config.json') config_dict = file_to_json(config_path) rv_config.set_everett_config( config_overrides=config_dict.get('rv_config')) @@ -74,11 +74,11 @@ def __init__(self, f'Using stats for scene group "{t.scene_group}". ' 'To use a different scene group, specify ' '--scene-group .') - t.update_root(bundle_dir) + t.update_root(self.bundle_dir) if self.update_stats: stats_analyzer = StatsAnalyzerConfig( - output_uri=join(bundle_dir, 'stats.json')) + output_uri=join(self.bundle_dir, 'stats.json')) self.config.analyzers = [stats_analyzer] self.scene.label_source = None @@ -87,7 +87,7 @@ def __init__(self, self.config.dataset.train_scenes = [self.scene] self.config.dataset.validation_scenes = [self.scene] self.config.dataset.test_scenes = [] - self.config.train_uri = bundle_dir + self.config.train_uri = self.bundle_dir if channel_order is not None: self.scene.raster_source.channel_order = channel_order @@ -109,6 +109,8 @@ def predict(self, image_uris: List[str], label_uri: str) -> None: if not hasattr(self.pipeline, 'predict'): raise Exception( 'pipeline in model bundle must have predict method') + self.pipeline.build_backend( + join(self.bundle_dir, 'model-bundle.zip')) self.scene.raster_source.uris = image_uris self.scene.label_store.uri = label_uri @@ -131,3 +133,57 @@ def predict(self, image_uris: List[str], label_uri: str) -> None: 'with channels unavailable in the imagery.\nTo set a new ' 'channel_order that only uses channels available in the ' 'imagery, use the --channel-order option.') + + +class ScenePredictor: + """Class for making predictions on a scen using a model-bundle.""" + + def __init__(self, + model_bundle_uri: str, + predict_options_uri: Optional[str] = None, + tmp_dir: Optional[str] = None): + """Creates a new Predictor. + + Args: + model_bundle_uri: URI of the model bundle to use. Can be any + type of URI that Raster Vision can read. + tmp_dir: Temporary directory in which to store files that are used + by the Predictor. + """ + self.tmp_dir = tmp_dir + if self.tmp_dir is None: + self._tmp_dir = get_tmp_dir() + self.tmp_dir = self._tmp_dir.name + + bundle_path = download_if_needed(model_bundle_uri) + bundle_dir = join(self.tmp_dir, 'bundle') + unzip(bundle_path, bundle_dir) + + pipeline_config_path = join(bundle_dir, 'pipeline-config.json') + pipeline_config_dict = file_to_json(pipeline_config_path) + + if predict_options_uri is not None: + pred_opts_config_dict = file_to_json(predict_options_uri) + pipeline_config_dict['predict_options'] = pred_opts_config_dict + + rv_config.set_everett_config( + config_overrides=pipeline_config_dict.get('rv_config')) + pipeline_config_dict = upgrade_config(pipeline_config_dict) + self.pipeline_config: 'RVPipelineConfig' = build_config( + pipeline_config_dict) + + self.pipeline: 'RVPipeline' = self.pipeline_config.build(self.tmp_dir) + self.pipeline.build_backend(join(bundle_dir, 'model-bundle.zip')) + + def predict(self, scene_config_uri: str) -> None: + """Generate predictions for the given image. + + Args: + scene_config_uri: URI to a serialized :class:`.ScenConfig`. + """ + scene_config_dict = file_to_json(scene_config_uri) + scene_config: 'SceneConfig' = build_config(scene_config_dict) + class_config = self.pipeline_config.dataset.class_config + scene = scene_config.build(class_config, self.tmp_dir) + labels = self.pipeline.predict_scene(scene) + scene.label_store.save(labels) diff --git a/rastervision_core/rastervision/core/rv_pipeline/object_detection.py b/rastervision_core/rastervision/core/rv_pipeline/object_detection.py index 2e92b1c16..14c528304 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/object_detection.py +++ b/rastervision_core/rastervision/core/rv_pipeline/object_detection.py @@ -9,7 +9,6 @@ from rastervision.core.data.label import ObjectDetectionLabels if TYPE_CHECKING: - from rastervision.core.backend.backend import Backend from rastervision.core.data import (Labels, Scene, RasterSource, ObjectDetectionLabelSource) from rastervision.core.rv_pipeline.object_detection_config import ( @@ -176,13 +175,19 @@ def get_train_labels(self, window: Box, ioa_thresh=self.config.chip_options.ioa_thresh, clip=True) - def predict_scene(self, scene: 'Scene', backend: 'Backend') -> 'Labels': + def predict_scene(self, scene: 'Scene') -> 'Labels': + if self.backend is None: + self.build_backend() + # Use strided windowing to ensure that each object is fully visible (ie. not # cut off) within some window. This means prediction takes 4x longer for object # detection :( chip_sz = self.config.predict_chip_sz stride = chip_sz // 2 - return backend.predict_scene(scene, chip_sz=chip_sz, stride=stride) + labels = self.backend.predict_scene( + scene, chip_sz=chip_sz, stride=stride) + labels = self.post_process_predictions(labels, scene) + return labels def post_process_predictions(self, labels: ObjectDetectionLabels, scene: 'Scene') -> ObjectDetectionLabels: diff --git a/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py b/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py index 2b05c4818..3719c900d 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py +++ b/rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py @@ -177,25 +177,26 @@ def predict(self, split_ind=0, num_splits=1): This uses a sliding window. """ - # Cache backend so subsequent calls will be faster. This is useful for - # the predictor. if self.backend is None: - self.backend = self.config.backend.build(self.config, self.tmp_dir) - self.backend.load_model() + self.build_backend() class_config = self.config.dataset.class_config dataset = self.config.dataset.get_split_config(split_ind, num_splits) for scene_config in (dataset.validation_scenes + dataset.test_scenes): scene = scene_config.build(class_config, self.tmp_dir) - labels = self.predict_scene(scene, self.backend) - labels = self.post_process_predictions(labels, scene) + labels = self.predict_scene(scene) scene.label_store.save(labels) - def predict_scene(self, scene: Scene, backend: Backend) -> Labels: + def predict_scene(self, scene: Scene) -> Labels: + if self.backend is None: + self.build_backend() chip_sz = self.config.predict_chip_sz stride = chip_sz - return backend.predict_scene(scene, chip_sz=chip_sz, stride=stride) + labels = self.backend.predict_scene( + scene, chip_sz=chip_sz, stride=stride) + labels = self.post_process_predictions(labels, scene) + return labels def eval(self): """Evaluate predictions against ground truth.""" @@ -262,3 +263,7 @@ def bundle(self): model_bundle_path = get_local_path(model_bundle_uri, self.tmp_dir) zipdir(bundle_dir, model_bundle_path) upload_or_copy(model_bundle_path, model_bundle_uri) + + def build_backend(self, uri: Optional[str] = None) -> None: + self.backend = self.config.backend.build(self.config, self.tmp_dir) + self.backend.load_model(uri) diff --git a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation.py b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation.py index 11dd24b4c..2ba5b21b3 100644 --- a/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation.py +++ b/rastervision_core/rastervision/core/rv_pipeline/semantic_segmentation.py @@ -10,7 +10,6 @@ SemanticSegmentationWindowMethod) if TYPE_CHECKING: - from rastervision.core.backend.backend import Backend from rastervision.core.data import ( ClassConfig, Labels, @@ -144,7 +143,10 @@ def post_process_batch(self, windows, chips, labels): return labels - def predict_scene(self, scene: 'Scene', backend: 'Backend') -> 'Labels': + def predict_scene(self, scene: 'Scene') -> 'Labels': + if self.backend is None: + self.build_backend() + cfg: 'SemanticSegmentationConfig' = self.config chip_sz = cfg.predict_chip_sz stride = cfg.predict_options.stride @@ -162,5 +164,7 @@ def predict_scene(self, scene: 'Scene', backend: 'Backend') -> 'Labels': 'still overlap after cropping.') crop_sz = overlap_sz // 2 - return backend.predict_scene( + labels = self.backend.predict_scene( scene, chip_sz=chip_sz, stride=stride, crop_sz=crop_sz) + labels = self.post_process_predictions(labels, scene) + return labels diff --git a/rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py b/rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py index 42e2986bf..5d5298466 100644 --- a/rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py +++ b/rastervision_pytorch_backend/rastervision/pytorch_backend/pytorch_learner_backend.py @@ -119,13 +119,14 @@ def train(self, source_bundle_uri=None): learner = self.learner_cfg.build(self.tmp_dir, training=True) learner.main() - def load_model(self): - self.learner = self._build_learner_from_bundle(training=False) + def load_model(self, uri: Optional[str] = None): + self.learner = self._build_learner_from_bundle( + bundle_uri=uri, training=False) def _build_learner_from_bundle(self, - bundle_uri=None, - cfg=None, - training=False): + bundle_uri: Optional[str] = None, + cfg: Optional['LearnerConfig'] = None, + training: bool = False): if bundle_uri is None: bundle_uri = self.learner_cfg.get_model_bundle_uri() return Learner.from_model_bundle( diff --git a/tests/pytorch_backend/examples/test_tiny_spacenet.py b/tests/pytorch_backend/examples/test_tiny_spacenet.py index 5345bfe03..66767f543 100644 --- a/tests/pytorch_backend/examples/test_tiny_spacenet.py +++ b/tests/pytorch_backend/examples/test_tiny_spacenet.py @@ -1,11 +1,15 @@ import unittest +from os.path import join import shutil from click.testing import CliRunner from rastervision.pipeline.cli import main -from rastervision.pipeline.file_system.utils import get_tmp_dir -from rastervision.core.cli import predict +from rastervision.pipeline.file_system.utils import get_tmp_dir, json_to_file +from rastervision.core.cli import predict, predict_scene +from rastervision.core.data import (RasterioSourceConfig, SceneConfig, + SemanticSegmentationLabelStoreConfig) +from rastervision.core.rv_pipeline import SemanticSegmentationPredictOptions from tests import data_file_path @@ -18,7 +22,8 @@ def test_rastervision_run_tiny_spacenet(self): 'run', 'inprocess', 'rastervision.pytorch_backend.examples.tiny_spacenet' ]) - self.assertEqual(result.exit_code, 0) + if result.exit_code != 0: + raise result.exception # test predict command bundle_path = '/opt/data/output/tiny_spacenet/bundle/model-bundle.zip' @@ -28,7 +33,29 @@ def test_rastervision_run_tiny_spacenet(self): bundle_path, img_path, tmp_dir, '--channel-order', '0', '1', '2' ]) - self.assertEqual(result.exit_code, 0) + if result.exit_code != 0: + raise result.exception + + # test predict_scene command + bundle_path = '/opt/data/output/tiny_spacenet/bundle/model-bundle.zip' + img_path = data_file_path('small-rgb-tile.tif') + with get_tmp_dir() as tmp_dir: + pred_uri = join(tmp_dir, 'pred') + rs_cfg = RasterioSourceConfig(uris=img_path) + ls_cfg = SemanticSegmentationLabelStoreConfig(uri=pred_uri) + scene_cfg = SceneConfig( + id='', raster_source=rs_cfg, label_store=ls_cfg) + pred_opts = SemanticSegmentationPredictOptions() + scene_config_uri = join(pred_uri, 'scene-config.json') + json_to_file(scene_cfg.dict(), scene_config_uri) + pred_opts_uri = join(pred_uri, 'predict-options.json') + json_to_file(pred_opts.dict(), pred_opts_uri) + result = runner.invoke(predict_scene, [ + bundle_path, scene_config_uri, '--predict_options_uri', + pred_opts_uri + ]) + if result.exit_code != 0: + raise result.exception if __name__ == '__main__':