Skip to content

Commit

Permalink
Add support for 32B and 64B swizzles to hopper matmul scheduler (#3544)
Browse files Browse the repository at this point in the history
This PR adds support for 32B and 64B swizzles to StMatrix indexing and
to the hopper matmul scheduler.

### Key Index Change
The number of distinct swizzle rows is number of bytes for swizzle
divided by size of megabank (16B). The number of times a swizzle pattern
is repeated to fill core (8, 8) matrix is number of swizzle rows (8)
divided by number of distinct
rows.

```cpp
MmaInputSmemSwizzle swizzle = getSwizzle(out_tv);
int64_t swizzle_bytes = getBytesFromSwizzle(swizzle);
constexpr int64_t megabank_size_bytes = 16;
const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes;

int row = ...;
int col = ...;
constexpr int64_t swizzle_row_size = 8;
const int64_t swizzle_row_repetitions = swizzle_row_size / distinct_swizzle_row_size;
int64_t  row_in_swizzle_pattern = (row % swizzle_row_size) / swizzle_row_repetitions;
int64_t swizzle_col = col ^ row_in_swizzle_pattern;
```

### Testing Changes
* Added `mma_macro` as testing value.
* Created separate test suite called `Swizzle/HopperMatmulSchedulerTest`
to test `32B`, `64B`, `128B` swizzles.
  • Loading branch information
rdspring1 authored Dec 9, 2024
1 parent 07be8b9 commit 5a2184c
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 42 deletions.
66 changes: 45 additions & 21 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1559,15 +1559,15 @@ void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
}

static DataType getMmaInputAType(MmaMacro macro) {
int warp_group_size = isHopper(macro) ? 128 : 32;
int size = getM(macro) * getK(macro) / warp_group_size /
2 /* halves per 32bit register */;
int64_t warp_group_size = isHopper(macro) ? 128L : 32L;
int64_t size = getM(macro) * getK(macro) / warp_group_size /
2L /* halves per 32bit register */;
return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
}

static DataType getMmaInputBType(MmaMacro macro) {
int size = getN(macro) * getK(macro) / 32 /* threads per warp */ /
2 /* halves per 32bit register */;
int64_t size = getN(macro) * getK(macro) / 32L /* threads per warp */ /
2L /* halves per 32bit register */;
return ArrayType{std::make_shared<DataType>(DataType::UInt32), (size_t)size};
}

