Skip to content

Commit

Permalink
use pytorch estimator for training jobs
Browse files Browse the repository at this point in the history
  • Loading branch information
AdeelH committed Jan 24, 2024
1 parent fe37b51 commit 1752019
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 46 deletions.
28 changes: 24 additions & 4 deletions docs/setup/aws.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,24 @@ Add the following to your ``~/.rastervision/default`` file.
cpu_instance_type=ml.p3.2xlarge
gpu_image=123.dkr.ecr.us-east-1.amazonaws.com/raster-vision
gpu_instance_type=ml.p3.2xlarge
use_spot_instances=yes
train_image=123.dkr.ecr.us-east-1.amazonaws.com/raster-vision
train_instance_type=ml.p3.8xlarge
train_instance_count=2
use_spot_instances=no
spot_instance_max_wait_time=86400
max_run_time=86400
* ``role`` - AWS IAM role with appropriate SageMaker permissions.
* ``cpu_image`` - Docker image URI for CPU jobs.
* ``cpu_instance_type`` - Instance type for CPU jobs.
* ``gpu_image`` - Docker image URI for GPU jobs.
* ``gpu_instance_type`` - Instance type for GPU jobs.
* ``use_spot_instances`` - Whether to use spot instances.
* ``train_image`` - Docker image URI for training jobs. Defaults to ``gpu_image``.
* ``train_instance_type`` - Instance type for training jobs. Defaults to ``gpu_instance_type``.
* ``train_instance_count`` - Number of parallel nodes to run for training jobs. Defaults to 1.
* ``use_spot_instances`` - Whether to use spot instances. Only applies to training jobs.
* ``spot_instance_max_wait_time`` - Maximum time, in seconds, to wait for a spot instance to be allocated. Must be greater than or equal to ``max_run_time``. Default: ``max_run_time``.
* ``max_run_time`` - Maximum job run time in seconds. Default: 86400 (24 hours).


Environment variables
Expand All @@ -102,14 +112,24 @@ Alternatively, you can set the following environment variables:
SAGEMAKER_CPU_INSTANCE_TYPE="ml.p3.2xlarge"
SAGEMAKER_GPU_IMAGE="123.dkr.ecr.us-east-1.amazonaws.com/raster-vision"
SAGEMAKER_GPU_INSTANCE_TYPE="ml.p3.2xlarge"
SAGEMAKER_USE_SPOT_INSTANCES="yes"
SAGEMAKER_TRAIN_IMAGE="123.dkr.ecr.us-east-1.amazonaws.com/raster-vision"
SAGEMAKER_TRAIN_INSTANCE_TYPE="ml.p3.8xlarge"
SAGEMAKER_TRAIN_INSTANCE_COUNT="2"
SAGEMAKER_USE_SPOT_INSTANCES="no"
SPOT_INSTANCE_MAX_WAIT_TIME="86400"
MAX_RUN_TIME="86400"
* ``SAGEMAKER_ROLE`` - AWS IAM role with appropriate SageMaker permissions.
* ``SAGEMAKER_CPU_IMAGE`` - Docker image URI for CPU jobs.
* ``SAGEMAKER_CPU_INSTANCE_TYPE`` - Instance type for CPU jobs.
* ``SAGEMAKER_GPU_IMAGE`` - Docker image URI for GPU jobs.
* ``SAGEMAKER_GPU_INSTANCE_TYPE`` - Instance type for GPU jobs.
* ``SAGEMAKER_USE_SPOT_INSTANCES`` - Whether to use spot instances.
* ``SAGEMAKER_TRAIN_IMAGE`` - Docker image URI for training jobs. Defaults to ``SAGEMAKER_GPU_IMAGE``.
* ``SAGEMAKER_TRAIN_INSTANCE_TYPE`` - Instance type for training jobs. Defaults to ``SAGEMAKER_GPU_INSTANCE_TYPE``.
* ``SAGEMAKER_TRAIN_INSTANCE_COUNT`` - Number of parallel nodes to run for training jobs. Defaults to 1.
* ``SAGEMAKER_USE_SPOT_INSTANCES`` - Whether to use spot instances. Only applies to training jobs.
* ``SPOT_INSTANCE_MAX_WAIT_TIME`` - Maximum time, in seconds, to wait for a spot instance to be allocated. Must be greater than or equal to ``MAX_RUN_TIME``. Default: ``MAX_RUN_TIME``.
* ``MAX_RUN_TIME`` - Maximum job run time in seconds. Default: 86400 (24 hours).


