Skip to content

Commit

Permalink
bug fixed when stacking pre_scales
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn committed Dec 11, 2023
1 parent 6e58e1e commit df68631
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion msamp/megatron/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def reduce_model_grads(self, args, timers): # noqa: C901
# pre_scales in the partition `data_parallel_rank`
pre_scales = [g.meta.pre_scale for g in fp8_grads[data_parallel_rank]]
max_elems_per_rank = max(model._grad_buffer_num_params)
pre_scales = torch.cat(pre_scales)
pre_scales = torch.stack(pre_scales)
# padding to max_elems_per_rank
pad = max_elems_per_rank - pre_scales.numel()
pre_scales = F.pad(pre_scales, (0, pad))
Expand Down

0 comments on commit df68631

Please sign in to comment.