Expand Down Expand Up @@ -1842,8 +1842,8 @@ Val* hardCodedIndexGenerationForStMatrix(
// To account for the threadIdx.y, we have to add it to the offset:
// offset_from_tdy = threadIdx.y * tma_m * tma_n * 2 (half)
//
// Now, lets apply stmatrix tile to the TMA Box.
// [NO(2), MO(4), MI(16), NIO(4), NII(16)].
// Now, lets apply stmatrix tile (16, 16) to the TMA Box [NO(2), M(64), NI(64)].
// [NO(2), MO(4), MI(16), NIO(4), NII(16)].
//
// A warp group of 128 threads contains four warps. StMatrix is a warp-level
// operation, so four StMatrix operations can be issued simultaneously by the
Expand All @@ -1865,6 +1865,7 @@ Val* hardCodedIndexGenerationForStMatrix(
// domain is scheduled as [NO(2), M(64), NI(64)]. Therefore, we must store the
// data in shared memory in [M(64), NI(64)] contiguous tiles.
//
// NOTE: This offset is skipped if for-loop is trivial
// To account for the outer_index, we have to add it to the offset:
// offset_from_outer_index = outer_index * tma_m * NI(64) * 2 (half)
//
Expand Down Expand Up @@ -1928,8 +1929,13 @@ Val* hardCodedIndexGenerationForStMatrix(
// with the 8 rows of the matrix to avoid bank conflicts. This swizzle pattern
// is repeated along the rows of the TMA box.
//
// The number of distinct swizzle rows is number of bytes for swizzle divided by
// size of megabank (16B). The number of times a swizzle pattern is repeated to
// fill core (8, 8) matrix is number of swizzle rows (8) divided by number of
// distinct rows.
//
// Swizzle column
// row_in_swizzle_pattern = row % swizzle_row_size(8)
// row_in_swizzle_pattern = (row % swizzle_row_size(8)) / swizzle_repetitions
// swizzle_col = column XOR row_in_swizzle_pattern
//
// Calculate Tile Offset
Expand All @@ -1939,7 +1945,7 @@ Val* hardCodedIndexGenerationForStMatrix(
//
// Get shared memory offset
// smem_offset = offset_from_tdy + offset_from_outer_index + tile_offset
Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
Val* hardCodedIndexGenerationForStMatrixSwizzle(
const LoadStoreOp* ldst,
ForLoop* loop,
const int64_t stsm_m_tile,
Expand All @@ -1958,16 +1964,19 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(

NVF_ERROR(ldst->out()->isA<TensorView>());
TensorView* out_tv = ldst->out()->as<TensorView>();
NVF_ERROR(getSwizzle(out_tv) == MmaInputSmemSwizzle::B128);
MmaInputSmemSwizzle swizzle = getSwizzle(out_tv);
int64_t swizzle_bytes = getBytesFromSwizzle(swizzle);

// Constants
constexpr int64_t dtype_size = 2;
constexpr int64_t warp_size = 32;
constexpr int64_t swizzle_row_size = 8;
constexpr int64_t stsm_column_size = 8;
constexpr int64_t swizzle_n_tile = 64;
constexpr int64_t megabank_size_bytes = 16;

// Derived constants
const int64_t swizzle_n_tile = swizzle_bytes / dtype_size;
const int64_t distinct_swizzle_row_size = swizzle_bytes / megabank_size_bytes;
constexpr int64_t stsm_column_stride = stsm_column_size * dtype_size;
const int64_t swizzle_n_iter = swizzle_n_tile / stsm_n_tile;
const int64_t swizzle_n_tile_stride = swizzle_n_tile * dtype_size;
Expand Down Expand Up @@ -2000,8 +2009,6 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
Val* warp_id = SimplifyingIrBuilder::divExpr(TDX, warp_size_val);
Val* lane_id = SimplifyingIrBuilder::modExpr(TDX, warp_size_val);

Val* outer_index =
SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
Val* inner_index =
SimplifyingIrBuilder::modExpr(loop->index(), swizzle_n_iter_val);

Expand All @@ -2021,6 +2028,17 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
// Swizzle Column
Val* row_in_swizzle_pattern =
SimplifyingIrBuilder::modExpr(row, swizzle_row_size_val);

// The swizzle pattern is repeated to fill (8, 8) matrix for 64B and 32B
// swizzles. swizzle_row_iter is the number of repetitions to fill 8 rows
// with distict swizzle rows.
const int64_t swizzle_row_iter = swizzle_row_size / distinct_swizzle_row_size;
if (swizzle_row_iter > 1) {
Val* swizzle_row_iter_val =
IrBuilder::create<Val>(swizzle_row_iter, DataType::Index);
row_in_swizzle_pattern = SimplifyingIrBuilder::divExpr(
row_in_swizzle_pattern, swizzle_row_iter_val);
}
Val* swizzle_col = bitwise_xor(col, row_in_swizzle_pattern);

// Calculate Tile Offset
Expand All @@ -2031,16 +2049,22 @@ Val* hardCodedIndexGenerationForStMatrix128BSwizzle(
Val* offset = SimplifyingIrBuilder::addExpr(row_offset, col_offset);

// Calculate Tile offset
Val* tile_offset = IrBuilder::mulExpr(outer_index, tile_stride_val);
// Skip tile offset if loop is trivial.
if (!loop->stop()->isOneInt()) {
Val* outer_index =
SimplifyingIrBuilder::divExpr(loop->index(), swizzle_n_iter_val);
Val* tile_offset =
SimplifyingIrBuilder::mulExpr(outer_index, tile_stride_val);
offset = SimplifyingIrBuilder::addExpr(tile_offset, offset);
}

// Calculate TDY offset
Val* tdy_offset = IrBuilder::mulExpr(TDY, tdy_stride_val);
Val* tdy_offset = SimplifyingIrBuilder::mulExpr(TDY, tdy_stride_val);
offset = SimplifyingIrBuilder::addExpr(tdy_offset, offset);

// Create shared memory TensorIndex
Val* out_index = SimplifyingIrBuilder::addExpr(
IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)),
SimplifyingIrBuilder::addExpr(
tdy_offset, SimplifyingIrBuilder::addExpr(tile_offset, offset)));
IrBuilder::baseAddressExpr(ir_utils::getTvOutput(ldst)), offset);
Val* out = IrBuilder::create<kir::TensorIndex>(
dynamic_cast<TensorView*>(ldst->out()), out_index);
return out;
Expand Down Expand Up @@ -2092,11 +2116,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
ldst, for_loops_[0], m_tile, n_tile, m, n);
break;
case MmaInputSmemSwizzle::B128:
out = hardCodedIndexGenerationForStMatrix128BSwizzle(
case MmaInputSmemSwizzle::B64:
case MmaInputSmemSwizzle::B32:
out = hardCodedIndexGenerationForStMatrixSwizzle(
ldst, for_loops_[0], m_tile, n_tile, m, n);
break;
case MmaInputSmemSwizzle::B32:
case MmaInputSmemSwizzle::B64:
default:
NVF_ERROR("Unsupported Swizzle Type for StMatrix");
}
Expand Down
18 changes: 7 additions & 11 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,13 +1027,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
const int64_t tma_m = getM(params_->mma_macro);
const int64_t tma_n = getN(params_->mma_macro);

NVF_ERROR(
tma_n >= 64,
"Scheduler only supports 128B swizzle that requires N dimension of MMA ",
"macro to be >= 64, but received ",
tma_n,
".");

fusion_->manage("st_matrix_m_tile", stmatrix_tile_m);
fusion_->manage("st_matrix_n_tile", stmatrix_tile_n);
fusion_->manage("st_matrix_m", tma_m);
Expand Down Expand Up @@ -1084,12 +1077,14 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() {
dc->setAllocationDomain(s.as<IterDomain*>(), true);
}

MmaInputSmemSwizzle swizzle = tmaSwizzleSharedMemory(d_smem);

// Schedule shared memory cache; Output from StMatrix
scheduleStMatrixForMmaOutput(
d_smem, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n);
d_smem, swizzle, stmatrix_tile_m, stmatrix_tile_n, tma_m, tma_n);

// Schedule global memory output; Output from TMA Store
scheduleTMAStoreForMmaOutput(d, tma_m, tma_n);
scheduleTMAStoreForMmaOutput(d, swizzle, tma_m, tma_n);
}
}
}
Expand Down Expand Up @@ -1247,6 +1242,7 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() {

void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t tile_m,
int64_t tile_n,
int64_t tma_m,
Expand All @@ -1263,7 +1259,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(
mma_utils::MmaSwizzler::scheduleMmaOutputAllocation(tv->getLoopDomain());

// Create tma store allocation domain with swizzle
scheduleTMAStoreForMmaOutput(tv, tma_m, tma_n);
scheduleTMAStoreForMmaOutput(tv, swizzle, tma_m, tma_n);

tv->setLoopDomain(s.as<IterDomain*>());

Expand All @@ -1290,6 +1286,7 @@ void HopperMultipleMatmulScheduler::scheduleStMatrixForMmaOutput(

void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t m,
int64_t n) {
// [M(m), N(n)] -> [MO(1), MI(m), NO(1), NI(n)]
Expand All @@ -1301,7 +1298,6 @@ void HopperMultipleMatmulScheduler::scheduleTMAStoreForMmaOutput(
// [BDX, BDY, TDY, MO(1), NO(1), MI, NI]
// skip the first 5 iterDomains
int64_t num_ids_to_skip = 5;
MmaInputSmemSwizzle swizzle = MmaInputSmemSwizzle::B128;

NVF_ERROR(num_ids_to_skip >= 0);
if (swizzle == MmaInputSmemSwizzle::None) {
Expand Down
7 changes: 6 additions & 1 deletion csrc/scheduler/hopper_multi_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,19 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler {
//! registers to shared memory.
void scheduleStMatrixForMmaOutput(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t tile_m,
int64_t tile_n,
int64_t tma_m,
int64_t tma_n);

//! Schedules the copy operation of output of a Mma op which resided in the
//! shared memory to global memory.
void scheduleTMAStoreForMmaOutput(TensorView* tv, int64_t m, int64_t n);
void scheduleTMAStoreForMmaOutput(
TensorView* tv,
MmaInputSmemSwizzle swizzle,
int64_t m,
int64_t n);

// Map TensorView's iterDomain to its ValGroup.
// Then, find the MatmulDimRole for the ValGroup.
Expand Down
61 changes: 52 additions & 9 deletions tests/cpp/test_matmul_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3119,32 +3119,51 @@ using HopperMatmulSchedulerTestParams = std::tuple<
bool, // b_k_inner
int64_t, // M
int64_t, // N
int64_t // K
>;
int64_t, // K
MmaMacro>;

std::string hopperTestName(
const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) {
std::ostringstream os;
bool use_smem_epilogue;
bool a_k_inner, b_k_inner;
int64_t M, N, K;
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = info.param;
MmaMacro mma_macro;
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
info.param;
os << (a_k_inner ? "K" : "M");
os << (b_k_inner ? "K" : "N");
os << "_" << M << "_" << N << "_" << K;
os << "_MmaMacro_" << mma_macro_to_str_map.at(mma_macro);
if (use_smem_epilogue) {
os << "_tma_store";
}
return os.str();
}

std::string hopperTestNameSwizzle(
const testing::TestParamInfo<HopperMatmulSchedulerTestParams>& info) {
std::unordered_map<MmaMacro, std::string> mma_macro_to_swizzle_str_map = {
{MmaMacro::Hopper_64_256_16, "128BSwizzle"},
{MmaMacro::Hopper_64_128_16, "128BSwizzle"},
{MmaMacro::Hopper_64_64_16, "128BSwizzle"},
{MmaMacro::Hopper_64_32_16, "64BSwizzle"},
{MmaMacro::Hopper_64_16_16, "32BSwizzle"}};
MmaMacro mma_macro = std::get<6>(info.param);
std::ostringstream os;
os << hopperTestName(info);
os << "_" << mma_macro_to_swizzle_str_map.at(mma_macro);
return os.str();
}

class HopperMatmulSchedulerTest
: public NVFuserFixtureParamTest<HopperMatmulSchedulerTestParams> {
protected:
void SetUp() {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(9, 0, 10, 0);

std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K) = GetParam();
std::tie(use_smem_epilogue, a_k_inner, b_k_inner, M, N, K, mma_macro) =
GetParam();

if (a_k_inner) {
layout = b_k_inner ? MmaLayout::TN : MmaLayout::TT;
Expand All @@ -3159,14 +3178,17 @@ class HopperMatmulSchedulerTest
// Create custom Matmul Params
MatMulTileOptions gemm_tile;
// TODO cta tile is a multiple of mma macro for hopper.
gemm_tile.cta_tile = GemmTile(128, 256, 16);
// Default cta_tile configuration is 2-CTA.
gemm_tile.cta_tile =
GemmTile(2 * getM(mma_macro), getN(mma_macro), getK(mma_macro));

// TODO warp tile is (macroM, macroN, macroK) for hopper.
gemm_tile.warp_tile = GemmTile(64, 128, 16);
gemm_tile.warp_tile =
GemmTile(getM(mma_macro), getN(mma_macro), getK(mma_macro));

mparams.supported_vec_size = {8, 8, 4};

mparams.mma_macro = MmaMacro::Hopper_64_128_16;
mparams.mma_macro = mma_macro;

mparams.use_smem_epilogue = use_smem_epilogue;

Expand Down Expand Up @@ -3203,6 +3225,7 @@ class HopperMatmulSchedulerTest
bool use_smem_epilogue;
bool a_k_inner, b_k_inner;
int64_t M, N, K;
MmaMacro mma_macro;
std::unique_ptr<Fusion> fusion_up;
Fusion* fusion;
std::unique_ptr<FusionGuard> fusion_guard;
Expand Down Expand Up @@ -3275,16 +3298,36 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) {
}

INSTANTIATE_TEST_SUITE_P(
,
General,
HopperMatmulSchedulerTest,
testing::Combine(
testing::Bool(), // use_smem_epilogue
testing::Bool(), // a_k_inner
testing::Bool(), // b_k_inner
testing::Values(512), // M
testing::Values(256), // N
testing::Values(64) // K
testing::Values(64), // K
testing::Values(MmaMacro::Hopper_64_128_16) // mma_macros
),
hopperTestName);

INSTANTIATE_TEST_SUITE_P(
Swizzle,
HopperMatmulSchedulerTest,
testing::Combine(
testing::Values(true), // use_smem_epilogue
testing::Bool(), // a_k_inner
testing::Bool(), // b_k_inner
testing::Values(512), // M
testing::Values(256), // N
testing::Values(64), // K
testing::Values(
MmaMacro::Hopper_64_256_16,
MmaMacro::Hopper_64_128_16,
MmaMacro::Hopper_64_64_16,
MmaMacro::Hopper_64_32_16,
MmaMacro::Hopper_64_16_16) // mma_macros
),
hopperTestNameSwizzle);

} // namespace nvfuser
Loading

0 comments on commit 5a2184c

Please sign in to comment.