diff --git a/examples/fsdp_mnist.py b/examples/fsdp_mnist.py index 37f45e7f..51328a57 100644 --- a/examples/fsdp_mnist.py +++ b/examples/fsdp_mnist.py @@ -147,17 +147,17 @@ def fsdp_main(rank, world_size, args): model = Net().to(rank) if args.msamp: - from msamp.nn import LinearReplacer - from msamp.common.dtype import Dtypes + from msamp.fsdp.replacer import FsdpReplacer from msamp.optim import FSDPAdam - - model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) - + model = FsdpReplacer.replace(model) + if rank == 0: print(f'model:') print(f'{model}') + for name, parameter in model.named_parameters(): + print(f'name:{name}, numel:{parameter.numel()}') - model = FSDP(model, use_orig_params=True) + model = FSDP(model, use_orig_params=True, auto_wrap_policy=my_auto_wrap_policy) if rank == 0: print(f'FSDP model:') @@ -167,6 +167,9 @@ def fsdp_main(rank, world_size, args): optimizer = FSDPAdam(model.parameters(), lr=args.lr) else: optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) + + if rank == 0: + print(f'optimizer initialized') scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma) init_start_event.record() diff --git a/msamp/fsdp/flat_param.py b/msamp/fsdp/flat_param.py index 6b65e206..35a32f58 100644 --- a/msamp/fsdp/flat_param.py +++ b/msamp/fsdp/flat_param.py @@ -656,6 +656,7 @@ def _init_shard_metadata( ranks = list(range(start_rank, end_rank + 1)) meta.group = dist.new_group(ranks=ranks) + def _get_shard_metadata( self, start: int, @@ -1469,12 +1470,19 @@ def _use_unsharded_views(self, as_params: bool) -> None: tensor.data = view # type: ignore[union-attr] assert tensor is not None # mypy param_var = tensor + setattr(module, param_name, param_var) if ( self._use_orig_params and self._training_state == HandleTrainingState.FORWARD ): module._parameters[param_name] = param_var # type: ignore[assignment] + + param_var._fp8 = True + param_var._scaling_metas = self.flat_param._scaling_metas[i] + param_var._meta = self.flat_param._metas[i] + param_var._padded = self.flat_param._paddeds[i] + param_var._original_shape = self.flat_param._original_shapes[i] for i, ( param_name, module, @@ -1615,6 +1623,10 @@ def _use_sharded_views(self) -> None: zip(self.flat_param._params, self.flat_param._param_infos) ): setattr(module, param_name, param) + if self.flat_param._metas[i] is not None: + param._meta = self.flat_param._metas[i] + param._grad_meta = self.flat_param._scaling_metas[i]['wgrad'] + in_sharded_flat_param = ( i >= start and i <= end diff --git a/msamp/fsdp/fully_sharded_data_parallel.py b/msamp/fsdp/fully_sharded_data_parallel.py index 33c53529..ece154d1 100644 --- a/msamp/fsdp/fully_sharded_data_parallel.py +++ b/msamp/fsdp/fully_sharded_data_parallel.py @@ -407,26 +407,6 @@ def __init__( _init_prefetching_state(self, backward_prefetch, forward_prefetch) _init_buffer_state(self, module) - for name, submodule in module.named_modules(): - params_to_process = list(submodule.named_parameters(recurse=False)) - for param_name, param in params_to_process: - if not isinstance(param, torch.Tensor): - data = param.value.view(-1) - padded = 0 - if data.numel() % 4 != 0: - padded = 4 - data.numel() % 4 - data = torch.nn.functional.pad(data, (0, padded)) - - data = data.view(dtype=torch.float32) - new_param = torch.nn.Parameter(data) - new_param._fp8 = True - new_param._original_shape = param.shape - new_param._padded = 0 - new_param._meta = param.meta - new_param._scaling_metas = param._scaling_metas - - setattr(submodule, param_name, new_param) - _init_param_handle_from_module( self, module, @@ -770,17 +750,6 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: f"{self.compute_device} but got {handle.flat_param.device}", ) - i = 0 - for _, submodule in self._fsdp_wrapped_module.named_modules(): - for param_name, param in submodule.named_parameters(recurse=False): - if self._flat_param._metas[i] is not None: - param._fp8 = True - param._scaling_metas = self._flat_param._scaling_metas[i] - param._meta = self._flat_param._metas[i] - param._padded = self._flat_param._paddeds[i] - param._original_shape = self._flat_param._original_shapes[i] - i += 1 - output = self._fsdp_wrapped_module(*args, **kwargs) return _post_forward(self, self._handles, reshard_fn, self, unused, output) @@ -928,12 +897,8 @@ def named_parameters( when inside the :meth:`summon_full_params` context manager. """ should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS - i = 0 + for param_name, param in super().named_parameters(*args, **kwargs): - if self._flat_param._metas[i] is not None: - param._meta = self._flat_param._metas[i] - param._grad_meta = self._flat_param._scaling_metas[i]['wgrad'] - i += 1 if should_clean_name: # Remove any instances of the FSDP-specific prefix; there can # be multiple in the case of nested FSDP modules diff --git a/msamp/fsdp/replacer.py b/msamp/fsdp/replacer.py new file mode 100644 index 00000000..2c49ea1f --- /dev/null +++ b/msamp/fsdp/replacer.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""MS-AMP te.replacer module.""" + +import torch + +from msamp.common.dtype import Dtypes +from msamp.nn import LinearReplacer + + +class FsdpReplacer: + """A replacer to replace the FP8 weights with FP32 nn.Parameter and attributes.""" + + @classmethod + def replace(cls, model): + """Replace the weights with ScalingParameter in transformer engine modules.""" + + model = LinearReplacer.replace(model, weight_qtype=Dtypes.kfloat8_e4m3) + for _, submodule in model.named_modules(): + params_to_process = list(submodule.named_parameters(recurse=False)) + for param_name, param in params_to_process: + if not isinstance(param, torch.Tensor): + data = param.value.view(-1) + padded = 0 + if data.numel() % 4 != 0: + padded = 4 - data.numel() % 4 + data = torch.nn.functional.pad(data, (0, padded)) + + data = data.view(dtype=torch.float32) + new_param = torch.nn.Parameter(data) + new_param._fp8 = True + new_param._original_shape = param.shape + new_param._padded = 0 + new_param._meta = param.meta + new_param._scaling_metas = param._scaling_metas + + setattr(submodule, param_name, new_param) + return model diff --git a/msamp/nn/functional.py b/msamp/nn/functional.py index 903a0d4c..6721eac6 100644 --- a/msamp/nn/functional.py +++ b/msamp/nn/functional.py @@ -26,7 +26,7 @@ def forward(ctx, input, weight, metas, dtype_holder): dtype_holder (torch.Tensor): A tensor to hold the output dtype. The required_grad of this tensor should be if input.required_grad is False. """ - if hasattr(weight, '_fp8'): + if isinstance(weight, torch.Tensor): padded = weight._padded original_shape = weight._original_shape meta = weight._meta @@ -36,7 +36,7 @@ def forward(ctx, input, weight, metas, dtype_holder): weight = weight[0: weight.numel() - padded] weight = weight.view(original_shape) weight = ScalingParameter(ScalingTensor(weight, meta)) - ctx._fp8 = True + ctx._returnWgrad = True ctx.metas = metas model_state.check_metas_in_flat(metas) @@ -109,7 +109,7 @@ def backward(ctx, output_grad): ) del old_wgrad - if ctx._fp8: + if ctx._returnWgrad: wgrad = wgrad.cast(Dtypes.kfloat8_e4m3, meta=wgrad_meta, sync=True) wgrad = wgrad.value.view(-1).view(dtype=torch.float32) wgrad.meta = wgrad_meta diff --git a/msamp/optim/adam.py b/msamp/optim/adam.py index 8ad7e29f..95c84d83 100644 --- a/msamp/optim/adam.py +++ b/msamp/optim/adam.py @@ -130,7 +130,6 @@ def zero_grad(self, set_to_none=False): param.grad.zero_() def step(self): - torch.set_printoptions(profile="full") # cast gradient to ScalingTensor for i, param in enumerate(self.original_params): if param.grad is None: