diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 960bac4f..b8fd4f21 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -364,6 +364,16 @@ def __init__(self): default=1, help="Context parallelism degree. 1 means disabled.", ) + self.parser.add_argument( + "--experimental.expert_parallel_mode", + type=str, + default="none", + choices=["none", "tp", "tp2ep"], + help=""" + Expert Parallel mode. + 'tp2ep' would use the entire TP mesh to shard non-shared experts on the num_experts dimension. + """, + ) self.parser.add_argument( "--training.mixed_precision_param", type=str, diff --git a/torchtitan/parallelisms/expert_parallel.py b/torchtitan/parallelisms/expert_parallel.py new file mode 100644 index 00000000..a42a9a15 --- /dev/null +++ b/torchtitan/parallelisms/expert_parallel.py @@ -0,0 +1,327 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from functools import partial +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +# This is similar to PrepareModuleInput and PrepareModuleOutput, +# but applies them simultaneously. +class PrepareModuleInputOutput(ParallelStyle): + def __init__( + self, + *, + input_layouts: Optional[Union[Placement, Tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, Tuple[Optional[Placement]]] + ] = None, + input_kwarg_layouts: Optional[Dict[str, Placement]] = None, + desired_input_kwarg_layouts: Optional[Dict[str, Placement]] = None, + use_local_input: bool = True, + output_layouts: Union[Placement, Tuple[Placement]], + desired_output_layouts: Union[Placement, Tuple[Placement]], + use_local_output: bool = True, + ): + # for input + self.input_layouts = ( + (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts + ) + self.desired_input_layouts = ( + (desired_input_layouts,) + if isinstance(desired_input_layouts, Placement) + else desired_input_layouts + ) + self.use_local_input = use_local_input + if self.input_layouts is not None: + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + assert len(self.input_layouts) == len( + self.desired_input_layouts + ), "input_layouts and desired_input_layouts should have same length!" + self.with_kwargs = input_kwarg_layouts is not None + self.input_kwarg_layouts = input_kwarg_layouts or {} + self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} + if self.with_kwargs: + assert len(self.input_kwarg_layouts) == len( + self.desired_input_kwarg_layouts + ), "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + + # for output + self.output_layouts = ( + (output_layouts,) + if isinstance(output_layouts, Placement) + else output_layouts + ) + self.desired_output_layouts = ( + (desired_output_layouts,) + if isinstance(desired_output_layouts, Placement) + else desired_output_layouts + ) + self.use_local_output = use_local_output + assert len(self.output_layouts) == len( + self.desired_output_layouts + ), "output_layouts and desired_output_layouts should have same length!" + + def _prepare_input_arg( + self, + input: Any, + mesh: DeviceMesh, + input_layout: Optional[Placement], + desired_layout: Optional[Placement], + ): + if input_layout is not None: + if isinstance(input, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = input + else: + assert isinstance( + input, torch.Tensor + ), "expecting input to be a torch.Tensor!" + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) + + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + + return dt_inp.to_local() if self.use_local_input else dt_inp + else: + return input + + def _prepare_input_fn(self, inputs, device_mesh): + if self.input_layouts is None: + return inputs + prepared_inputs = [] + if not isinstance(inputs, tuple): + inputs = (inputs,) + if len(inputs) != len(self.input_layouts): + raise ValueError("module inputs and input_layouts should have same length!") + + assert ( + self.desired_input_layouts is not None + ), "desired module inputs should not be None!" + for inp, input_layout, desired_layout in zip( + inputs, self.input_layouts, self.desired_input_layouts + ): + prepared_inputs.append( + self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout) + ) + return tuple(prepared_inputs) + + def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): + prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) + prepared_kwarg_inputs = {} + for kwarg_key in kwarg_inputs.keys(): + kwarg_val = kwarg_inputs[kwarg_key] + input_layout = self.input_kwarg_layouts.get(kwarg_key) + desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) + + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( + kwarg_val, device_mesh, input_layout, desired_input_layout + ) + + return (prepared_arg_inputs, prepared_kwarg_inputs) + + def _prepare_out_fn(self, outputs, device_mesh): + prepared_outputs = [] + if not isinstance(outputs, tuple): + outputs = (outputs,) + if len(outputs) != len(self.output_layouts): + raise ValueError( + "module outputs and output_layouts should have same length!" + ) + for out, out_layout, desired_out_layout in zip( + outputs, self.output_layouts, self.desired_output_layouts + ): + if out_layout is not None: + if isinstance(out, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert out.placements[0] == out_layout + dt_out = out + else: + dt_out = DTensor.from_local( + out, device_mesh, (out_layout,), run_check=False + ) + + if out_layout != desired_out_layout: + dt_out = dt_out.redistribute(placements=(desired_out_layout,)) + prepared_outputs.append( + dt_out.to_local() if self.use_local_output else dt_out + ) + else: + prepared_outputs.append(out) + if len(prepared_outputs) == 1: + return prepared_outputs[0] + else: + return tuple(prepared_outputs) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + # for input + if self.with_kwargs: + module.register_forward_pre_hook( + lambda _, inputs, kwargs: self._prepare_input_kwarg_fn( + inputs, kwargs, device_mesh + ), + with_kwargs=True, + ) # type: ignore[misc] + else: + module.register_forward_pre_hook( + lambda _, inputs: self._prepare_input_fn(inputs, device_mesh) + ) # type: ignore[misc, call-arg] + + # for output + module.register_forward_hook( + lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh) + ) # type: ignore[misc, call-arg] + + return module + + +class TensorParallel(ParallelStyle): + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Partial(),) + self.desired_input_layouts = (Replicate(),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # TODO: figure out dynamo support for instance method and switch this to instance method + + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "gate_proj", + nn.Parameter(distribute_tensor(module.gate_proj, device_mesh, [Shard(2)])), + ) # Column-wise sharding + module.register_parameter( + "down_proj", + nn.Parameter(distribute_tensor(module.down_proj, device_mesh, [Shard(1)])), + ) # Row-wise sharding + module.register_parameter( + "up_proj", + nn.Parameter(distribute_tensor(module.up_proj, device_mesh, [Shard(2)])), + ) # Column-wise sharding + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), + ) + + +class ExpertParallel(ParallelStyle): + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Shard(0),) + self.output_layouts = (output_layouts or Shard(0),) + self.desired_input_layouts = (Shard(0),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # TODO: figure out dynamo support for instance method and switch this to instance method + + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + def _partition_fn(self, name, module, device_mesh): + # shard on the expert dimension + for name, param in module.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + module.register_parameter(name, dist_param) + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), + ) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 53f00686..56df423d 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. from collections import defaultdict +from typing import Optional import torch import torch.nn as nn @@ -64,6 +65,14 @@ def parallelize_llama( enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) + ep_mode = job_config.experimental.expert_parallel_mode + if ep_mode != "none": + apply_ep( + model, + ep_mode=ep_mode, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ) + if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -214,6 +223,72 @@ def apply_tp( ) +def apply_ep( + model: nn.Module, + ep_mode: str, + tp_mesh: Optional[DeviceMesh] = None, +): + from torch.distributed.tensor import Partial + from torchtitan.parallelisms.expert_parallel import ( + ExpertParallel, + PrepareModuleInputOutput, + TensorParallel, + ) + + if ep_mode == "tp": + assert tp_mesh is not None + + for _, transformer_block in model.layers.items(): + moe_plan = { + # input / output sharding on the seqlen dim + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # Router Parallel, sharding on the expert dim + "moe.router.gate": ColwiseParallel(), + # input / output sharding on the expert dim + "moe.experts": TensorParallel( + input_layouts=Shard(0), + output_layouts=Shard(0), + ), + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_plan, + ) + + elif ep_mode == "tp2ep": + assert tp_mesh is not None + + for _, transformer_block in model.layers.items(): + moe_plan = { + # input / output sharding on the seqlen dim + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # Router Parallel, sharding on the expert dim + "moe.router.gate": ColwiseParallel(), + # input / output replicated + "moe.experts": ExpertParallel(), + "moe.shared_expert": TensorParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_plan, + ) + + logger.info(f"Applied {ep_mode} Expert Parallelism to the model") + + # for selective op activation checkpointing _save_list = { torch.ops.aten.mm.default, diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index f681cdba..276c8e71 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -46,6 +46,7 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) context_parallel_degree = 1 pipeline_parallel_degree = 1 enable_async_tensor_parallel = false +expert_parallel_mode = "tp2ep" [checkpoint] enable_checkpoint = false