Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
When developing test_pp_cp and chatting with fegin, we realized the freqs_cis buffers are not being handled correctly in torchtitan for the pipelining case. CP needs to modify the freqs_cis buffer to account for sharding on seq dim, but in the previous titan code this was implemented incorrectly. `model.freqs_cis` was passed to CP for sharding, but pipelining does not use `model` at all, it uses the different stage-models contained in `model_parts` list. The fix is to tell CP context about each freqs_cis buffer inside `model_parts` models. Alternatively we could tie the freqs_cis buffers for each pp stage together, by explicitly doing so after calling init_weights per pp-stage. However this is of limited value so we skip it. ghstack-source-id: 7aa393515497050c069f25604bcd984bc0f1a118 Pull Request resolved: #792
- Loading branch information