Skip to content

Commit

Permalink
Move train_data_spec to state
Browse files Browse the repository at this point in the history
  • Loading branch information
b-chu committed Mar 21, 2024
1 parent cf031e2 commit afa3742
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
3 changes: 2 additions & 1 deletion composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ def __init__(
self.save_metrics = save_metrics

self._train_dataloader = train_dataloader
self._train_data_spec = None
self._evaluators = list(ensure_tuple(evaluators))

self.previous_timestamp: Optional[Timestamp] = None
Expand Down Expand Up @@ -666,7 +667,7 @@ def stop_training(self):
logging, and evaluation for that batch, as well as any epoch end events.
"""
# Set the max_duration to the current time in its unit, except if the unit is TimeUnit.EPOCH. This is because TimeUnit.EPOCH is a very crude way to measure max duration. For example, it will result in division by zero error while computing get_elapsed_duration: https://github.com/mosaicml/composer/blob/1b9c6d3c0592183b947fd89890de0832366e33a7/composer/core/state.py#L641
if self.max_duration is not None and Time.from_input(self.max_duration,).unit != TimeUnit.EPOCH:
if self.max_duration is not None and Time.from_input(self.max_duration).unit != TimeUnit.EPOCH:
max_duration_unit = Time.from_input(self.max_duration).unit
self.max_duration = self.timestamp.get(max_duration_unit)
else:
Expand Down
43 changes: 24 additions & 19 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1471,10 +1471,10 @@ def __init__(
self.state.evaluators = evaluators

# Train Dataloader
self._train_data_spec = None if train_dataloader is None else ensure_data_spec(train_dataloader)
if self._train_data_spec is not None:
self.state._train_data_spec = None if train_dataloader is None else ensure_data_spec(train_dataloader)
if self.state._train_data_spec is not None:
self.state.set_dataloader(
self._train_data_spec.dataloader,
self.state._train_data_spec.dataloader,
train_dataloader_label,
train_subset_num_batches,
)
Expand Down Expand Up @@ -2031,15 +2031,15 @@ def fit(

# Train Dataloader
if train_dataloader is not None:
self._train_data_spec = ensure_data_spec(train_dataloader)
self.state.set_dataloader(self._train_data_spec.dataloader, train_dataloader_label)
self.state._train_data_spec = ensure_data_spec(train_dataloader)
self.state.set_dataloader(self.state._train_data_spec.dataloader, train_dataloader_label)
self.state.train_dataloader = self.state.dataloader
self.state.device_train_microbatch_size = _get_initial_device_train_microbatch_size(
self.state.device_train_microbatch_size,
self.state.auto_microbatching,
self.state.train_dataloader,
)
if self._train_data_spec is None:
if self.state._train_data_spec is None:
_raise_missing_argument_exception('train_dataloader')
if train_subset_num_batches is not None:
self.state.dataloader_len = train_subset_num_batches
Expand Down Expand Up @@ -2156,8 +2156,8 @@ def fit(
device_train_microbatch_size,
device=self.state.device,
)
if self.state.auto_microbatching and self._train_data_spec is not None and hasattr(
self._train_data_spec,
if self.state.auto_microbatching and self.state._train_data_spec is not None and hasattr(
self.state._train_data_spec,
'seq_parallel_world_size',
):
raise ValueError('`device_train_microbatch_size="auto"` is not compatible with sequence parallelism.')
Expand Down Expand Up @@ -2323,7 +2323,7 @@ def _train_loop(self) -> None:
'enabled_algorithms/' + algo.__class__.__name__: True for algo in self.state.algorithms
})
assert self.state.dataloader is not None, 'dataloader is set in __init__() or fit()'
assert self._train_data_spec is not None, 'The train data spec is set in __init__() or fit()'
assert self.state._train_data_spec is not None, 'The train data spec is set in __init__() or fit()'
assert self.state.scaler is not None, 'scaler should have been set in __init__()'

self.engine.run_event(Event.FIT_START)
Expand Down Expand Up @@ -2372,9 +2372,9 @@ def _train_loop(self) -> None:
continue

self.state.batch = self.state.device.batch_to_device(self.state.batch)
self.state.batch = self._train_data_spec.device_transforms(self.state.batch)
rank_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self._train_data_spec.get_num_tokens_in_batch(self.state.batch)
self.state.batch = self.state._train_data_spec.device_transforms(self.state.batch)
rank_num_samples = self.state._train_data_spec.get_num_samples_in_batch(self.state.batch)
rank_num_tokens = self.state._train_data_spec.get_num_tokens_in_batch(self.state.batch)

if self.state.deepspeed_enabled:
self.state.batch = _fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision)
Expand Down Expand Up @@ -2514,7 +2514,7 @@ def _train_loop(self) -> None:
self._run_evaluators(Event.FIT_END)

def _eval_train_metrics(self, device_batch):
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
assert self.state._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
assert self.state.train_metrics is not None, 'The train metrics should be set on __init__ or fit()'

with torch.no_grad(),\
Expand Down Expand Up @@ -2560,7 +2560,7 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
Returns:
Dict[str, torch.Tensor]: a dictionary containing the total loss and individual losses if available.
"""
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
assert self.state._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'

# Cache the device batch, because `self.state.batch` gets overridden in microbatching loop.
# Any in-place changes to a microbatch will be reflected in the device batch.
Expand All @@ -2581,7 +2581,10 @@ def _train_batch(self, use_grad_scaling: bool) -> Dict[str, torch.Tensor]:
try:
assert self.state.scaler is not None
assert self.state.device_train_microbatch_size is not None
microbatches = self._train_data_spec.split_batch(device_batch, self.state.device_train_microbatch_size)
microbatches = self.state._train_data_spec.split_batch(
device_batch,
self.state.device_train_microbatch_size,
)
if self._use_closures():
for optimizer in self.state.optimizers:
if use_grad_scaling:
Expand Down Expand Up @@ -2676,7 +2679,7 @@ def _train_microbatches(
else:
context = cast(Callable[[], ContextManager], self.state.model.no_sync)

assert self._train_data_spec is not None
assert self.state._train_data_spec is not None

with context():
self.engine.run_event(Event.BEFORE_TRAIN_BATCH)
Expand All @@ -2694,7 +2697,9 @@ def _train_microbatches(
optimizer.zero_grad()

# Tracker for gradient accumulation
current_batch_size = sum([self._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches])
current_batch_size = sum([
self.state._train_data_spec.get_num_samples_in_batch(batch) for batch in microbatches
])
# Cache batch, which will be overwritten by microbatches. Restore after microbatches complete
current_batch = self.state.batch

Expand Down Expand Up @@ -2736,12 +2741,12 @@ def _train_microbatch(
is_final_microbatch (bool): If current microbatch is the last one.
"""
assert self.state.scaler is not None
assert self._train_data_spec is not None
assert self.state._train_data_spec is not None

# Cache the device batch, because `self.state.batch` gets overridden in microbatching loop
device_batch = deepcopy(self.state.batch)

microbatch_num_samples = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
microbatch_num_samples = self.state._train_data_spec.get_num_samples_in_batch(self.state.batch)
if self.state.deepspeed_enabled or not isinstance(self.state.model, DistributedDataParallel):
sync_context = contextlib.nullcontext()
elif self.state.auto_microbatching and not self.first_batch_complete:
Expand Down

0 comments on commit afa3742

Please sign in to comment.