Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matmul with DID loop split #3651

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Matmul with DID loop split #3651

wants to merge 8 commits into from

Conversation

Priya2698
Copy link
Collaborator

This PR modifies the hasTrivialAllocationDomain to consider if the tensorview has a DID loop split. In this case, we compare the corresponding iterdomains for logical and allocation domain across all but the sharded logical axis.

Note: This does not guarantee that MatmulOp with non-trivial stride order will work for DID loop split. I suspect it will require some additional changes to the MatmulOp::evaluate method.

@Priya2698
Copy link
Collaborator Author

!build

@Priya2698 Priya2698 requested review from samnordmann and wujingyue and removed request for samnordmann December 28, 2024 06:10
@Priya2698
Copy link
Collaborator Author

!test

tests/python/test_multidevice.py Outdated Show resolved Hide resolved
tests/python/test_multidevice.py Outdated Show resolved Hide resolved
csrc/ir/utils.cpp Outdated Show resolved Hide resolved
@Priya2698 Priya2698 force-pushed the pm/matmul_loop_split branch from ccffce6 to 242e4dd Compare January 14, 2025 23:53
Copy link

github-actions bot commented Jan 15, 2025

PR Reviewer Guide 🔍

(Review updated until commit 2aad6e6)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Logic Change

The hasTrivialAllocationDomain function has been modified to consider if the tensorview has a DID loop split. This change may affect the behavior of the MatmulOp class.

MatmulOp::MatmulOp(IrBuilderPasskey passkey, Val* out, Val* in_a, Val* in_b)
    : Expr(passkey) {
  addOutput(out);
  addInput(in_a);
Potential Bug

The inferShape function has been modified to take a const ExpressionEvaluator& instead of an ExpressionEvaluator&. This change may cause issues if the function is not properly handling the new const reference.

std::pair<std::vector<int64_t>, std::vector<int64_t>> inferShape(
    const TensorView* tv,
    std::vector<Val*> symbolic_sizes,
    std::vector<bool> expand_flags,
    const ExpressionEvaluator& expr_eval) {
  FUSER_PERF_SCOPE("fusion_executor::allocations::inferShape");

  // Allocate should be provided for intermediates. We just need to
  // grab a chunk of memory of the size dicatated by
Test Update

A new test case has been added to test the matmul operation with loop split.

@pytest.mark.mpi
def test_matmul_loop_split(multidevice_test):
    class Model(FusionDefinition):
        def __init__(self, num_devices, batch, sequence, hidden):
            super().__init__()
            self._num_devices = num_devices
            self._batch = batch
            self._sequence = sequence
            self._hidden = hidden

        def definition(self):
            d, b, s, e = self._num_devices, self._batch, self._sequence, self._hidden
            self.inp = self.define_tensor([b, s, e])
            self.weight = self.define_tensor([e, d * e])
            self.out = self.ops.matmul(self.inp, self.weight)
            self.add_output(self.out)

        def multidevice_schedule(self):
            for t in [self.inp, self.weight, self.out]:
                self.sched._set_device_mesh(t, mesh)

            # Shard N for weight (K, N)
            self.sched.split(self.weight, -1, d, False)
            self.sched.parallelize(self.weight, -2, nvfuser.ParallelType.mesh_x)
            self.sched.set_allocation_as_loop(self.weight)

            # Output of linear: {.., i{M}, i{N}, r{K}}
            # Shard N -> axis(-2)
            self.sched.split(self.out, -2, d, False)
            self.sched.parallelize(self.out, -3, nvfuser.ParallelType.mesh_x)
            self.sched.set_allocation_as_loop(self.out)

    d = multidevice_test.size
    mesh = nvfuser.DeviceMesh(range(d))
    rank = multidevice_test.rank

    torch.cuda.set_device(multidevice_test.local_rank)

    b, s, e = 2, 1024, 768
    inp_tensor = torch.randn(b, s, e, device="cuda")
    unsharded_weight_tensor = torch.randn(e, d * e)
    sharded_weight_tensor = multidevice_test.shard_tensor(
        unsharded_weight_tensor, -1, mesh
    )

    fd = Model(d, b, s, e)
    out_tensors = fd.execute([inp_tensor, sharded_weight_tensor])

    # [b, s, d*e]
    unsharded_out_tensor = torch.matmul(inp_tensor.cpu(), unsharded_weight_tensor)
    expected_out_tensor = multidevice_test.shard_tensor(unsharded_out_tensor, -1, mesh)
    # rtol is the same as the default for fp32. atol is slightly increased.
    torch.testing.assert_close(
        out_tensors[0], expected_out_tensor.squeeze(0), rtol=1.3e-6, atol=1e-3
    )


class QkvFormat(Enum):

@Priya2698
Copy link
Collaborator Author

!test

csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
csrc/ir/nodes.cpp Outdated Show resolved Hide resolved
@Priya2698
Copy link
Collaborator Author

!test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants