Skip to content

Commit

Permalink
Raising a better warning if train or eval did not process any data. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ethantang-db authored Oct 17, 2024
1 parent 2972a2a commit 32caf5b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 5 deletions.
22 changes: 17 additions & 5 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,7 @@ def __init__(
'the optimal device_train_microbatch_size value and then manually specify that in a '
'second run with profiler.',
)
self.first_batch_complete = False
self.first_train_batch_complete = False
# If auto_microbatching is True or `device_train_microbatch_size` is not specified, the microbatch size
# will be determined when dataloader is specified. train_dataloader is parsed after `Event.INIT` or in
# fit()
Expand Down Expand Up @@ -2463,7 +2463,7 @@ def fit(
# update scaler since precision was provided
self.state.scaler = ClosureGradScaler() if self._use_closures() else GradScaler()

self.first_batch_complete = False
self.first_train_batch_complete = False
self._train_loop()

# Zero gradients at the end of fit so same model/optimizer can be used for further training
Expand Down Expand Up @@ -2763,6 +2763,11 @@ def _train_loop(self) -> None:
finished_epoch_early = True
break

if not self.first_train_batch_complete:
warnings.warn(
f'No batches were trained for global rank {dist.get_global_rank()}. This may be due to an issue with the train dataset, dataloader, or sampler. This may cause other issues or crashes down the line.',
)

if not finished_epoch_early or self.state.dataloader_len == self.state.timestamp.batch_in_epoch:
# Trigger the epoch end events if the dataloader was exhausted.
# This happens if the "break" did not trigger above, or if it
Expand Down Expand Up @@ -2997,7 +3002,7 @@ def _train_batch(self, use_grad_scaling: bool) -> dict[str, torch.Tensor]:
memory_stats = torch.cuda.memory_stats()
self.cumulative_alloc_retries = memory_stats['num_alloc_retries']
self.logger.log_metrics({'trainer/device_train_microbatch_size': self.state.device_train_microbatch_size})
self.first_batch_complete = True
self.first_train_batch_complete = True
return total_loss_dict

def _train_microbatches(
Expand All @@ -3019,7 +3024,7 @@ def _train_microbatches(
if ddp_sync or not isinstance(self.state.model, DistributedDataParallel):
context = contextlib.nullcontext
else:
if self.state.auto_microbatching and not self.first_batch_complete:
if self.state.auto_microbatching and not self.first_train_batch_complete:
# PyTorch DDP rebuilds gradient reduction buckets after 1) a forward pass where the
# no_sync context was not set 2) a backward pass 3) a forward pass. If only a
# subset of ranks OOM on the first batch, this will cause a deadlock since a rank
Expand Down Expand Up @@ -3122,7 +3127,7 @@ def _train_microbatch(
microbatch_size = self._train_data_spec.get_num_samples_in_batch(self.state.batch)
if self.state.deepspeed_enabled or not isinstance(self.state.model, DistributedDataParallel):
sync_context = contextlib.nullcontext()
elif self.state.auto_microbatching and not self.first_batch_complete:
elif self.state.auto_microbatching and not self.first_train_batch_complete:
# PyTorch DDP rebuilds gradient reduction buckets after 1) a forward pass where the
# no_sync context was not set 2) a backward pass 3) a forward pass. If only a
# subset of ranks OOM on the first batch, this will cause a deadlock since a rank
Expand Down Expand Up @@ -3599,6 +3604,7 @@ def _eval_loop(
drop_last = None
dataset_len = None
last_batch = False
first_eval_batch_complete = False
dist_sampler = _get_distributed_sampler(dataloader) if isinstance(dataloader, DataLoader) else None
if isinstance(dist_sampler, DistributedSampler) and isinstance(dataloader, DataLoader):
# The distributed sampler uses `set_epoch` to set the random seed
Expand Down Expand Up @@ -3762,6 +3768,7 @@ def _eval_loop(
evaluator.device_eval_microbatch_size,
})
# Break if we've successfully completed eval without OOMing.
first_eval_batch_complete = True
break

now = datetime.datetime.now()
Expand All @@ -3783,6 +3790,11 @@ def _eval_loop(

self.engine.run_event(Event.EVAL_BATCH_END)

if not first_eval_batch_complete:
warnings.warn(
f'No batches were evaluated for global rank {dist.get_global_rank()}. This may be due to an issue with the eval dataset, dataloader, or sampler. This may cause other issues or crashes down the line.',
)

self._compute_and_log_metrics(dataloader_label=evaluator.label, metrics=metrics)

self.engine.run_event(Event.EVAL_END)
Expand Down
1 change: 1 addition & 0 deletions tests/loggers/test_console_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,7 @@ def test_console_logger_eval(


@pytest.mark.filterwarnings('ignore:The ``compute`` method of metric.*:UserWarning')
@pytest.mark.filterwarnings('ignore:No batches were evaluated*:UserWarning')
def test_console_logger_eval_empty_dataloader(
console_logger_test_stream,
console_logger_test_file_path,
Expand Down
1 change: 1 addition & 0 deletions tests/trainer/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,7 @@ def __len__(self) -> int:
)
# trainer_2 will call compute if checkpoint is already at end of epoch
@pytest.mark.filterwarnings('ignore:The ``compute`` method of metric MulticlassAccuracy.*:UserWarning')
@pytest.mark.filterwarnings('ignore:No batches were trained*:UserWarning')
def test_resumption(
self,
device: str,
Expand Down
43 changes: 43 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1690,3 +1690,46 @@ def test_autoresume_and_default_remote_uploader_downloader(self, tmp_path: pathl

# Just test that the default args for everything do not hit the above errors
_ = Trainer(**config)


class TestNoTrainDataTrained:
"""Test cases where no training data is trained with the trainer.
This can happen in the following cases:
- The dataset has no samples.
- The dataset cannot split evenly across multi nodes on the first batch even
"""

def _get_dataloader(self, dataset_size: int):
"""Get a dataloader."""
dataset = RandomClassificationDataset(size=dataset_size)
dataloader = DataLoader(dataset=dataset, batch_size=1, sampler=dist.get_sampler(dataset=dataset))
return dataloader

def test_empty_train_dataloader(self):
"""Test the case where the train dataset has no samples."""
with pytest.raises(UserWarning, match='No batches were trained for global rank'):
train_dataloader = self._get_dataloader(0)
model = SimpleModel()

trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
max_duration='1ba',
)
trainer.fit()

def test_empty_eval_dataloader(self):
"""Test the case where the eval dataset has no samples."""
with pytest.raises(UserWarning, match='No batches were evaluated for global rank'):
train_dataloader = self._get_dataloader(1)
eval_dataloader = self._get_dataloader(0)
model = SimpleModel()

trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration='1ba',
)
trainer.fit()

0 comments on commit 32caf5b

Please sign in to comment.