Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A FusionDefinition wrapper that takes/produces DTensors. #3703

Open
wants to merge 7 commits into
base: wjy/dist
Choose a base branch
from

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 14, 2025

This is a proof of concept for integrating nvFuser's model parallelism to the framework.

@wujingyue wujingyue marked this pull request as draft January 14, 2025 06:18
@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

Cc @jjsjann123

wujingyue added a commit that referenced this pull request Jan 14, 2025
wujingyue added a commit that referenced this pull request Jan 14, 2025
Copy link

github-actions bot commented Jan 16, 2025

PR Reviewer Guide 🔍

(Review updated until commit f9082a9)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Possible Logic Change

The new local property added to the FusionDefinition class may change the behavior of the axis_sharded_on method. Reviewers should verify that the logic is correct and consistent with the rest of the codebase.

@property
def local(self) -> torch.Tensor:
    """Returns the underlying local tensor."""
    return self._dtensor.local
Potential Bug

The FusionDefinitionWrapper class assumes that the define_fusion function will always return a FusionDefinition object. However, if the function returns None or an object of a different type, the wrapper will fail. Reviewers should add error handling to ensure that the wrapper can handle such cases.

class FusionDefinitionWrapper:
    def __init__(self, define_fusion: Callable[[FusionDefinition], None]):
        """Wraps a function that defines a fusion without `multidevice_schedule`."""
        self._define_fusion = define_fusion

    def __call__(self, in_dtensors: Iterable[DTensor]) -> list[DTensor]:
        define_fn = self._define_fusion

        class Model(FusionDefinition):
            def definition(self):
                define_fn(self)

            def _find_tensor_by_index(self, index: int) -> nvfuser.Tensor:
                for t in self.sched.tensors():
                    if t.index == index:
                        return t
                return None

            def multidevice_schedule(self):
                for in_tensor_index, in_dtensor in zip(self.inputs(), in_dtensors):
                    in_tensor = self._find_tensor_by_index(in_tensor_index)

                    # Set the device mesh.
                    assert (
                        in_dtensor.device_mesh.ndim == 1
                    ), "nvFuser's Python API only supports 1D meshes."
                    mesh = nvfuser.DeviceMesh(
                        in_dtensor.device_mesh.mesh.view(-1).tolist()
                    )
                    self.sched._set_device_mesh(in_tensor, mesh)

                    # Parallelize.
                    assert len(in_dtensor.placements) == 1, "Expect a 1D mesh"
                    placement: Placement = in_dtensor.placements[0]
                    if placement.is_shard():
                        dim = cast(Shard, placement).dim
                        self.sched.parallelize(
                            in_tensor, dim, nvfuser.ParallelType.mesh_x
                        )

        in_tensors = [in_dtensor.to_local() for in_dtensor in in_dtensors]
        model = Model()
        out_tensors = model.execute(in_tensors)

        for i, out_tensor in enumerate(out_tensors):
            if isinstance(out_tensor, nvfuser.DistributedTensor):
                mesh = dist.device_mesh.init_device_mesh("cuda", [out_tensor.mesh.size])
                placements: list[Placement] = []
                for parallel_type in [nvfuser.ParallelType.mesh_x]:
                    axis: int = out_tensor.axis_sharded_on(parallel_type)
                    placements.append(Replicate() if axis == -1 else Shard(axis))
                out_tensors[i] = DTensor.from_local(out_tensor.local, mesh, placements)
        return out_tensors

@wujingyue wujingyue changed the title A custom op that wraps a FusionDefinition and takes/produces DTensors. A FusionDefinition wrapper that takes/produces DTensors. Jan 16, 2025
@wujingyue wujingyue changed the base branch from main to wjy/dist January 19, 2025 08:03
@wujingyue wujingyue marked this pull request as ready for review January 19, 2025 17:40
@wujingyue
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant