-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Load mma operands to shared memory with TMA #3320
Conversation
523c05b
to
b0b6f22
Compare
d8bc1a6
to
7c8f375
Compare
a04848d
to
3002112
Compare
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First pass. Looks good so far. One question: how will we handle partial vectorization?
return MmaInputSmemSwizzle::None; // No need to swizzle in this case. | ||
} | ||
|
||
// 128B swizzle results in 8 x 8 matrix given half precision inputs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we only using 128B swizzle or do we plan to support the smaller swizzles as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to update this comment.
@@ -1332,6 +1396,8 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { | |||
} | |||
} | |||
|
|||
/* | |||
// TODO Investigate. Disable loop rotation with tma circular buffering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My first thought was can we just disable this parameter for loop rotation, but I realized that we do not actually respect the rotate_ldmatrix_out_of_main_loop
parameter in MatmulParams
. In fact, I just ran git log --patch csrc/scheduler/ | grep -C20 rotate_
and sifted through the results. I don't think we've ever used that parameter :-D cc @zasdfgbnm .
add tmaSwizzleSharedMemory and disable loop rotation update cacheOperandsToSmem
3002112
to
f8aa777
Compare
Do you mean when the tensor is not 16B aligned? You can overcopy with TMA, cp.async, or regular LDG + STS. |
Yeah exactly. So if we had K=60 and that is the inner dimension of each of the operands, in the Ampere scheduler we need to handle them differently when we generate the kernel since we can only do 4-element reads for the cp.async call then in stead of 8-element reads. But I don't see where that kind of alignment analysis comes in when using TMA; will TMA handle misaligned boxes dynamically using the same compiled kernel as for fully-aligned inputs? EDIT: is this computed on the host side in the TMA descriptor? |
TMA should automatically handle the case when K=60 by filling the out-of-bounds accesses. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!test |
Looks like you just need to guard |
!test |
Stacked on #3320 This PR: * Schedules the MMA instruction result for the HopperMultiMatmulScheduler. * Removes some unused methods that are no longer necessary. * Checks that there is "no prologue". Specifically, that we have `gmem -LoadStoreOp-> smem -MmaOp->`. This can currently not be done unless we create the MmaOp at definition using `fusedMultiplySum` (see #1628). * Checks that MmaOp output has logical order MNK. If not then a root->logical reorder should have been created at definition. (maybe this should be made easier as an option in `fusedMultiplySum`). This PR does not schedule split-K or TMA stores of the output. --------- Co-authored-by: Ryan Spring <[email protected]>
This PR modifies
schedulePrologues
to use TMA loads to move mma operands to shared memory. Stacked on #3324 and #3310.Details
CpAsyncBulkTensorTile
LoadStoreOp.LdMatrix
operation with basic set.scheduleOperandSmemStores
to apply swizzling to avoid bank conflicts.swizzleSharedMemory
by moving the analysis component to a separate function namedanalyzeSwizzleSharedMemory
.tmaSwizzleSharedMemory
function that usesanalyzeSwizzleSharedMemory
and then finds the appropriate tma swizzle format.