diff --git a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py index 6ae3a548..90e1ddf0 100644 --- a/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py +++ b/msamp/deepspeed/runtime/zero/fp8_stage_1_and_2.py @@ -426,10 +426,10 @@ def fp8_reduce_independent_p_g_buckets_and_remove_grads(self, param, i): # Copy the grad tensor to the ipg buffer. new_grad_tensor = self.fp8_ipg_buffer[self.fp8_ipg_index ].narrow(0, self.fp8_elements_in_ipg_bucket, param.numel()) - grad = param.grad - if isinstance(grad, ScalingTensor): - # only copy ScalingTensor.value - grad = grad.value + if not isinstance(param.grad, ScalingTensor): + meta = ScalingMeta(WEIGHT_GRAD_QTYPE, group=self.dp_process_group) + param.grad = param.grad.cast(WEIGHT_GRAD_QTYPE, meta=meta, sync=True) + grad = param.grad.value new_grad_tensor.copy_(grad.view(-1)) # param: lp grad.data = new_grad_tensor.data.view(grad.shape) diff --git a/msamp/nn/distributed.py b/msamp/nn/distributed.py index f3e342c0..cce3bb1e 100644 --- a/msamp/nn/distributed.py +++ b/msamp/nn/distributed.py @@ -133,6 +133,11 @@ def _reduce_bucket(self, bucket_id): param_ids = self.bucket_to_param_ids[bucket_id] params = [self.parameters[i] for i in param_ids] grads = [p.grad for p in params] + wgrad_qtype = Dtypes.kfloat8_e4m3 + for g in grads: + if not hasattr(g, 'meta'): + meta = ScalingMeta(wgrad_qtype, group=self.process_group) + g.meta = meta metas = [g.meta for g in grads] # step 2: synchronize the amax diff --git a/msamp/te/modules.py b/msamp/te/modules.py index 24ef0429..9c5a5f8b 100644 --- a/msamp/te/modules.py +++ b/msamp/te/modules.py @@ -10,6 +10,10 @@ from msamp.common.tensor import ScalingTensor from msamp.nn import ScalingModule +# set the function `untyped_storage` for TransformerEngine +if not hasattr(torch.Tensor, 'untyped_storage'): + torch.Tensor.untyped_storage = lambda self: self.data.storage().untyped() + def set_activation_dtype(self, inp): """Set activation data type for AMP. @@ -253,8 +257,11 @@ def backward(ctx, *args): assert grads[i] is not None if v.grad is None: v.grad = grads[i] - else: + elif torch.is_tensor(v.grad): v.grad += grads[i] + else: + assert isinstance(v.grad, ScalingTensor) + v.grad = v.grad.to(grads[i].dtype) + grads[i] v.backward_grad_update(v.grad) grads[i] = None return (None, ) + tuple(grads) diff --git a/tests/te/test_replacer.py b/tests/te/test_replacer.py index a682a909..a015fc70 100644 --- a/tests/te/test_replacer.py +++ b/tests/te/test_replacer.py @@ -3,14 +3,19 @@ """Tests for msamp.te.replacer module.""" +import os import unittest import torch +import torch.distributed as dist +from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl import transformer_engine.pytorch as te from transformer_engine.common.recipe import Format, DelayedScaling from tests.helper import decorator +from msamp import deepspeed from msamp.nn import ScalingParameter +from msamp.optim import LBAdamW from msamp.te.replacer import TeReplacer @@ -33,9 +38,6 @@ def tearDown(self): @decorator.cuda_test def test_replace(self): """Test replace function in TeReplacer.""" - # fused attention need cuda version >= 12.1 - if torch.version.cuda < '12.1': - return te_transformer = te.TransformerLayer( self.hidden_size, self.ffn_hidden_size, self.num_attention_heads, fuse_qkv_params=True ) @@ -74,3 +76,93 @@ def _check_model(model): y = model(x, attention_mask=None) assert y.shape == (self.sequence_length, self.batch_size, self.hidden_size) y.sum().backward() + + @decorator.cuda_test + def test_te_with_deepspeed(self): + """Test TransformerEngine + MS-AMP with DeepSpeed.""" + te_transformer = te.TransformerLayer( + self.hidden_size, self.ffn_hidden_size, self.num_attention_heads, fuse_qkv_params=True + ) + te_transformer.to(dtype=self.dtype).cuda() + + model = TeReplacer.replace(te_transformer) + + ds_config = { + 'train_batch_size': self.batch_size, + 'train_micro_batch_size_per_gpu': self.batch_size, + 'zero_optimization': { + 'stage': 2, + } + } + + optimizer = LBAdamW(model.parameters(), lr=1e-3, weight_decay=0) + model, _, _, _ = deepspeed.initialize(model=model, optimizer=optimizer, config=ds_config) + + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + input = torch.randn(self.sequence_length, self.batch_size, self.hidden_size).cuda().to(dtype=self.dtype) + output = model(input, attention_mask=None) + loss = output.sum() + model.backward(loss) + model.step() + + +class TeReplacerDistributedTestCast(MultiProcessTestCase): + """Test functions in distributed module with TransformerEngine.""" + def setUp(self): + """Hook method for setting up the test fixture before exercising it.""" + super().setUp() + torch.manual_seed(1000) + + self._spawn_processes() + + def tearDown(self): + """Hook method for deconstructing the test fixture after testing it.""" + super().tearDown() + try: + os.remove(self.file_name) + except OSError: + pass + + @property + def world_size(self): + """Return the number of processes.""" + return 2 + + @requires_nccl() + @skip_if_lt_x_gpu(2) + @decorator.cuda_test + def test_fp8_ddp_with_te(self): + """Test FP8DistributedDataParallel with TransformerEngine.""" + hidden_size = 4096 + ffn_hidden_size = 16384 + num_attention_heads = 32 + dtype = torch.float16 + batch_size = 4 + sequence_length = 128 + + rank = self.rank + store = dist.FileStore(self.file_name, self.world_size) + torch.cuda.set_device(rank) + dist.init_process_group(backend='nccl', store=store, rank=self.rank, world_size=self.world_size) + + te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, fuse_qkv_params=True) + te_transformer.to(dtype=dtype).cuda() + model = TeReplacer.replace(te_transformer) + try: + # ddp_with_replicated_tensor is set in MultiProcessTestCase and should disabled. We catch exception because + # replicated_tensor_ddp_utils is not available in torch 2. + from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor + _set_ddp_with_replicated_tensor(False) + except Exception: + pass + + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) + # input is different for each rank. + x = torch.randn(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype) + fp8_format = Format.HYBRID + fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo='max') + with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + output = model(x, attention_mask=None) + output.sum().backward()