Skip to content

Commit

Permalink
Merge pull request #1979 from AdeelH/pred_scene_cmd
Browse files Browse the repository at this point in the history
Add a `predict_scene` CLI command that makes predictions on a scene specified by a `SceneConfig`
  • Loading branch information
AdeelH authored Nov 6, 2023
2 parents 90d8480 + 07afe90 commit a7d56de
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 59 deletions.
75 changes: 53 additions & 22 deletions docs/framework/cli.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ It has a main command, with some top level options, and several subcommands.
> rastervision --help
Usage: rastervision [OPTIONS] COMMAND [ARGS]...
Usage: python -m rastervision.pipeline.cli [OPTIONS] COMMAND [ARGS]...
The main click command.
Expand All @@ -26,17 +26,18 @@ It has a main command, with some top level options, and several subcommands.
--help Show this message and exit.
Commands:
predict Use a model bundle to predict on new images.
run Run sequence of commands within pipeline(s).
run_command Run an individual command within a pipeline.
predict Use a model bundle to predict on new images.
predict_scene Use a model bundle to predict on a new scene.
run Run sequence of commands within pipeline(s).
run_command Run an individual command within a pipeline.
Subcommands
------------

.. _run cli command:

run
^^^
``run``
^^^^^^^

Run is the main interface into running pipelines.

Expand Down Expand Up @@ -69,15 +70,15 @@ Some specific parameters to call out:

.. _split cli option:

-\\-splits
~~~~~~~~~~
``--splits``
~~~~~~~~~~~~

Use ``-s N`` or ``--splits N``, where ``N`` is the number of splits to create, to parallelize commands that can be split into parallelizable chunks. See :ref:`parallelizing commands` for more information.

.. _run_command cli command:

run_command
^^^^^^^^^^^
``run_command``
^^^^^^^^^^^^^^^

The ``run_command`` is used to run a specific command from a serialized ``PipelineConfig`` JSON file.
This is likely only interesting to people writing :ref:`custom runners <runners>`.
Expand All @@ -86,7 +87,8 @@ This is likely only interesting to people writing :ref:`custom runners <runners>
> rastervision run_command --help
Usage: rastervision run_command [OPTIONS] CFG_JSON_URI COMMAND
Usage: python -m rastervision.pipeline.cli run_command [OPTIONS] CFG_JSON_URI
COMMAND
Run a single COMMAND using a serialized PipelineConfig in CFG_JSON_URI.
Expand All @@ -99,24 +101,53 @@ This is likely only interesting to people writing :ref:`custom runners <runners>
.. _predict cli command:

predict
^^^^^^^
``predict``
^^^^^^^^^^^

Use ``predict`` to make predictions on new imagery given a :ref:`model bundle <model bundle>`.

.. code-block:: console
> rastervision predict --help
Usage: rastervision predict [OPTIONS] MODEL_BUNDLE IMAGE_URI LABEL_URI
Usage: python -m rastervision.pipeline.cli predict [OPTIONS] MODEL_BUNDLE
IMAGE_URI LABEL_URI
Make predictions on the images at IMAGE_URI using MODEL_BUNDLE and store
the prediction output at LABEL_URI.
Make predictions on the images at IMAGE_URI using MODEL_BUNDLE and store the
prediction output at LABEL_URI.
Options:
-a, --update-stats Run an analysis on this individual image, as
opposed to using any analysis like statistics that
exist in the prediction package
--channel-order TEXT List of indices comprising channel_order. Example:
2 1 0
--help Show this message and exit.
-a, --update-stats Run an analysis on this individual image, as opposed
to using any analysis like statistics that exist in
the prediction package
--channel-order LIST List of indices comprising channel_order. Example: 2 1
0
--scene-group TEXT Name of the scene group whose stats will be used by
the StatsTransformer. Requires the stats for this
scene group to be present inside the bundle.
--help Show this message and exit.
``predict_scene``
^^^^^^^^^^^^^^^^^

Similar to ``predict`` but allows greater control by allowing the user to specify a full :class:`.SceneConfig` and :class:`.PredictOptions`.

.. code-block:: console
> rastervision predict_scene --help
Usage: python -m rastervision.pipeline.cli predict_scene [OPTIONS]
MODEL_BUNDLE_URI
SCENE_CONFIG_URI
Use a model-bundle to make predictions on a scene.
MODEL_BUNDLE_URI URI to a serialized Raster Vision model-bundle.
SCENE_CONFIG_URI URI to a serialized Raster Vision SceneConfig.
Options:
--predict_options_uri TEXT Optional URI to serialized Raster Vision
PredictOptions config.
--help Show this message and exit.
3 changes: 2 additions & 1 deletion rastervision_core/rastervision/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

def register_plugin(registry):
registry.set_plugin_version('rastervision.core', 10)
from rastervision.core.cli import predict
from rastervision.core.cli import predict, predict_scene
registry.add_plugin_command(predict)
registry.add_plugin_command(predict_scene)


import rastervision.pipeline
Expand Down
10 changes: 7 additions & 3 deletions rastervision_core/rastervision/core/backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import TYPE_CHECKING, Optional
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from contextlib import AbstractContextManager