.. seealso::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@ def register_plugin(registry):
'cpu_instance_type',
'gpu_image',
'gpu_instance_type',
'train_image',
'train_instance_type',
'train_instance_count',
'use_spot_instances',
'spot_instance_max_wait_time',
'max_run_time',
])


Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from typing import TYPE_CHECKING, List, Optional, Union
from os.path import join, basename
import logging
from pprint import pprint
import tarfile

import boto3
from rastervision.pipeline import rv_config_ as rv_config
from rastervision.pipeline.runner import Runner
from rastervision.pipeline.file_system import FileSystem
from rastervision.pipeline.file_system.utils import (str_to_file, get_tmp_dir,
upload_or_copy)

if TYPE_CHECKING:
from rastervision.pipeline.pipeline import Pipeline
from rastervision.core.rv_pipeline import RVPipeline, RVPipelineConfig
from sagemaker.workflow.pipeline_context import _JobStepArguments
from sagemaker import Session
from sagemaker.workflow.pipeline import Pipeline as SageMakerPipeline
Expand All @@ -18,6 +24,24 @@

AWS_SAGEMAKER = 'sagemaker'

DEFAULT_MAX_RUN_TIME = 24 * 60 * 60

PYTORCH_ESTIMATOR_SCRIPT_FILENAME = 'train.py'
PYTORCH_ESTIMATOR_TAR_FILENAME = 'train.tar.gz'
PYTORCH_ESTIMATOR_SCRIPT_TEMPLATE = """\
import os
from rastervision.pipeline import rv_config_ as rv_config
from rastervision.pipeline.cli import _run_command
if __name__ == '__main__':
print('WORLD_SIZE', os.environ.get('WORLD_SIZE'))
print('RANK', os.environ.get('RANK'))
print('LOCAL_RANK', os.environ.get('LOCAL_RANK'))
rv_config.set_tmp_dir_root('/opt/data/tmp/rv')
_run_command('{cfg_json_uri}', '{rv_cmd}')
"""


class AWSSageMakerRunner(Runner):
"""Runs pipelines remotely using AWS SageMaker.
Expand All @@ -32,7 +56,12 @@ class AWSSageMakerRunner(Runner):
cpu_instance_type=
gpu_image=
gpu_instance_type=
train_image=
train_instance_type=
train_instance_count=
use_spot_instances=
spot_instance_max_wait_time=
max_run_time=
"""

