From d9898423ecef131825d13c6c8b521a24e889785f Mon Sep 17 00:00:00 2001 From: Will Constable Date: Thu, 16 Jan 2025 08:53:07 -0800 Subject: [PATCH] Fix PP+CP handling freqs_cis buffer When developing test_pp_cp and chatting with fegin, we realized the freqs_cis buffers are not being handled correctly in torchtitan for the pipelining case. CP needs to modify the freqs_cis buffer to account for sharding on seq dim, but in the previous titan code this was implemented incorrectly. `model.freqs_cis` was passed to CP for sharding, but pipelining does not use `model` at all, it uses the different stage-models contained in `model_parts` list. The fix is to tell CP context about each freqs_cis buffer inside `model_parts` models. Alternatively we could tie the freqs_cis buffers for each pp stage together, by explicitly doing so after calling init_weights per pp-stage. However this is of limited value so we skip it. ghstack-source-id: 7aa393515497050c069f25604bcd984bc0f1a118 Pull Request resolved: https://github.com/pytorch/torchtitan/pull/792 --- train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index d79f51b5..bac22772 100644 --- a/train.py +++ b/train.py @@ -154,6 +154,8 @@ def loss_fn(pred, labels): pp_schedule, model_parts = models_pipelining_fns[model_name]( model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) + # when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead + del model # For PP with looped schedules, each item in model_parts is one stage-model-chunk. # We need to iterate through model_parts to apply SPMD parallelisms, compilation, @@ -268,11 +270,12 @@ def loss_fn(pred, labels): optimizers.zero_grad() # apply context parallelism if cp is enabled + # ensure CP handles the separate freqs_cis buffer for each pp stage optional_context_parallel_ctx = ( utils.create_context_parallel_ctx( cp_mesh=world_mesh["cp"], - cp_buffers=[input_ids, labels, model.freqs_cis], - cp_seq_dims=[1, 1, 0], + cp_buffers=[input_ids, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={input_ids, labels}, cp_rotate_method=job_config.experimental.context_parallel_rotate_method, )