From afb33105b5df0648866a0b15be3a8fae3a5f5e99 Mon Sep 17 00:00:00 2001 From: Tianyu Liu Date: Sun, 12 Jan 2025 15:04:30 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- docs/checkpoint.md | 2 +- torchtitan/parallelisms/pipelining_utils.py | 2 +- train.py | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 3f66e5ac..05ef6f4d 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -75,5 +75,5 @@ A seed checkpoint does initialization of the model on a single CPU, and can be l To create a seed checkpoint, use the same model config as you use for training. e.g. ```bash -NGPU=1 CONFIG= ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_shard_degree 1 +NGPU=1 CONFIG= ./run_llama_train.sh --checkpoint.enable_checkpoint --checkpoint.create_seed_checkpoint --training.data_parallel_replicate_degree 1 --training.data_parallel_shard_degree 1 --training.tensor_parallel_degree 1 --experimental.pipeline_parallel_degree 1 --experimental.context_parallel_degree 1 ``` diff --git a/torchtitan/parallelisms/pipelining_utils.py b/torchtitan/parallelisms/pipelining_utils.py index 4322a031..7b2994f8 100644 --- a/torchtitan/parallelisms/pipelining_utils.py +++ b/torchtitan/parallelisms/pipelining_utils.py @@ -107,7 +107,7 @@ def build_pipeline_schedule(job_config, stages, loss_fn): ) logger.info( f"Using pipeline schedule {job_config.experimental.pipeline_parallel_schedule} \ -with {n_microbatches} and {num_total_stages} stages." +with {n_microbatches} microbatches and {num_total_stages} stages." ) if pp_schedule_csv: diff --git a/train.py b/train.py index 2874d0d5..21dd9f8b 100644 --- a/train.py +++ b/train.py @@ -201,7 +201,10 @@ def loss_fn(pred, labels): if job_config.checkpoint.create_seed_checkpoint: assert ( world_size == 1 - ), "Must create seed-checkpoint using one gpu, to disable sharding" + ), "Must create seed checkpoint using a single device, to disable sharding" + assert ( + job_config.checkpoint.enable_checkpoint + ), "Must enable checkpointing when creating a seed checkpoint" checkpoint.save(curr_step=0, force=True) logger.info("Created seed checkpoint") return