def run(self,
Expand Down Expand Up @@ -75,20 +104,47 @@ def build_pipeline(self,
"""Build a SageMaker Pipeline with each command as a step within it."""
from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow.pipeline import Pipeline as SageMakerPipeline
from sagemaker.workflow.pipeline_definition_config import (
PipelineDefinitionConfig)

config = rv_config.get_namespace_config(AWS_SAGEMAKER)
role = config('role')
cpu_image = config('cpu_image')
cpu_instance_type = config('cpu_instance_type')

gpu_image = config('gpu_image')
gpu_instance_type = config('gpu_instance_type')

train_image = config('train_image', default=gpu_image)
train_instance_type = config(
'train_instance_type', default=gpu_instance_type)
train_instance_count = int(config('train_instance_count', default='1'))

use_spot_instances = config('use_spot_instances').lower() == 'yes'
spot_instance_max_wait_time = int(
config(
'spot_instance_max_wait_time',
default=str(DEFAULT_MAX_RUN_TIME)))
max_run_time = int(
config('max_run_time', default=str(DEFAULT_MAX_RUN_TIME)))
sagemaker_session = PipelineSession()

steps = []

for command in commands:
use_gpu = command in pipeline.gpu_commands
if command.lower() == 'train':
use_gpu = True
instance_type = train_instance_type
instance_count = train_instance_count
image_uri = train_image
else:
use_gpu = command in pipeline.gpu_commands
image_uri = gpu_image if use_gpu else cpu_image
instance_type = (gpu_instance_type
if use_gpu else cpu_instance_type)
instance_count = 1
use_spot_instances = False

job_name = f'{pipeline_run_name}-{command}'

cmd = cmd_prefix[:]
Expand All @@ -105,91 +161,114 @@ def build_pipeline(self,
# If the step can be split, then split it into parts
# that do not depend on each other (can run in
# parallel).
_steps = []
step_splits = [None] * num_splits
for i in range(num_splits):
cmd += [
split_cmd = cmd + [
'--split-ind',
str(i), '--num-splits',
str(num_splits)
]
step = self.build_step(
step_name=f'{job_name}_{i+1}of{num_splits}',
cmd=cmd,
split_job_name = f'{job_name}_{i+1}of{num_splits}'
step_split = self.build_step(
pipeline,
step_name=command,
job_name=split_job_name,
cmd=split_cmd,
role=role,
image_uri=gpu_image if use_gpu else cpu_image,
instance_type=(gpu_instance_type
if use_gpu else cpu_instance_type),
image_uri=image_uri,
instance_type=instance_type,
use_spot_instances=use_spot_instances,
sagemaker_session=sagemaker_session,
use_gpu=use_gpu)
step.add_depends_on(steps)
_steps.append(step)
steps.extend(_steps)
instance_count=instance_count,
max_wait=spot_instance_max_wait_time,
max_run=max_run_time,
)
step_split.add_depends_on(steps)
step_splits[i] = step_split
steps.extend(step_splits)
else:
# If the step can not be split, then submit it as-is.
step = self.build_step(
step_name=job_name,
pipeline,
step_name=command,
job_name=job_name,
cmd=cmd,
role=role,
image_uri=gpu_image if use_gpu else cpu_image,
instance_type=(gpu_instance_type
if use_gpu else cpu_instance_type),
image_uri=image_uri,
instance_type=instance_type,
use_spot_instances=use_spot_instances,
sagemaker_session=sagemaker_session,
use_gpu=use_gpu)
instance_count=instance_count,
max_wait=spot_instance_max_wait_time,
max_run=max_run_time,
)
step.add_depends_on(steps)
steps.append(step)

# Submit the pipeline to SageMaker
pipeline_definition_config = PipelineDefinitionConfig(
use_custom_job_prefix=True)
sagemaker_pipeline = SageMakerPipeline(
name=pipeline_run_name,
steps=steps,
sagemaker_session=sagemaker_session,
)
pipeline_definition_config=pipeline_definition_config)
return sagemaker_pipeline

def build_step(self, step_name: str, cmd: List[str], role: str,
image_uri: str, instance_type: str,
def build_step(self,
pipeline: 'RVPipeline',
step_name: str,
job_name: str,
cmd: List[str],
role: str,
image_uri: str,
instance_type: str,
use_spot_instances: bool,
sagemaker_session: 'PipelineSession',
use_gpu: bool) -> Union['TrainingStep', 'ProcessingStep']:
instance_count: int = 1,
max_wait: int = DEFAULT_MAX_RUN_TIME,
max_run: int = DEFAULT_MAX_RUN_TIME,
**kwargs) -> Union['TrainingStep', 'ProcessingStep']:
"""Build an TrainingStep if use_gpu=True, otherwise a ProcessingStep.
"""
if use_gpu:
# For GPU-enabled steps, create an "Estimator".
# Formally this should probably not be used for prediction in
# this way, but it is expedient (especially given default
# service quotas, and other stuff).
from sagemaker.estimator import Estimator
if not use_spot_instances:
max_wait = None

if step_name.lower() == 'train':
from sagemaker.workflow.steps import TrainingStep
step_estimator = Estimator(
container_entry_point=cmd,

estimator = self._build_pytorch_estimator(
pipeline_cfg=pipeline.config,
role=role,
image_uri=image_uri,
instance_count=1,
instance_type=instance_type,
max_retry_attempts=1,
role=role,
use_spot_instances=use_spot_instances,
sagemaker_session=sagemaker_session,
use_spot=use_spot_instances,
instance_count=instance_count,
job_name=job_name,
max_wait=max_wait,
max_run=max_run,
**kwargs,
)
step_args: Optional['_JobStepArguments'] = step_estimator.fit(
step_args: Optional['_JobStepArguments'] = estimator.fit(
wait=False)
step = TrainingStep(step_name, step_args=step_args)
step = TrainingStep(job_name, step_args=step_args)
else:
# For non-GPU-enabled steps, create a ScriptProcessor.
from sagemaker.processing import Processor
from sagemaker.workflow.steps import ProcessingStep

step_processor = Processor(
role=role,
image_uri=image_uri,
instance_count=1,
instance_type=instance_type,
sagemaker_session=sagemaker_session,
entrypoint=cmd,
**kwargs,
)
step_args: Optional['_JobStepArguments'] = step_processor.run(
wait=False)
step = ProcessingStep(step_name, step_args=step_args)
step = ProcessingStep(job_name, step_args=step_args)

return step

Expand Down Expand Up @@ -246,3 +325,65 @@ def run_command(self,
base_job_name=job_name,
)
processor.run()

def _build_pytorch_estimator(self,
pipeline_cfg: 'RVPipelineConfig',
role: str,
image_uri: str,
instance_type: str,
sagemaker_session: 'PipelineSession',
use_spot_instances: bool = False,
instance_count: int = 1,
distribution: Optional[dict] = None,
job_name: Optional[str] = None,
**kwargs):
from sagemaker.pytorch import PyTorch
from rastervision.aws_s3.s3_file_system import S3FileSystem

if distribution is None:
distribution = dict(torch_distributed=dict(enabled=True))

train_uri = pipeline_cfg.train_uri
if FileSystem.get_file_system(train_uri) != S3FileSystem:
raise ValueError('Pipeline\'s train_uri must be an S3 URI.')

with get_tmp_dir() as source_dir:
# create script from template
script_path = join(source_dir, PYTORCH_ESTIMATOR_SCRIPT_FILENAME)
_write_train_script(
script_path, cfg_json_uri=pipeline_cfg.get_config_uri())
# tar and upload to S3
tar_path = _tar_script(script_path, source_dir)
tar_path_s3 = join(train_uri, PYTORCH_ESTIMATOR_TAR_FILENAME)
upload_or_copy(tar_path, tar_path_s3)

estimator = PyTorch(
entry_point=PYTORCH_ESTIMATOR_SCRIPT_FILENAME,
source_dir=tar_path_s3,
image_uri=image_uri,
distribution=distribution,
instance_count=instance_count,
instance_type=instance_type,
role=role,
sagemaker_session=sagemaker_session,
base_job_name=job_name,
use_spot_instances=use_spot_instances,
**kwargs,
)
return estimator


def _write_train_script(script_path: str, cfg_json_uri: str):
script_str = PYTORCH_ESTIMATOR_SCRIPT_TEMPLATE.format(
cfg_json_uri=cfg_json_uri, rv_cmd='train')
log.debug(script_path)
log.debug(script_str)
str_to_file(script_str, script_path)
return script_path


def _tar_script(script_path: str, tar_dir: str):
tar_path = join(tar_dir, PYTORCH_ESTIMATOR_TAR_FILENAME)
with tarfile.open(tar_path, 'w:gz') as tar:
tar.add(script_path, arcname=basename(script_path))
return tar_path
Loading

0 comments on commit 1752019

Please sign in to comment.