if TYPE_CHECKING:
Expand Down Expand Up @@ -42,8 +42,12 @@ def train(self):
pass

@abstractmethod
def load_model(self):
"""Load the model in preparation for one or more prediction calls."""
def load_model(self, uri: Optional[str] = None):
"""Load the model in preparation for one or more prediction calls.
Args:
uri: Optional URI to load the model from.
"""
pass

@abstractmethod
Expand Down
25 changes: 24 additions & 1 deletion rastervision_core/rastervision/core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import click

from rastervision.pipeline.file_system import get_tmp_dir
from rastervision.core.predictor import Predictor
from rastervision.core.predictor import Predictor, ScenePredictor


# https://stackoverflow.com/questions/48391777/nargs-equivalent-for-options-in-click
Expand Down Expand Up @@ -80,3 +80,26 @@ def predict(model_bundle: str,
predictor = Predictor(model_bundle, tmp_dir, update_stats,
channel_order, scene_group)
predictor.predict([image_uri], label_uri)


@click.command(
'predict_scene',
short_help='Use a model bundle to predict on a new scene.')
@click.argument('model_bundle_uri')
@click.argument('scene_config_uri')
@click.option(
'--predict_options_uri',
type=str,
default=None,
help='Optional URI to serialized Raster Vision PredictOptions config.')
def predict_scene(model_bundle_uri: str,
scene_config_uri: str,
predict_options_uri: Optional[str] = None):
"""Use a model-bundle to make predictions on a scene.
\b
MODEL_BUNDLE_URI URI to a serialized Raster Vision model-bundle.
SCENE_CONFIG_URI URI to a serialized Raster Vision SceneConfig.
"""
predictor = ScenePredictor(model_bundle_uri, predict_options_uri)
predictor.predict(scene_config_uri)
74 changes: 65 additions & 9 deletions rastervision_core/rastervision/core/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

from rastervision.pipeline import rv_config_ as rv_config
from rastervision.pipeline.config import (build_config, upgrade_config)
from rastervision.pipeline.file_system.utils import (download_if_needed,
file_to_json, unzip)
from rastervision.pipeline.file_system.utils import (
download_if_needed, file_to_json, get_tmp_dir, unzip)
from rastervision.core.data.raster_source import ChannelOrderError
from rastervision.core.data import (SemanticSegmentationLabelStoreConfig,
PolygonVectorOutputConfig,
StatsTransformerConfig)
from rastervision.core.analyzer import StatsAnalyzerConfig

if TYPE_CHECKING:
from rastervision.core.rv_pipeline import RVPipelineConfig
from rastervision.core.rv_pipeline import RVPipeline, RVPipelineConfig
from rastervision.core.data import SceneConfig

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -46,10 +46,10 @@ def __init__(self,
self.model_loaded = False

bundle_path = download_if_needed(model_bundle_uri)
bundle_dir = join(tmp_dir, 'bundle')
unzip(bundle_path, bundle_dir)
self.bundle_dir = join(tmp_dir, 'bundle')
unzip(bundle_path, self.bundle_dir)

config_path = join(bundle_dir, 'pipeline-config.json')
config_path = join(self.bundle_dir, 'pipeline-config.json')
config_dict = file_to_json(config_path)
rv_config.set_everett_config(
config_overrides=config_dict.get('rv_config'))
Expand All @@ -74,11 +74,11 @@ def __init__(self,
f'Using stats for scene group "{t.scene_group}". '
'To use a different scene group, specify '
'--scene-group <scene-group-name>.')
t.update_root(bundle_dir)
t.update_root(self.bundle_dir)

if self.update_stats:
stats_analyzer = StatsAnalyzerConfig(
output_uri=join(bundle_dir, 'stats.json'))
output_uri=join(self.bundle_dir, 'stats.json'))
self.config.analyzers = [stats_analyzer]

self.scene.label_source = None
Expand All @@ -87,7 +87,7 @@ def __init__(self,
self.config.dataset.train_scenes = [self.scene]
self.config.dataset.validation_scenes = [self.scene]
self.config.dataset.test_scenes = []
self.config.train_uri = bundle_dir
self.config.train_uri = self.bundle_dir

if channel_order is not None:
self.scene.raster_source.channel_order = channel_order
Expand All @@ -109,6 +109,8 @@ def predict(self, image_uris: List[str], label_uri: str) -> None:
if not hasattr(self.pipeline, 'predict'):
raise Exception(
'pipeline in model bundle must have predict method')
self.pipeline.build_backend(
join(self.bundle_dir, 'model-bundle.zip'))

self.scene.raster_source.uris = image_uris
self.scene.label_store.uri = label_uri
Expand All @@ -131,3 +133,57 @@ def predict(self, image_uris: List[str], label_uri: str) -> None:
'with channels unavailable in the imagery.\nTo set a new '
'channel_order that only uses channels available in the '
'imagery, use the --channel-order option.')


class ScenePredictor:
"""Class for making predictions on a scen using a model-bundle."""

def __init__(self,
model_bundle_uri: str,
predict_options_uri: Optional[str] = None,
tmp_dir: Optional[str] = None):
"""Creates a new Predictor.
Args:
model_bundle_uri: URI of the model bundle to use. Can be any
type of URI that Raster Vision can read.
tmp_dir: Temporary directory in which to store files that are used
by the Predictor.
"""
self.tmp_dir = tmp_dir
if self.tmp_dir is None:
self._tmp_dir = get_tmp_dir()
self.tmp_dir = self._tmp_dir.name

bundle_path = download_if_needed(model_bundle_uri)
bundle_dir = join(self.tmp_dir, 'bundle')
unzip(bundle_path, bundle_dir)

pipeline_config_path = join(bundle_dir, 'pipeline-config.json')
pipeline_config_dict = file_to_json(pipeline_config_path)

if predict_options_uri is not None:
pred_opts_config_dict = file_to_json(predict_options_uri)
pipeline_config_dict['predict_options'] = pred_opts_config_dict

rv_config.set_everett_config(
config_overrides=pipeline_config_dict.get('rv_config'))
pipeline_config_dict = upgrade_config(pipeline_config_dict)
self.pipeline_config: 'RVPipelineConfig' = build_config(
pipeline_config_dict)

self.pipeline: 'RVPipeline' = self.pipeline_config.build(self.tmp_dir)
self.pipeline.build_backend(join(bundle_dir, 'model-bundle.zip'))

def predict(self, scene_config_uri: str) -> None:
"""Generate predictions for the given image.
Args:
scene_config_uri: URI to a serialized :class:`.ScenConfig`.
"""
scene_config_dict = file_to_json(scene_config_uri)
scene_config: 'SceneConfig' = build_config(scene_config_dict)
class_config = self.pipeline_config.dataset.class_config
scene = scene_config.build(class_config, self.tmp_dir)
labels = self.pipeline.predict_scene(scene)
scene.label_store.save(labels)
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from rastervision.core.data.label import ObjectDetectionLabels

if TYPE_CHECKING:
from rastervision.core.backend.backend import Backend
from rastervision.core.data import (Labels, Scene, RasterSource,
ObjectDetectionLabelSource)
from rastervision.core.rv_pipeline.object_detection_config import (
Expand Down Expand Up @@ -176,13 +175,19 @@ def get_train_labels(self, window: Box,
ioa_thresh=self.config.chip_options.ioa_thresh,
clip=True)

def predict_scene(self, scene: 'Scene', backend: 'Backend') -> 'Labels':
def predict_scene(self, scene: 'Scene') -> 'Labels':
if self.backend is None:
self.build_backend()

# Use strided windowing to ensure that each object is fully visible (ie. not
# cut off) within some window. This means prediction takes 4x longer for object
# detection :(
chip_sz = self.config.predict_chip_sz
stride = chip_sz // 2
return backend.predict_scene(scene, chip_sz=chip_sz, stride=stride)
labels = self.backend.predict_scene(
scene, chip_sz=chip_sz, stride=stride)
labels = self.post_process_predictions(labels, scene)
return labels

def post_process_predictions(self, labels: ObjectDetectionLabels,
scene: 'Scene') -> ObjectDetectionLabels:
Expand Down
21 changes: 13 additions & 8 deletions rastervision_core/rastervision/core/rv_pipeline/rv_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,25 +177,26 @@ def predict(self, split_ind=0, num_splits=1):
This uses a sliding window.
"""
# Cache backend so subsequent calls will be faster. This is useful for
# the predictor.
if self.backend is None:
self.backend = self.config.backend.build(self.config, self.tmp_dir)
self.backend.load_model()
self.build_backend()

class_config = self.config.dataset.class_config
dataset = self.config.dataset.get_split_config(split_ind, num_splits)

for scene_config in (dataset.validation_scenes + dataset.test_scenes):
scene = scene_config.build(class_config, self.tmp_dir)
labels = self.predict_scene(scene, self.backend)
labels = self.post_process_predictions(labels, scene)
labels = self.predict_scene(scene)
scene.label_store.save(labels)

def predict_scene(self, scene: Scene, backend: Backend) -> Labels:
def predict_scene(self, scene: Scene) -> Labels:
if self.backend is None:
self.build_backend()
chip_sz = self.config.predict_chip_sz
stride = chip_sz
return backend.predict_scene(scene, chip_sz=chip_sz, stride=stride)
labels = self.backend.predict_scene(
scene, chip_sz=chip_sz, stride=stride)
labels = self.post_process_predictions(labels, scene)
return labels

def eval(self):
"""Evaluate predictions against ground truth."""
Expand Down Expand Up @@ -262,3 +263,7 @@ def bundle(self):
model_bundle_path = get_local_path(model_bundle_uri, self.tmp_dir)
zipdir(bundle_dir, model_bundle_path)
upload_or_copy(model_bundle_path, model_bundle_uri)

def build_backend(self, uri: Optional[str] = None) -> None:
self.backend = self.config.backend.build(self.config, self.tmp_dir)
self.backend.load_model(uri)
Loading

0 comments on commit a7d56de

Please sign in to comment.