diff --git a/composer/trainer/mosaic_fsdp.py b/composer/trainer/mosaic_fsdp.py index 44f42650eb..7ebc800f88 100644 --- a/composer/trainer/mosaic_fsdp.py +++ b/composer/trainer/mosaic_fsdp.py @@ -75,3 +75,12 @@ def patch_pytorch(): state_dict.set_model_state_dict = set_model_state_dict state_dict.set_optimizer_state_dict = set_optimizer_state_dict state_dict._get_fqns = _get_fqns + + # Monkeypatch for ND child submeshes + # PR: https://github.com/pytorch/pytorch/pull/119752 + from torch.distributed.device_mesh import DeviceMesh, _MeshEnv + + from composer.trainer.mosaic_fsdp_utils import create_child_mesh, device_mesh__getitem__, device_mesh__init__ + _MeshEnv.create_child_mesh = create_child_mesh + DeviceMesh.__getitem__ = device_mesh__getitem__ + DeviceMesh.__init__ = device_mesh__init__ diff --git a/composer/trainer/mosaic_fsdp_utils.py b/composer/trainer/mosaic_fsdp_utils.py index 32f6225391..ec75dfdfeb 100644 --- a/composer/trainer/mosaic_fsdp_utils.py +++ b/composer/trainer/mosaic_fsdp_utils.py @@ -792,3 +792,149 @@ def create_global_plan( self.metadata = metadata return self.global_plan, self.metadata + + from torch.utils._typing_utils import not_none + from torch.distributed.device_mesh import DeviceMesh + + def create_child_mesh( + self, + device_mesh, + mesh_dim_names: Tuple[str], + ): + """Monkeypatch create_child_mesh to nightly version.""" + # swap the current dim to the last dim then reshape to flatten out other + # dims, so we can just extract the list of ranks which contains cur_rank. + mesh_dims = [ + not_none(device_mesh.mesh_dim_names).index(mesh_dim_name) + for mesh_dim_name in mesh_dim_names + ] + cur_rank = device_mesh.get_rank() + mesh = device_mesh.mesh + all_mesh_dims = list(range(mesh.ndim)) + for mesh_dim in mesh_dims: + # remove not pop b/c we want the value of the ind removed not it's position in the list + # because this list dynamically changes. + all_mesh_dims.remove(mesh_dim) + + mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims] + + pg_ranks_by_dim = device_mesh.mesh.permute( + *all_mesh_dims, *mesh_dims, + ).reshape(-1, *mesh_sizes) + + for mesh_nd in pg_ranks_by_dim: + if cur_rank in mesh_nd: + sub_mesh = DeviceMesh( + device_mesh.device_type, + mesh_nd, + mesh_dim_names=mesh_dim_names, + ) + res_sub_mesh = sub_mesh + + res_sub_mesh._dim_group_infos = [ # type: ignore[possibly-undefined] + device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims + ] + + # Assign the current DeviceMesh as the parent of the child DeviceMesh. + self.child_to_parent_mapping[res_sub_mesh] = device_mesh + return res_sub_mesh + + from torch.distributed.device_mesh import _mesh_resources + + def device_mesh__init__( + self, + device_type: str, + mesh, + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + ) -> None: + """Monkeypatch device mesh __init__ to nightly version.""" + self.device_type = device_type + if isinstance(mesh, torch.Tensor) and mesh.device.type != 'cpu': + raise ValueError(f'`mesh` must be a CPU tensor, got {mesh}') + self.mesh = ( + mesh.detach().cpu() + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, dtype=torch.int) + ) + self.mesh_dim_names = mesh_dim_names + + # private field to pre-generate DeviceMesh's hash + self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) + self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self))) + self._parent_mesh = _mesh_resources.get_parent_mesh(self) + + # Skip process group initialization if xla device. + # TODO(yeounoh) implement DeviceMesh backend and register XLA backend. + if device_type != 'xla': + # always try to create default (world) pg, even if it is not initialized + # already. The world pg is used for device mesh identity (rank) on each + # process (we need to know if the current global rank is in the mesh or not). + self._get_or_create_default_group() + if not self._parent_mesh: + self._init_process_groups() + + def device_mesh__getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> 'DeviceMesh': + """Monkeypatch device_mesh __getitem__ to nightly version. + + Slice the current DeviceMesh based on the mesh_dim_name given to create a child + DeviceMesh. + + Args: + mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh + to create a child DeviceMesh for. + + Returns: + A :class:`DeviceMesh` object + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). + Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). + Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). + Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). + Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). + Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). + + Example:: + >>> # xdoctest: +SKIP("no rank") + >>> from torch.distributed.device_mesh import DeviceMesh + >>> + >>> # Initialize device mesh as (2, 4) to represent the topology + >>> # of cross-host(dim 0), and within-host (dim 1). + >>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]]) + """ + if not self.mesh_dim_names: + raise RuntimeError('Cannot slice a DeviceMesh without mesh_dim_names.') + + mesh_dim_names = ( + (mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names + ) + + error_msg = ( + f'Invalid mesh_dim_name {mesh_dim_names} specified. ' + f'Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}.' + ) + + # When the dimension slicing out is equal to the mesh dimensions of the current DeviceMesh, + # we simply return self if the given slicing is valid. + if mesh_dim_names == self.mesh_dim_names: + return self + # Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names + # of the current DeviceMesh. + elif len(mesh_dim_names) < len(self.mesh_dim_names): + outermost_dim_name = mesh_dim_names[0] + if outermost_dim_name not in self.mesh_dim_names: + raise ValueError(error_msg) + outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name) + for i, j in zip( + mesh_dim_names, + self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)], + ): + if i != j: + raise ValueError(error_msg) + else: + raise ValueError(error_msg) + + submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names) + return submesh