Skip to content

Commit

Permalink
Merge pull request #1978 from AdeelH/rvconfig-pred-opts
Browse files Browse the repository at this point in the history
Allow setting some `Learner`-related params via environment variables (or other Everett config options)
  • Loading branch information
AdeelH authored Nov 6, 2023
2 parents e56f93d + 8f9fa23 commit 90d8480
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 10 deletions.
26 changes: 22 additions & 4 deletions rastervision_pipeline/rastervision/pipeline/rv_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Any, Dict, List, Optional
import os
from tempfile import TemporaryDirectory
from pathlib import Path
import logging
import json
from typing import Optional, List, Dict

from everett.manager import (ConfigManager, ConfigDictEnv, ConfigOSEnv,
ConfigurationMissingError)
Expand Down Expand Up @@ -199,10 +199,27 @@ def set_everett_config(self,
'Switch to the version being run and search for Raster Vision '
'Configuration.'))

def get_namespace_config(self, namespace: str) -> Dict[str, str]:
def get_namespace_config(self, namespace: str) -> ConfigManager:
"""Get the key-val pairs associated with a namespace."""
return self.config.with_namespace(namespace)

def get_namespace_option(self,
namespace: str,
key: str,
default: Optional[Any] = None,
as_bool: bool = False) -> str:
"""Get the value of an option from a namespace."""
namespace_options = self.config.with_namespace(namespace)
try:
val: str = namespace_options(key)
if as_bool:
val = val.lower() in ('1', 'true', 'y', 'yes')
return val
except ConfigurationMissingError:
if as_bool:
return False
return default

def get_config_dict(
self, rv_config_schema: Dict[str, List[str]]) -> Dict[str, str]:
"""Get all Everett configuration.
Expand All @@ -221,8 +238,9 @@ def get_config_dict(
for namespace, keys in rv_config_schema.items():
for key in keys:
try:
config_dict[namespace + '_' + key] = \
self.get_namespace_config(namespace)(key)
namespace_options = self.get_namespace_config(namespace)
full_key = f'{namespace}_{key}'
config_dict[full_key] = namespace_options(key)
except ConfigurationMissingError:
pass

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Optional, Tuple, Union, Type)
from typing_extensions import Literal
from abc import ABC, abstractmethod
import os
from os.path import join, isfile, basename, isdir
import warnings
import time
Expand All @@ -22,6 +21,7 @@
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

from rastervision.pipeline import rv_config_ as rv_config
from rastervision.pipeline.file_system import (
sync_to_dir, json_to_file, file_to_json, make_dir, zipdir,
download_if_needed, download_or_copy, sync_from_dir, get_local_path, unzip,
Expand Down Expand Up @@ -50,7 +50,6 @@
TRANSFORMS_DIRNAME = 'custom_albumentations_transforms'
BUNDLE_MODEL_WEIGHTS_FILENAME = 'model.pth'
BUNDLE_MODEL_ONNX_FILENAME = 'model.onnx'
USE_ONNX = os.getenv('RASTERVISION_USE_ONNX', 'false').lower() in ('true', '1')

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -216,7 +215,7 @@ def from_model_bundle(cls: Type,
tmp_dir: Optional[str] = None,
cfg: Optional['LearnerConfig'] = None,
training: bool = False,
use_onnx_model: bool = USE_ONNX,
use_onnx_model: Optional[bool] = None,
**kwargs) -> 'Learner':
"""Create a Learner from a model bundle.
Expand All @@ -238,7 +237,7 @@ def from_model_bundle(cls: Type,
model will be put into eval mode. If True, the training
apparatus will be set up and the model will be put into
training mode. Defaults to True.
use_onnx_model (bool, optional): If True and training=False and a
use_onnx_model (Optional[bool]): If True and training=False and a
model.onnx file is available in the bundle, use that for
inference rather than the PyTorch weights. Defaults to the
boolean environment variable RASTERVISION_USE_ONNX if set,
Expand Down Expand Up @@ -302,6 +301,9 @@ def from_model_bundle(cls: Type,
# config has been altered, so re-validate
cfg = build_config(cfg.dict())

if use_onnx_model is None:
use_onnx_model = rv_config.get_namespace_option(
'rastervision', 'USE_ONNX', as_bool=True)
onnx_mode = False
if not training and use_onnx_model:
onnx_path = join(model_bundle_dir, 'model.onnx')
Expand Down Expand Up @@ -639,10 +641,19 @@ def predict_dataset(self,
if return_format not in {'xyz', 'yz', 'z'}:
raise ValueError('return_format must be one of "xyz", "yz", "z".')

cfg = self.cfg

num_workers = rv_config.get_namespace_option(
'rastervision',
'PREDICT_NUM_WORKERS',
default=cfg.data.num_workers)
batch_size = rv_config.get_namespace_option(
'rastervision', 'PREDICT_BATCH_SIZE', default=cfg.solver.batch_sz)

dl_kw = dict(
collate_fn=self.get_collate_fn(),
batch_size=self.cfg.solver.batch_sz,
num_workers=self.cfg.data.num_workers,
batch_size=int(batch_size),
num_workers=int(num_workers),
shuffle=False,
pin_memory=True)
dl_kw.update(dataloader_kw)
Expand Down

0 comments on commit 90d8480

Please sign in to comment.