Skip to content

Commit

Permalink
[Bug Fixed] Support MS-AMP+TE+DDP and MS-AMP+TE+DeepSpeed (#144)
Browse files Browse the repository at this point in the history
**Description**
This PR supports MS-AMP+TE+DDP and MS-AMP+TE+DeepSpeed.

**Major Revision**
- In FP8 DeepSpeed ZeRO or FP8 DDP, when the type of a weight gradient
is torch.Tensor, it will be converted to ScalingTensor.
- unittests for MS-AMP+TE+DDP and MS-AMP+TE+DeepSpeed in
[tests/te/test_replacer.py](https://github.com/Azure/MS-AMP/compare/main...wkcn:fix_wgrad_for_ds?expand=1#diff-9d0a56799e3844e780daf5b7bd565c4217659bd3f46748b7067aac46847b16b9).

---------

Co-authored-by: Yuxiang Yang <[email protected]>
  • Loading branch information
wkcn and tocean authored Dec 14, 2023
1 parent 5f1cd5f commit d562f0f
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 8 deletions.
8 changes: 4 additions & 4 deletions msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,10 +426,10 @@ def fp8_reduce_independent_p_g_buckets_and_remove_grads(self, param, i):
# Copy the grad tensor to the ipg buffer.
new_grad_tensor = self.fp8_ipg_buffer[self.fp8_ipg_index
].narrow(0, self.fp8_elements_in_ipg_bucket, param.numel())
grad = param.grad
if isinstance(grad, ScalingTensor):
# only copy ScalingTensor.value
grad = grad.value
if not isinstance(param.grad, ScalingTensor):
meta = ScalingMeta(WEIGHT_GRAD_QTYPE, group=self.dp_process_group)
param.grad = param.grad.cast(WEIGHT_GRAD_QTYPE, meta=meta, sync=True)
grad = param.grad.value
new_grad_tensor.copy_(grad.view(-1))
# param: lp
grad.data = new_grad_tensor.data.view(grad.shape)
Expand Down
5 changes: 5 additions & 0 deletions msamp/nn/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def _reduce_bucket(self, bucket_id):
param_ids = self.bucket_to_param_ids[bucket_id]
params = [self.parameters[i] for i in param_ids]
grads = [p.grad for p in params]
wgrad_qtype = Dtypes.kfloat8_e4m3
for g in grads:
if not hasattr(g, 'meta'):
meta = ScalingMeta(wgrad_qtype, group=self.process_group)
g.meta = meta
metas = [g.meta for g in grads]

# step 2: synchronize the amax
Expand Down
9 changes: 8 additions & 1 deletion msamp/te/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from msamp.common.tensor import ScalingTensor
from msamp.nn import ScalingModule

# set the function `untyped_storage` for TransformerEngine
if not hasattr(torch.Tensor, 'untyped_storage'):
torch.Tensor.untyped_storage = lambda self: self.data.storage().untyped()


def set_activation_dtype(self, inp):
"""Set activation data type for AMP.
Expand Down Expand Up @@ -253,8 +257,11 @@ def backward(ctx, *args):
assert grads[i] is not None
if v.grad is None:
v.grad = grads[i]
else:
elif torch.is_tensor(v.grad):
v.grad += grads[i]
else:
assert isinstance(v.grad, ScalingTensor)
v.grad = v.grad.to(grads[i].dtype) + grads[i]
v.backward_grad_update(v.grad)
grads[i] = None
return (None, ) + tuple(grads)
Expand Down
98 changes: 95 additions & 3 deletions tests/te/test_replacer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@

"""Tests for msamp.te.replacer module."""

import os
import unittest

import torch
import torch.distributed as dist
from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

from tests.helper import decorator
from msamp import deepspeed
from msamp.nn import ScalingParameter
from msamp.optim import LBAdamW
from msamp.te.replacer import TeReplacer


Expand All @@ -33,9 +38,6 @@ def tearDown(self):
@decorator.cuda_test
def test_replace(self):
"""Test replace function in TeReplacer."""
# fused attention need cuda version >= 12.1
if torch.version.cuda < '12.1':
return
te_transformer = te.TransformerLayer(
self.hidden_size, self.ffn_hidden_size, self.num_attention_heads, fuse_qkv_params=True
)
Expand Down Expand Up @@ -74,3 +76,93 @@ def _check_model(model):
y = model(x, attention_mask=None)
assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size)
y.sum().backward()

@decorator.cuda_test
def test_te_with_deepspeed(self):
"""Test TransformerEngine + MS-AMP with DeepSpeed."""
te_transformer = te.TransformerLayer(
self.hidden_size, self.ffn_hidden_size, self.num_attention_heads, fuse_qkv_params=True
)
te_transformer.to(dtype=self.dtype).cuda()

model = TeReplacer.replace(te_transformer)

ds_config = {
'train_batch_size': self.batch_size,
'train_micro_batch_size_per_gpu': self.batch_size,
'zero_optimization': {
'stage': 2,
}
}

optimizer = LBAdamW(model.parameters(), lr=1e-3, weight_decay=0)
model, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config)

fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
input = torch.randn(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype)
output = model(input, attention_mask=None)
loss = output.sum()
model.backward(loss)
model.step()


class TeReplacerDistributedTestCast(MultiProcessTestCase):
"""Test functions in distributed module with TransformerEngine."""
def setUp(self):
"""Hook method for setting up the test fixture before exercising it."""
super().setUp()
torch.manual_seed(1000)

self._spawn_processes()

def tearDown(self):
"""Hook method for deconstructing the test fixture after testing it."""
super().tearDown()
try:
os.remove(self.file_name)
except OSError:
pass

@property
def world_size(self):
"""Return the number of processes."""
return 2

@requires_nccl()
@skip_if_lt_x_gpu(2)
@decorator.cuda_test
def test_fp8_ddp_with_te(self):
"""Test FP8DistributedDataParallel with TransformerEngine."""
hidden_size = 4096
ffn_hidden_size = 16384
num_attention_heads = 32
dtype = torch.float16
batch_size = 4
sequence_length = 128

rank = self.rank
store = dist.FileStore(self.file_name, self.world_size)
torch.cuda.set_device(rank)
dist.init_process_group(backend='nccl', store=store, rank=self.rank, world_size=self.world_size)

te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, fuse_qkv_params=True)
te_transformer.to(dtype=dtype).cuda()
model = TeReplacer.replace(te_transformer)
try:
# ddp_with_replicated_tensor is set in MultiProcessTestCase and should disabled. We catch exception because
# replicated_tensor_ddp_utils is not available in torch 2.
from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
_set_ddp_with_replicated_tensor(False)
except Exception:
pass

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
# input is different for each rank.
x = torch.randn(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max')
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
output = model(x, attention_mask=None)
output.sum().backward()

0 comments on commit d562f0f

Please sign in to comment.