Skip to content

Commit

Permalink
Fix PP+CP handling freqs_cis buffer
Browse files Browse the repository at this point in the history
When developing test_pp_cp and chatting with fegin, we realized the
freqs_cis buffers are not being handled correctly in torchtitan for the
pipelining case.

CP needs to modify the freqs_cis buffer to account for sharding on seq
dim, but in the previous titan code this was implemented incorrectly.
`model.freqs_cis` was passed to CP for sharding, but pipelining does not
use `model` at all, it uses the different stage-models contained in
`model_parts` list.  The fix is to tell CP context about each freqs_cis
buffer inside `model_parts` models.

Alternatively we could tie the freqs_cis buffers for each pp stage
together, by explicitly doing so after calling init_weights per
pp-stage.  However this is of limited value so we skip it.

ghstack-source-id: 7aa393515497050c069f25604bcd984bc0f1a118
Pull Request resolved: #792
  • Loading branch information
wconstab committed Jan 16, 2025
1 parent f504a14 commit d989842
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def loss_fn(pred, labels):
pp_schedule, model_parts = models_pipelining_fns[model_name](
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)
# when PP is enabled, `model` obj is no longer used after this point, model_parts is used instead
del model

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
Expand Down Expand Up @@ -268,11 +270,12 @@ def loss_fn(pred, labels):
optimizers.zero_grad()

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
optional_context_parallel_ctx = (
utils.create_context_parallel_ctx(
cp_mesh=world_mesh["cp"],
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_buffers=[input_ids, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={input_ids, labels},
cp_rotate_method=job_config.experimental.context_parallel_rotate_method,
)
Expand Down

0 comments on commit d989842

Please sign in to comment.