Skip to content

Commit

Permalink
more fp8 tuning for decode (pytorch#3576)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3576

X-link: facebookresearch/FBGEMM#661

This diff includes:

1. Allow specifying split-K in the common tempalate
2. Add a few more instances
3. Update tuning for some decode shapes

Reviewed By: jwfromm

Differential Revision: D68233557

fbshipit-source-id: 168fb31a2bb281a2879babcdb50751330e01c798
  • Loading branch information
mxz297 authored and facebook-github-bot committed Jan 16, 2025
1 parent 497bad6 commit 5c76d93
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ struct IntTupleHash {
static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTupleHash> rowwise_lookup_dispatch = {
// Support for decode for [1024, 5120]
{{16, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
{{32, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
{{64, 1024, 5120},
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
{{128, 1024, 5120},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
// Support for decode for [5120, 1024]
{{16, 5120, 1024},
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
128,
16,
16,
1,
1,
S<8, 16, 1>,
S<8, 16, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Interwave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 8);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
using DeviceGemmInstance = DeviceGemmHelper<
128,
16,
32,
512,
16,
16,
1,
1,
S<32, 4, 1>,
S<32, 4, 1>,
S<1, 16, 1, 8>,
S<4, 4, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v2,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
}
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ at::Tensor f8f8bf16_rowwise_impl(
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
at::Tensor Y,
int KBatch = 1) {
// Get input information.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
Expand Down Expand Up @@ -194,7 +195,7 @@ at::Tensor f8f8bf16_rowwise_impl(
StrideB,
std::array<ck::index_t, NumDTensor>{0, 0},
StrideE,
1,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,19 @@ fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

0 comments on commit 5c76d93

Please sign in to comment.