Skip to content

Commit

Permalink
Add validation that batch size is divisible by number of microbatches
Browse files Browse the repository at this point in the history
ghstack-source-id: b111e687be74fe7d371b21536df662d622d9e6e3
Pull Request resolved: #784
  • Loading branch information
H-Huang committed Jan 10, 2025
1 parent a04662f commit 95677cb
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,13 @@ def build_pipeline_schedule(job_config, stages, loss_fn):
of stages ({num_total_stages}) which may result in a bubble in the pipeline."
)

# validate that the batch size is divisible by the number of microbatches otherwise we'll hang or error during training
if job_config.training.batch_size % n_microbatches != 0:
raise ValueError(
f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
"Update the config arguments for either batch_size or pipeline_parallel_microbatches."
)

schedule = schedule_class(
stages if looped_schedule else stages[0],
n_microbatches=n_microbatches,
Expand Down

0 comments on commit 95677cb

Please sign in to comment.