Skip to content

Commit

Permalink
Refactor FP8 grouped GEMM with dynamic and static versions
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#647

Refactor FP8 grouped GEMM with dynamic and static versions to unify CUTLASS and CK FP8 grouped GEMM in fbgemm

Reviewed By: jwfromm

Differential Revision: D68004072
  • Loading branch information
jiawenliu64 authored and facebook-github-bot committed Jan 10, 2025
1 parent 395f065 commit 27d65f1
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 149 deletions.
47 changes: 15 additions & 32 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,39 +522,22 @@ def quantize(self, x, w):

def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
if m_values is None:
if torch.version.cuda:
return torch.ops.fbgemm.f8f8bf16_grouped(
xq,
wq,
x_scale,
w_scale,
)
else:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq,
wq,
x_scale,
w_scale,
kernel_name=kernel_name,
)
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
xq,
wq,
x_scale,
w_scale,
kernel_name=kernel_name,
)
else:
if torch.version.cuda:
return torch.ops.fbgemm.f8f8bf16_grouped(
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
)
else:
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
kernel_name=kernel_name,
)
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
xq,
wq,
x_scale,
w_scale,
zero_start_index_M=m_values,
kernel_name=kernel_name,
)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale, m_values = self.quantize(x, w)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ __global__ void set_kernel_args_kernel(
int problem_shape_buf_offset,
int stride_buf_offset,
int stride_size,
int problem_count,
int group_count,
int problem_shape_size,
int group_index,
int M,
Expand Down Expand Up @@ -239,7 +239,7 @@ __global__ void set_dynamic_kernel_args_kernel(
int problem_shape_buf_offset,
int stride_buf_offset,
int stride_size,
int problem_count,
int group_count,
int problem_shape_size,
int group_index,
int64_t* zero_start_index_M,
Expand Down Expand Up @@ -309,54 +309,51 @@ template <
int TBS_M,
int TBS_N,
int TBS_K,
bool PONG,
bool FAST_ACCUM>
std::vector<at::Tensor> f8f8bf16_grouped_impl(
bool PONG>
std::vector<at::Tensor> f8f8bf16_rowwise_grouped_impl(
at::TensorList XQ, // FP8
at::TensorList WQ, // FP8
at::TensorList x_scale,
at::TensorList w_scale,
at::Tensor output,
std::optional<at::Tensor> zero_start_index_M) {
int problem_count = XQ.size();
TORCH_CHECK(WQ.size() == problem_count);
if (problem_count == 0) {
int group_count = XQ.size();
TORCH_CHECK(WQ.size() == group_count);
if (group_count == 0) {
return std::vector<at::Tensor>();
}
using GroupedGemmConfigs = GroupedGemmArgs::
GroupedGemmConfigs<TB_M, TB_N, TB_K, TBS_M, TBS_N, TBS_K, PONG>;

int64_t total_output_size = 0;
std::vector<int64_t> output_sizes;
output_sizes.reserve(problem_count);
output_sizes.reserve(group_count);
at::Tensor output_args =
at::empty({problem_count}, XQ[0].options().dtype(at::kLong));
at::empty({group_count}, XQ[0].options().dtype(at::kLong));

const int64_t problem_shape_size = problem_count *
const int64_t problem_shape_size = group_count *
((int64_t)sizeof(GroupedGemmArgs::ProblemShape::UnderlyingProblemShape));
const int64_t stride_size = problem_count *
const int64_t stride_size = group_count *
((int64_t)sizeof(typename GroupedGemmConfigs::StrideInputA));

at::Tensor input_args = at::empty(
{problem_count * 4 + problem_shape_size + stride_size * 3 + 1000},
{group_count * 4 + problem_shape_size + stride_size * 3 + 1000},
XQ[0].options().dtype(at::kLong));

int xq_ptr_offset = 0;
int wq_ptr_offset = problem_count * sizeof(int64_t);
int x_scale_ptr_offset = problem_count * 2 * sizeof(int64_t);
int w_scale_ptr_offset = problem_count * 3 * sizeof(int64_t);
int problem_shape_buf_offset = problem_count * 4 * sizeof(int64_t);
int wq_ptr_offset = group_count * sizeof(int64_t);
int x_scale_ptr_offset = group_count * 2 * sizeof(int64_t);
int w_scale_ptr_offset = group_count * 3 * sizeof(int64_t);
int problem_shape_buf_offset = group_count * 4 * sizeof(int64_t);
int stride_buf_offset =
problem_count * 4 * sizeof(int64_t) + problem_shape_size;
group_count * 4 * sizeof(int64_t) + problem_shape_size;

for (int i = 0; i < problem_count; ++i) {
for (int i = 0; i < group_count; ++i) {
const int64_t output_size = XQ[i].size(0) * WQ[i].size(0);
total_output_size += output_size;
output_sizes.push_back(output_size);
}

at::Tensor output_tensor =
at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16));

int blockSize = 256;
int numBlocks = 1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
Expand All @@ -368,7 +365,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
"zero_start_index_M must be int64.");

// Set arguments
for (int i = 0; i < problem_count; ++i) {
for (int i = 0; i < group_count; ++i) {
int N = WQ[i].size(0);
int K = XQ[i].size(1);
TORCH_CHECK_EQ(WQ[i].size(1), K);
Expand All @@ -382,7 +379,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
w_scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>()),
input_args.data_ptr<int64_t>(),
output_args.data_ptr<int64_t>(),
output_tensor.data_ptr<at::BFloat16>(),
output.data_ptr<at::BFloat16>(),
output_offset,
xq_ptr_offset,
wq_ptr_offset,
Expand All @@ -391,7 +388,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
problem_shape_buf_offset,
stride_buf_offset,
stride_size,
problem_count,
group_count,
problem_shape_size,
i,
reinterpret_cast<int64_t*>(zero_start_index_M.value().data_ptr()),
Expand All @@ -408,7 +405,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
w_scale[i].data_ptr<GroupedGemmArgs::ElementAccumulator>()),
input_args.data_ptr<int64_t>(),
output_args.data_ptr<int64_t>(),
output_tensor.data_ptr<at::BFloat16>(),
output.data_ptr<at::BFloat16>(),
output_offset,
xq_ptr_offset,
wq_ptr_offset,
Expand All @@ -417,7 +414,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(
problem_shape_buf_offset,
stride_buf_offset,
stride_size,
problem_count,
group_count,
problem_shape_size,
i,
M,
Expand Down Expand Up @@ -451,7 +448,7 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(

typename GroupedGemmConfigs::Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGrouped,
{problem_count, problem_shape_ptr, nullptr},
{group_count, problem_shape_ptr, nullptr},
{reinterpret_cast<const GroupedGemmArgs::ElementInputA**>(xq_ptr),
stride_input_A_ptr,
reinterpret_cast<const GroupedGemmArgs::ElementInputB**>(wq_ptr),
Expand Down Expand Up @@ -510,59 +507,119 @@ std::vector<at::Tensor> f8f8bf16_grouped_impl(

C10_CUDA_KERNEL_LAUNCH_CHECK();

std::vector<at::Tensor> output_group = output_tensor.split(output_sizes);
for (int i = 0; i < problem_count; ++i) {
std::vector<at::Tensor> output_group = output.split(output_sizes);
for (int i = 0; i < group_count; ++i) {
output_group[i] = output_group[i].view({XQ[i].size(0), WQ[i].size(0)});
}
return output_group;
}

// FP8 Tensorwise grouped cutlass kernel dispatch.
template <bool FastAccum>
std::vector<at::Tensor> dispatch_fp8_grouped_kernel(
at::TensorList xq_group, // FP8
at::TensorList wq_group, // FP8
at::TensorList XQ, // FP8
at::TensorList WQ, // FP8
at::TensorList x_scale,
at::TensorList w_scale,
at::Tensor output,
std::optional<at::Tensor> zero_start_index_M) {
KernelMode kernel = get_grouped_kernel_mode(xq_group, wq_group);
KernelMode kernel = get_grouped_kernel_mode(XQ, WQ);
if (kernel == KernelMode::Small) {
return f8f8bf16_grouped_impl<64, 128, 128, 2, 1, 1, true, FastAccum>(
xq_group, wq_group, x_scale, w_scale, zero_start_index_M);
return f8f8bf16_rowwise_grouped_impl<64, 128, 128, 2, 1, 1, true>(
XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
} else if (kernel == KernelMode::Large) {
return f8f8bf16_grouped_impl<128, 128, 128, 2, 1, 1, true, FastAccum>(
xq_group, wq_group, x_scale, w_scale, zero_start_index_M);
return f8f8bf16_rowwise_grouped_impl<128, 128, 128, 2, 1, 1, true>(
XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
} else {
return f8f8bf16_grouped_impl<128, 128, 128, 1, 2, 1, true, FastAccum>(
xq_group, wq_group, x_scale, w_scale, zero_start_index_M);
return f8f8bf16_rowwise_grouped_impl<128, 128, 128, 1, 2, 1, true>(
XQ, WQ, x_scale, w_scale, output, zero_start_index_M);
}
}

std::vector<at::Tensor> f8f8bf16_grouped(
at::TensorList xq_group, // FP8
at::TensorList wq_group, // FP8
std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
at::TensorList XQ, // FP8
at::TensorList WQ, // FP8
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<at::Tensor> zero_start_index_M,
bool use_fast_accum) {
if (use_fast_accum) {
return dispatch_fp8_grouped_kernel<true>(
xq_group, wq_group, x_scale, w_scale, zero_start_index_M);
std::optional<std::vector<at::Tensor>> output = std::nullopt,
std::optional<std::string> kernel_name = std::nullopt) {
at::Tensor Y;
int group_count = XQ.size();
if (output.has_value()) {
std::vector<at::Tensor> output_;
output_ = output.value();
TORCH_CHECK(
output_.size() == group_count,
"Output and input must have same number of groups.");
// Check that output shapes are correct.
for (int i = 0; i < group_count; i++) {
int M = XQ[i].size(0);
int N = WQ[i].size(0);
int out_M = output_[i].size(0);
int out_N = output_[i].size(1);
TORCH_CHECK(
M == out_M && N == out_N,
"Output tensors do not have the expected shape.");
TORCH_CHECK(
output_[i].dtype() == at::kBFloat16,
"Output dtype must be bfloat16.");
}
Y = at::stack(output.value(), 0);
} else {
return dispatch_fp8_grouped_kernel<false>(
xq_group, wq_group, x_scale, w_scale, zero_start_index_M);
int64_t total_output_size = 0;
std::vector<int64_t> output_sizes;
for (int i = 0; i < group_count; ++i) {
const int64_t output_size = XQ[i].size(0) * WQ[i].size(0);
total_output_size += output_size;
output_sizes.push_back(output_size);
}
Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16));
}
return dispatch_fp8_grouped_kernel(XQ, WQ, x_scale, w_scale, Y, std::nullopt);
}

at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::TensorList XQ, // FP8
at::TensorList WQ, // FP8
at::TensorList x_scale,
at::TensorList w_scale,
at::Tensor zero_start_index_M,
std::optional<std::string> kernel_name = std::nullopt) {
at::Tensor Y;
int group_count = XQ.size();
int64_t total_output_size = 0;
std::vector<int64_t> output_sizes;
for (int i = 0; i < group_count; ++i) {
const int64_t output_size = XQ[i].size(0) * WQ[i].size(0);
total_output_size += output_size;
output_sizes.push_back(output_size);
}
Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16));
return at::stack(
dispatch_fp8_grouped_kernel(
XQ, WQ, x_scale, w_scale, Y, zero_start_index_M),
0);
}

#else

std::vector<at::Tensor> f8f8bf16_grouped(
at::TensorList xq_group, // FP8
at::TensorList wq_group, // FP8
std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
at::TensorList XQ, // FP8
at::TensorList WQ, // FP8
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<std::vector<at::Tensor>> output = std::nullopt,
std::optional<std::string> kernel_name = std::nullopt) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}

at::Tensor f8f8bf16_rowwise_grouped_dynamic(
at::TensorList XQ, // FP8
at::TensorList WQ, // FP8
at::TensorList x_scale,
at::TensorList w_scale,
std::optional<at::Tensor> zero_start_index_M,
bool use_fast_accum) {
at::Tensor zero_start_index_M,
std::optional<std::string> kernel_name = std::nullopt) {
throw std::runtime_error(
"CUDA version is older than 12.0"); // requires CUDA>=12
}
Expand Down
Loading

0 comments on commit 27d65f1

Please sign in to comment.