Skip to content

Commit

Permalink
minor fix
Browse files Browse the repository at this point in the history
ghstack-source-id: 0dd35232e76d80a4542a7e91b2d25fea663938b6
Pull Request resolved: #788
  • Loading branch information
tianyu-l committed Jan 13, 2025
1 parent 95677cb commit 82f7387
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,5 +75,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l
To create a seed checkpoint, use the same model config as you use for training.
e.g.
```bash
NGPU=1 CONFIG=<path_to_model_config> ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_shard_degree 1
NGPU=1 CONFIG=<path_to_model_config> ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_replicate_degree 1 --training.data_parallel_shard_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1 --experimental.context_parallel_degree 1
```
2 changes: 1 addition & 1 deletion torchtitan/parallelisms/pipelining_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def build_pipeline_schedule(job_config, stages, loss_fn):
)
logger.info(
f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \
with {n_microbatches} and {num_total_stages} stages."
with {n_microbatches} microbatches and {num_total_stages} stages."
)

if pp_schedule_csv:
Expand Down
5 changes: 4 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ def loss_fn(pred, labels):
if job_config.checkpoint.create_seed_checkpoint:
assert (
world_size == 1
), "Must create seed-checkpoint using one gpu, to disable sharding"
), "Must create seed checkpoint using a single device, to disable sharding"
assert (
job_config.checkpoint.enable_checkpoint
), "Must enable checkpointing when creating a seed checkpoint"
checkpoint.save(curr_step=0, force=True)
logger.info("Created seed checkpoint")
return
Expand Down

0 comments on commit 82f7387

Please sign in to comment.