Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Load mma operands to shared memory with TMA #3320

Merged
merged 14 commits into from
Nov 8, 2024
Merged

Conversation

rdspring1
Copy link
Collaborator

@rdspring1 rdspring1 commented Oct 31, 2024

This PR modifies schedulePrologues to use TMA loads to move mma operands to shared memory. Stacked on #3324 and #3310.

Details

  1. Input operands are loaded into shared memory via CpAsyncBulkTensorTile LoadStoreOp.
  2. Replace LdMatrix operation with basic set.
  3. Modified scheduleOperandSmemStores to apply swizzling to avoid bank conflicts.
  4. Refactor swizzleSharedMemory by moving the analysis component to a separate function named analyzeSwizzleSharedMemory.
  5. Create tmaSwizzleSharedMemory function that uses analyzeSwizzleSharedMemory and then finds the appropriate tma swizzle format.
  6. Disable loop rotation. There is an issue with tma loads and circular buffering. Not sure if loop rotation is required for hopper matmul.
  7. Expect hopper matmul tests to give incorrect results.

@rdspring1 rdspring1 force-pushed the hopper_matmul_tests branch from d8bc1a6 to 7c8f375 Compare November 1, 2024 02:54
@rdspring1 rdspring1 marked this pull request as ready for review November 1, 2024 03:08
@rdspring1 rdspring1 changed the title Loading operands with TMA Load mma operands to shared memory with TMA Nov 1, 2024
@rdspring1 rdspring1 force-pushed the multi_matmul_tma branch 2 times, most recently from a04848d to 3002112 Compare November 2, 2024 00:41
@rdspring1
Copy link
Collaborator Author

!test

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First pass. Looks good so far. One question: how will we handle partial vectorization?

csrc/scheduler/hopper_multi_matmul.cpp Outdated Show resolved Hide resolved
return MmaInputSmemSwizzle::None; // No need to swizzle in this case.
}

// 128B swizzle results in 8 x 8 matrix given half precision inputs.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we only using 128B swizzle or do we plan to support the smaller swizzles as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to update this comment.

csrc/scheduler/hopper_multi_matmul.cpp Show resolved Hide resolved
tests/cpp/test_matmul_scheduler.cpp Show resolved Hide resolved
tests/cpp/test_matmul.cpp Show resolved Hide resolved
@@ -1332,6 +1396,8 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() {
}
}

/*
// TODO Investigate. Disable loop rotation with tma circular buffering
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My first thought was can we just disable this parameter for loop rotation, but I realized that we do not actually respect the rotate_ldmatrix_out_of_main_loop parameter in MatmulParams. In fact, I just ran git log --patch csrc/scheduler/ | grep -C20 rotate_ and sifted through the results. I don't think we've ever used that parameter :-D cc @zasdfgbnm .

Base automatically changed from hopper_matmul_tests to main November 3, 2024 17:26
@rdspring1
Copy link
Collaborator Author

how will we handle partial vectorization?

Do you mean when the tensor is not 16B aligned? You can overcopy with TMA, cp.async, or regular LDG + STS.

@jacobhinkle
Copy link
Collaborator

jacobhinkle commented Nov 5, 2024

how will we handle partial vectorization?

Do you mean when the tensor is not 16B aligned? You can overcopy with TMA, cp.async, or regular LDG + STS.

Yeah exactly. So if we had K=60 and that is the inner dimension of each of the operands, in the Ampere scheduler we need to handle them differently when we generate the kernel since we can only do 4-element reads for the cp.async call then in stead of 8-element reads. But I don't see where that kind of alignment analysis comes in when using TMA; will TMA handle misaligned boxes dynamically using the same compiled kernel as for fully-aligned inputs?

EDIT: is this computed on the host side in the TMA descriptor?

@rdspring1
Copy link
Collaborator Author

rdspring1 commented Nov 5, 2024

TMA should automatically handle the case when K=60 by filling the out-of-bounds accesses.
If the tensor is not 16B aligned, TMA will fail and you need to use regular LDG + STS accesses.

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@rdspring1
Copy link
Collaborator Author

!test

@jacobhinkle
Copy link
Collaborator

Looks like you just need to guard AmpereMatmulBroadcastBatch. I noticed I needed this in #3278 but I was too lazy to merge that upstream to this PR for you. https://github.com/NVIDIA/Fuser/pull/3278/files#diff-64fc4e7bfbc5b9f95ac3dc5823bd99b683b048926805c13310ce6a8ef8032289R147-R148

@rdspring1
Copy link
Collaborator Author

!test

@rdspring1 rdspring1 merged commit 114e9a1 into main Nov 8, 2024
47 checks passed
@rdspring1 rdspring1 deleted the multi_matmul_tma branch November 8, 2024 15:33
jacobhinkle added a commit that referenced this pull request Nov 13, 2024
Stacked on #3320 

This PR:
* Schedules the MMA instruction result for the
HopperMultiMatmulScheduler.
* Removes some unused methods that are no longer necessary.
* Checks that there is "no prologue". Specifically, that we have `gmem
-LoadStoreOp-> smem -MmaOp->`. This can currently not be done unless we
create the MmaOp at definition using `fusedMultiplySum` (see #1628).
* Checks that MmaOp output has logical order MNK. If not then a
root->logical reorder should have been created at definition. (maybe
this should be made easier as an option in `fusedMultiplySum`).

This PR does not schedule split-K or TMA stores of the output.

---------

Co-authored-by: Ryan Spring <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants