-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Split Hopper MMA by warp-tile before instruction tile #3642
base: main
Are you sure you want to change the base?
Conversation
!test |
The bank conflict came from stmatrix scheduling which needs to be updated. I will do that in a separate PR. For now, I've disabled smem epilogue in the included test. |
!test |
When I manually disable stmatrix but keep TMA store, I still hit a bank conflict and misaligned address in the smem read when doing the TMA store. The epilogue looks like this: asm volatile("wgmma.commit_group.sync.aligned;\n");
asm volatile("wgmma.wait_group.sync.aligned %0;\n"::"n"(0LL):"memory");
__syncthreads();
#pragma unroll
for(nvfuser_index_t i50 = 0; i50 < 16; ++i50) {
nvfuser_index_t i51;
i51 = 4 * i50;
#pragma unroll
for(nvfuser_index_t i52 = 0; i52 < 2; ++i52) {
nvfuser_index_t i53;
i53 = i51 + (2 * i52);
Array<__half, 2, 2> T6;
#pragma unroll
for(nvfuser_index_t i54 = 0; i54 < 2; ++i54) {
T6[i54]
= __float2half(T2[(i53 + i54)]);
}
loadGeneric<__half, 2>( &T7[(i17 + (128 * i52))], &T6[0]);
}
__syncthreads();
asm volatile("fence.proxy.async;\n");
if (b24) {
Hopper::cpAsyncBulkTensorTileS2G((Hopper::CpAsyncBulkTensorTileS2GIndex<2>{ ptr19, (Array<nvfuser_index_t, 2, 1>{(i20 + (8 * i50)), i21}) }), i18);
}
__syncthreads();
asm volatile("cp.async.bulk.commit_group;\n");
asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory");
}
asm volatile("cp.async.bulk.commit_group;\n");
asm volatile("cp.async.bulk.wait_group.read %0;\n"::"n"(0LL):"memory"); The misaligned read happens with threadIdx.y = 3;
i11 = ((nvfuser_index_t)threadIdx.y) / 2; // =1
i12 = 2048 * i11; // =2048
i14 = ((nvfuser_index_t)threadIdx.y) % 2; // =1
i18 = (toSmem(T7) + i12) + (16 * i14); // =toSmem(T7) + 2064
|
mma result before this PR:
And after this PR:
|
Note that I can enable smem epilogue and the test passes if I use |
I think this covers the motivation for #3616
// K dimension is present for mma_result | ||
tv->split(-1, params_->tile_sizes.warp_tile.k); | ||
tv->split(-1, getK(params_->mma_macro)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rdspring1 is this enough or is #3616 still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is all that is required for scheduler changes.
// size | ||
// Original: [..., M, N(, K)] | ||
// We split this into warp tiles then instruction tiles | ||
if (is_mma_result) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: since there is no code in common between these branches, we should split this into two separate functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to remove this limitation to handle all matmul parameter configurations?
CTA tile must match warp tile K dimension for Hopper matmul but found MatMulTileOptions: warp tile [64, 256, 32], CTA tile [128, 256, 64]
// K dimension is present for mma_result | ||
tv->split(-1, params_->tile_sizes.warp_tile.k); | ||
tv->split(-1, getK(params_->mma_macro)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is all that is required for scheduler changes.
I see |
I just checked this by modifying the test to expect 5 warpgroups instead of four and by adding a Fuser/csrc/scheduler/hopper_multi_matmul.cpp Lines 677 to 690 in f5e084c
The result for me is a passing test but perf drops. |
I might be confused here. The thing is that the K dimension is treated differently from the M and N dimensions in these tile definitions. The instruction tile's K dimension is clear, and the warp tile's K dimension (I think) signifies how much data we should load at a time then we can loop to compute instructions over all the loaded data. The CTA tile's M and N dimensions specify the tiling of the output, but what does the Note that this restriction |
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K)); | ||
} | ||
|
||
TEST_F(HopperMatmulTest, ScheduleWithTranslation) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is pretty much identical to the previous one, but it uses a MatmulOp
instead of fusedMultiplySum
. This is currently failing (passes on main) with
C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/pass/circular_buffer.cpp":160, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. No IfThenElse should exist yet:
IF ElectSync:
MBarrierArriveExpectTx(T9_s[i408] view( T9 ), 4096)
FOR i372 in iB28{16}:
FOR i375 in iB34{2}:
FOR i373 in iB31{4}:
FOR i376 in iB35{2}:
FOR i374 in iB33{8}:
T3_s___half[iblockIdx.x24{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[0] ), 128) )}, bS22{1}, iS20{( ceilDiv(( (( (( getMetaData(T0) )).logical_size ))[1] ), 16) )}, bS23{256}, iS26{1}, iB28{16}, iB34{2}, iB31{4}, iB35{2}, iB33{8}] ca_pos( 5 )
= CpAsyncBulkTensorTile( T0_g___half[iS170{( (( (( getMetaData(T0) )).logical_size ))[0] )}, iS171{( (( (( getMetaData(T0) )).logical_size ))[1] )}] )
This is on hold temporarily while I investigate decoupling math warp groups by splitting by warp tile before the TMA/MMA scheduling. That would be a different approach that would let us schedule entire K loop of one math group before the next group's K loop, allowing some epilogue overlap between math groups in addition to overlapping the DMA warps. |
!test |
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
Currently we ignore the warp tile parameter when scheduling Hopper matmuls (see #3636). This PR introduces a test with different CTA, warp, and instruction tiles and modifies the Hopper scheduler to split by warp tile in addition to instruction tile. Note that the instruction tile split results in two serial loop domain so we wind up executing multiple mma instructions in each main loop. In the included example,
warp_tile
is 64, 128, 16 and the macro isHopper_64_8_16
. In this case, there are 128/8 = 16 instruction tiles per warp tile so the generated main loop looks like this:Fixes #3636