Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split Hopper MMA by warp-tile before instruction tile #3642

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
65 changes: 49 additions & 16 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,53 @@ void HopperMultipleMatmulScheduler::transformLikeMmaOutput(
bool is_mma_result) {
// TODO Add constraints

auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr {
return (is_mma_result) ? idx - 1 : idx;
};

// Original: [..., Mo, No, Mi, Ni]
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro));
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro));
// After Split: [..., Mo, No, Mio, Mii, Nio, Nii]
tv->reorder({{apply_k_dim_offset(-3), apply_k_dim_offset(-2)}});
// After Reorder: [..., Mo, No, Mio, Nio, Mii, Nii]
tv->merge(apply_k_dim_offset(-4));
// After Merge: [..., Mo, No, Mio * Nio, Mii, Nii]
tv->axis(apply_k_dim_offset(-3))->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo, No, Mio * Nio (TIDy), Mii, Nii]
// The input is originally block tiled so that the inner dims are the CTA tile
// size
// Original: [..., M, N(, K)]
// We split this into warp tiles then instruction tiles
if (is_mma_result) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: since there is no code in common between these branches, we should split this into two separate functions.

// Original: [..., M, N, K]
tv->split(-3, params_->tile_sizes.warp_tile.m);
tv->split(-3, getM(params_->mma_macro));
tv->split(-2, params_->tile_sizes.warp_tile.n);
tv->split(-2, getN(params_->mma_macro));
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
Comment on lines +47 to +49
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 is this enough or is #3616 still needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is all that is required for scheduler changes.

// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Ko, Kw, Ki]
tv->reorder({
{-9, -9}, // Mo
{-8, -6}, // Mw
{-7, -3}, // Mi
{-6, -8}, // No
{-5, -5}, // Nw
{-4, -2}, // Ni
{-3, -7}, // Ko
{-2, -4}, // Kw
{-1, -1}, // Ki
});
// After Reorder: [..., Mo, No, Ko, Mw, Nw, Kw, Mi, Ni, Ki]
tv->merge(-9);
// After Merge: [..., Mo * No, Ko, Mw, Nw, Kw, Mi, Ni]
tv->axis(-8)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Ko, Mw, Nw, Kw, Mi, Ni, Ki]
} else {
// Original: [..., M, N]
tv->split(-2, params_->tile_sizes.warp_tile.m);
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, params_->tile_sizes.warp_tile.n);
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
tv->reorder({
{-3, -5},
{-2, -3},
});
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
tv->merge(-6);
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
tv->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
Expand Down Expand Up @@ -490,8 +523,8 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
// tile is a multiple of the macro size because stmatrix stores results from
// wgmma to shared memory. For maximum inlining and to reduce shared memory
// usage, the tma tile is mma_macro size.
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);
const int64_t tma_m = params_->tile_sizes.warp_tile.m;
const int64_t tma_n = params_->tile_sizes.warp_tile.n;

fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
Expand Down
76 changes: 76 additions & 0 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4243,4 +4243,80 @@ TEST_F(HopperMatmulTest, HSH_TT_UseScheduler) {
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

// This tests that we can use a small instruction tile with a medium size
// warpgroup tile and a large CTA tile.
TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
Fusion fusion;
FusionGuard fg(&fusion);

constexpr int64_t M = 2048, N = 2048, K = 8192;
const auto dtype = DataType::Half;

auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M
auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N
fusion.addInput(tv0);
fusion.addInput(tv1);

auto tv2 = fusedMultiplySum(tv0, tv1, {0});

// Reorder the accumulator as [M, N, K]
// [K, M, N] -> [M, N, K]
tv2->reorder({{-3, -1}});
tv2->commitLeafToLogical();

auto tv3 = castOp(DataType::Half, tv2);
fusion.addOutput(tv3);

auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
auto a_ref = at::randn({K, M, 1}, options);
auto b_ref = at::randn({K, 1, N}, options);
auto out_ref = at::matmul(a_ref.squeeze().t(), b_ref.squeeze()).to(at::kHalf);

MatMulTileOptions gemm_tile;
// Regardless of the instruction, this should result in 2 warp groups i.e. 256
// threads
gemm_tile.cta_tile = GemmTile(256, 256, 32);
gemm_tile.warp_tile = GemmTile(128, 128, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
mparams.mma_macro = MmaMacro::Hopper_64_64_16;
mparams.tile_sizes = gemm_tile;
mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
// NOTE: disabling smem use for this test since we currrently hit a bank
// conflict.
// TODO: enable smem epilogue once stmatrix is updated
mparams.use_smem_epilogue = false;
mparams.cluster_dims = {2, 1, 1};
mparams.promote_prologue_smem_reuse = false;

SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
->schedule(&fusion, &mparams);

std::vector<c10::IValue> inputs = {a_ref, b_ref};

KernelExecutor ke;
ke.compile(&fusion, inputs);
EXPECT_TRUE(getBankConflictInfo(ke.kernel()).empty());
EXPECT_FALSE(
PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse(ke.kernel()));

auto cg_outputs = ke.run(inputs);

// Check number of launched threads matches what we expect
EXPECT_EQ(ke.lastLaunchParams().bdimx(), 128);
EXPECT_EQ(ke.lastLaunchParams().bdimy(), 4)
<< " expected 4 warp groups (BIDy==4) but found BIDy=="
<< ke.lastLaunchParams().bdimy();

// Relax tolerance for larger sum due to large K
EXPECT_TRUE(cg_outputs[0].allclose(out_ref, 1e-6 * K, 1e-6 * K));
}

} // namespace nvfuser
Loading