Skip to content

Commit

Permalink
Monkeypatch Device Mesh ND Slicing (mosaicml#3302)
Browse files Browse the repository at this point in the history
* add patch

* fix lint

* link to PR
  • Loading branch information
mvpatel2000 authored May 18, 2024
1 parent 3128cee commit bb1b7cf
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 0 deletions.
9 changes: 9 additions & 0 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,12 @@ def patch_pytorch():
state_dict.set_model_state_dict = set_model_state_dict
state_dict.set_optimizer_state_dict = set_optimizer_state_dict
state_dict._get_fqns = _get_fqns

# Monkeypatch for ND child submeshes
# PR: https://github.com/pytorch/pytorch/pull/119752
from torch.distributed.device_mesh import DeviceMesh, _MeshEnv

from composer.trainer.mosaic_fsdp_utils import create_child_mesh, device_mesh__getitem__, device_mesh__init__
_MeshEnv.create_child_mesh = create_child_mesh
DeviceMesh.__getitem__ = device_mesh__getitem__
DeviceMesh.__init__ = device_mesh__init__
146 changes: 146 additions & 0 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,3 +792,149 @@ def create_global_plan(
self.metadata = metadata

return self.global_plan, self.metadata

from torch.utils._typing_utils import not_none
from torch.distributed.device_mesh import DeviceMesh

def create_child_mesh(
self,
device_mesh,
mesh_dim_names: Tuple[str],
):
"""Monkeypatch create_child_mesh to nightly version."""
# swap the current dim to the last dim then reshape to flatten out other
# dims, so we can just extract the list of ranks which contains cur_rank.
mesh_dims = [
not_none(device_mesh.mesh_dim_names).index(mesh_dim_name)
for mesh_dim_name in mesh_dim_names
]
cur_rank = device_mesh.get_rank()
mesh = device_mesh.mesh
all_mesh_dims = list(range(mesh.ndim))
for mesh_dim in mesh_dims:
# remove not pop b/c we want the value of the ind removed not it's position in the list
# because this list dynamically changes.
all_mesh_dims.remove(mesh_dim)

mesh_sizes = [device_mesh.mesh.size(mesh_dim) for mesh_dim in mesh_dims]

pg_ranks_by_dim = device_mesh.mesh.permute(
*all_mesh_dims, *mesh_dims,
).reshape(-1, *mesh_sizes)

for mesh_nd in pg_ranks_by_dim:
if cur_rank in mesh_nd:
sub_mesh = DeviceMesh(
device_mesh.device_type,
mesh_nd,
mesh_dim_names=mesh_dim_names,
)
res_sub_mesh = sub_mesh

res_sub_mesh._dim_group_infos = [ # type: ignore[possibly-undefined]
device_mesh._dim_group_infos[mesh_dim] for mesh_dim in mesh_dims
]

# Assign the current DeviceMesh as the parent of the child DeviceMesh.
self.child_to_parent_mapping[res_sub_mesh] = device_mesh
return res_sub_mesh

from torch.distributed.device_mesh import _mesh_resources

def device_mesh__init__(
self,
device_type: str,
mesh,
*,
mesh_dim_names: Optional[Tuple[str, ...]] = None,
) -> None:
"""Monkeypatch device mesh __init__ to nightly version."""
self.device_type = device_type
if isinstance(mesh, torch.Tensor) and mesh.device.type != 'cpu':
raise ValueError(f'`mesh` must be a CPU tensor, got {mesh}')
self.mesh = (
mesh.detach().cpu()
if isinstance(mesh, torch.Tensor)
else torch.tensor(mesh, dtype=torch.int)
)
self.mesh_dim_names = mesh_dim_names

# private field to pre-generate DeviceMesh's hash
self._flatten_mesh_list = tuple(self.mesh.flatten().tolist())
self._hash = hash((self._flatten_mesh_list, self.mesh.shape, id(self)))
self._parent_mesh = _mesh_resources.get_parent_mesh(self)

# Skip process group initialization if xla device.
# TODO(yeounoh) implement DeviceMesh backend and register XLA backend.
if device_type != 'xla':
# always try to create default (world) pg, even if it is not initialized
# already. The world pg is used for device mesh identity (rank) on each
# process (we need to know if the current global rank is in the mesh or not).
self._get_or_create_default_group()
if not self._parent_mesh:
self._init_process_groups()

def device_mesh__getitem__(self, mesh_dim_names: Union[str, Tuple[str]]) -> 'DeviceMesh':
"""Monkeypatch device_mesh __getitem__ to nightly version.
Slice the current DeviceMesh based on the mesh_dim_name given to create a child
DeviceMesh.
Args:
mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh
to create a child DeviceMesh for.
Returns:
A :class:`DeviceMesh` object
The following program runs on each process/rank in an SPMD manner. In this example, we have 2
hosts with 4 GPUs each.
Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]).
Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]).
Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]).
Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]).
Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]).
Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]).
Example::
>>> # xdoctest: +SKIP("no rank")
>>> from torch.distributed.device_mesh import DeviceMesh
>>>
>>> # Initialize device mesh as (2, 4) to represent the topology
>>> # of cross-host(dim 0), and within-host (dim 1).
>>> mesh = DeviceMesh(device_type="cuda", mesh=[[0, 1, 2, 3],[4, 5, 6, 7]])
"""
if not self.mesh_dim_names:
raise RuntimeError('Cannot slice a DeviceMesh without mesh_dim_names.')

mesh_dim_names = (
(mesh_dim_names,) if isinstance(mesh_dim_names, str) else mesh_dim_names
)

error_msg = (
f'Invalid mesh_dim_name {mesh_dim_names} specified. '
f'Valid mesh_dim_names should be a contiguous subsequence of {self.mesh_dim_names}.'
)

# When the dimension slicing out is equal to the mesh dimensions of the current DeviceMesh,
# we simply return self if the given slicing is valid.
if mesh_dim_names == self.mesh_dim_names:
return self
# Check if the user-provided slicing is a valid contiguous subsequence of the mesh_dim_names
# of the current DeviceMesh.
elif len(mesh_dim_names) < len(self.mesh_dim_names):
outermost_dim_name = mesh_dim_names[0]
if outermost_dim_name not in self.mesh_dim_names:
raise ValueError(error_msg)
outermost_dim_idx = self.mesh_dim_names.index(outermost_dim_name)
for i, j in zip(
mesh_dim_names,
self.mesh_dim_names[outermost_dim_idx : len(mesh_dim_names)],
):
if i != j:
raise ValueError(error_msg)
else:
raise ValueError(error_msg)

submesh = _mesh_resources.create_child_mesh(self, mesh_dim_names)
return submesh

0 comments on commit bb1b7cf

Please sign in to comment.