Skip to content

Commit

Permalink
Add support for variable length dataloaders in DDP (#3416)
Browse files Browse the repository at this point in the history
* Add support for variable length dataloaders in dist training

* Remove test file

* Fix typo

* Fixed batch referenced before assignment

* Replace sentinel with None

* Add unit test

* Update unit test

* Reduce tensor creation to one line

Co-authored-by: Mihir Patel <[email protected]>

* Remove requirement for gpu in test

---------

Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
JAEarly and mvpatel2000 authored Jun 24, 2024
1 parent 6d7b90d commit 8b32fbc
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
13 changes: 13 additions & 0 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3640,6 +3640,11 @@ def _iter_dataloader(self, trainer_mode: TrainerMode):
else:
dataloader_iter = itertools.islice(self.state.dataloader, int(self.state.dataloader_len))

# Track if iteration has finished (used for distributed training when we have variable length dataloaders)
# 0 = not finished, 1 = finished (using integer tensors so we can use dist.all_reduce)
iter_finished = self.state.device.tensor_to_device(torch.zeros(1, dtype=torch.uint8))

batch = None
while True:
try:
# [BEFORE/AFTER]_DATALOADER only runs while training
Expand All @@ -3655,7 +3660,15 @@ def _iter_dataloader(self, trainer_mode: TrainerMode):
# Otherwise, we will encounter an error at the start of the next epoch when
# Event.BEFORE_DATALOADER tries to start an unfinished marker.
self.engine.run_marker_only_event(Event.AFTER_DATALOADER)
# Mark iteration as finished - don't break yet as we need to sync across ranks
iter_finished += 1

# Sync iter finished across ranks
dist.all_reduce(iter_finished, reduce_operation='MAX')
# If any rank has finished, stop all rank iterations
if iter_finished.item() == 1:
break

yield batch

def _use_closures(self) -> bool:
Expand Down
37 changes: 37 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,6 +1250,43 @@ def test_accumulate_time_across_ranks(
assert num_tokens_accum == num_tokens * 2
assert batch_time_accum == datetime.timedelta(seconds=0.1 * (1 + 0))

@pytest.mark.world_size(2)
def test_rank_dependent_dataloader_lengths(
self,
model: ComposerModel,
max_duration: Time[int],
):
# Change rank 1 dataloader size to create different sized dataloaders on each rank
batch_size = 4
orig_num_samples = 16
rank_num_samples = orig_num_samples + 8 if dist.get_local_rank() == 1 else orig_num_samples
# Create train and eval dataloaders (will have rank-dependent lengths)
train_dataset = RandomClassificationDataset(size=rank_num_samples)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
sampler=dist.get_sampler(train_dataset),
)
eval_dataset = RandomClassificationDataset(size=rank_num_samples)
eval_dataloader = DataLoader(
dataset=eval_dataset,
batch_size=batch_size,
sampler=dist.get_sampler(eval_dataset),
)
# Fit (train + eval)
trainer = Trainer(
model=model,
max_duration=max_duration,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
)
trainer.fit()
# Check the correct number of samples and batches have been processed
assert trainer.state.timestamp.sample.value == orig_num_samples
assert trainer.state.timestamp.batch.value == orig_num_samples / batch_size / 2
assert trainer.state.eval_timestamp.sample.value == orig_num_samples
assert trainer.state.eval_timestamp.batch.value == orig_num_samples / batch_size / 2


@world_size(1, 2)
@device('cpu', 'gpu', 'gpu-amp', precision=True)
Expand Down

0 comments on commit 8b32fbc

Please sign in to comment.