diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 4447698beb..91dd0b1e19 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -3640,6 +3640,11 @@ def _iter_dataloader(self, trainer_mode: TrainerMode): else: dataloader_iter = itertools.islice(self.state.dataloader, int(self.state.dataloader_len)) + # Track if iteration has finished (used for distributed training when we have variable length dataloaders) + # 0 = not finished, 1 = finished (using integer tensors so we can use dist.all_reduce) + iter_finished = self.state.device.tensor_to_device(torch.zeros(1, dtype=torch.uint8)) + + batch = None while True: try: # [BEFORE/AFTER]_DATALOADER only runs while training @@ -3655,7 +3660,15 @@ def _iter_dataloader(self, trainer_mode: TrainerMode): # Otherwise, we will encounter an error at the start of the next epoch when # Event.BEFORE_DATALOADER tries to start an unfinished marker. self.engine.run_marker_only_event(Event.AFTER_DATALOADER) + # Mark iteration as finished - don't break yet as we need to sync across ranks + iter_finished += 1 + + # Sync iter finished across ranks + dist.all_reduce(iter_finished, reduce_operation='MAX') + # If any rank has finished, stop all rank iterations + if iter_finished.item() == 1: break + yield batch def _use_closures(self) -> bool: diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 59e8b26782..1bb5d265b6 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1250,6 +1250,43 @@ def test_accumulate_time_across_ranks( assert num_tokens_accum == num_tokens * 2 assert batch_time_accum == datetime.timedelta(seconds=0.1 * (1 + 0)) + @pytest.mark.world_size(2) + def test_rank_dependent_dataloader_lengths( + self, + model: ComposerModel, + max_duration: Time[int], + ): + # Change rank 1 dataloader size to create different sized dataloaders on each rank + batch_size = 4 + orig_num_samples = 16 + rank_num_samples = orig_num_samples + 8 if dist.get_local_rank() == 1 else orig_num_samples + # Create train and eval dataloaders (will have rank-dependent lengths) + train_dataset = RandomClassificationDataset(size=rank_num_samples) + train_dataloader = DataLoader( + dataset=train_dataset, + batch_size=batch_size, + sampler=dist.get_sampler(train_dataset), + ) + eval_dataset = RandomClassificationDataset(size=rank_num_samples) + eval_dataloader = DataLoader( + dataset=eval_dataset, + batch_size=batch_size, + sampler=dist.get_sampler(eval_dataset), + ) + # Fit (train + eval) + trainer = Trainer( + model=model, + max_duration=max_duration, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + ) + trainer.fit() + # Check the correct number of samples and batches have been processed + assert trainer.state.timestamp.sample.value == orig_num_samples + assert trainer.state.timestamp.batch.value == orig_num_samples / batch_size / 2 + assert trainer.state.eval_timestamp.sample.value == orig_num_samples + assert trainer.state.eval_timestamp.batch.value == orig_num_samples / batch_size / 2 + @world_size(1, 2) @device('cpu', 'gpu', 'gpu-amp', precision=True)