diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index aa2d8d2fd..0bfaf5101 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -522,39 +522,22 @@ def quantize(self, x, w): def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): if m_values is None: - if torch.version.cuda: - return torch.ops.fbgemm.f8f8bf16_grouped( - xq, - wq, - x_scale, - w_scale, - ) - else: - return torch.ops.fbgemm.f8f8bf16_rowwise_grouped( - xq, - wq, - x_scale, - w_scale, - kernel_name=kernel_name, - ) + return torch.ops.fbgemm.f8f8bf16_rowwise_grouped( + xq, + wq, + x_scale, + w_scale, + kernel_name=kernel_name, + ) else: - if torch.version.cuda: - return torch.ops.fbgemm.f8f8bf16_grouped( - xq, - wq, - x_scale, - w_scale, - zero_start_index_M=m_values, - ) - else: - return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( - xq, - wq, - x_scale, - w_scale, - zero_start_index_M=m_values, - kernel_name=kernel_name, - ) + return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( + xq, + wq, + x_scale, + w_scale, + zero_start_index_M=m_values, + kernel_name=kernel_name, + ) def quantize_and_compute(self, x, w): xq, wq, x_scale, w_scale, m_values = self.quantize(x, w) diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_grouped.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu similarity index 81% rename from fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_grouped.cu rename to fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu index 545402061..c16fba407 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_grouped.cu +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu @@ -161,7 +161,7 @@ __global__ void set_kernel_args_kernel( int problem_shape_buf_offset, int stride_buf_offset, int stride_size, - int problem_count, + int group_count, int problem_shape_size, int group_index, int M, @@ -239,7 +239,7 @@ __global__ void set_dynamic_kernel_args_kernel( int problem_shape_buf_offset, int stride_buf_offset, int stride_size, - int problem_count, + int group_count, int problem_shape_size, int group_index, int64_t* zero_start_index_M, @@ -309,17 +309,17 @@ template < int TBS_M, int TBS_N, int TBS_K, - bool PONG, - bool FAST_ACCUM> -std::vector f8f8bf16_grouped_impl( + bool PONG> +std::vector f8f8bf16_rowwise_grouped_impl( at::TensorList XQ, // FP8 at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, + at::Tensor output, std::optional zero_start_index_M) { - int problem_count = XQ.size(); - TORCH_CHECK(WQ.size() == problem_count); - if (problem_count == 0) { + int group_count = XQ.size(); + TORCH_CHECK(WQ.size() == group_count); + if (group_count == 0) { return std::vector(); } using GroupedGemmConfigs = GroupedGemmArgs:: @@ -327,36 +327,33 @@ std::vector f8f8bf16_grouped_impl( int64_t total_output_size = 0; std::vector output_sizes; - output_sizes.reserve(problem_count); + output_sizes.reserve(group_count); at::Tensor output_args = - at::empty({problem_count}, XQ[0].options().dtype(at::kLong)); + at::empty({group_count}, XQ[0].options().dtype(at::kLong)); - const int64_t problem_shape_size = problem_count * + const int64_t problem_shape_size = group_count * ((int64_t)sizeof(GroupedGemmArgs::ProblemShape::UnderlyingProblemShape)); - const int64_t stride_size = problem_count * + const int64_t stride_size = group_count * ((int64_t)sizeof(typename GroupedGemmConfigs::StrideInputA)); at::Tensor input_args = at::empty( - {problem_count * 4 + problem_shape_size + stride_size * 3 + 1000}, + {group_count * 4 + problem_shape_size + stride_size * 3 + 1000}, XQ[0].options().dtype(at::kLong)); int xq_ptr_offset = 0; - int wq_ptr_offset = problem_count * sizeof(int64_t); - int x_scale_ptr_offset = problem_count * 2 * sizeof(int64_t); - int w_scale_ptr_offset = problem_count * 3 * sizeof(int64_t); - int problem_shape_buf_offset = problem_count * 4 * sizeof(int64_t); + int wq_ptr_offset = group_count * sizeof(int64_t); + int x_scale_ptr_offset = group_count * 2 * sizeof(int64_t); + int w_scale_ptr_offset = group_count * 3 * sizeof(int64_t); + int problem_shape_buf_offset = group_count * 4 * sizeof(int64_t); int stride_buf_offset = - problem_count * 4 * sizeof(int64_t) + problem_shape_size; + group_count * 4 * sizeof(int64_t) + problem_shape_size; - for (int i = 0; i < problem_count; ++i) { + for (int i = 0; i < group_count; ++i) { const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); total_output_size += output_size; output_sizes.push_back(output_size); } - at::Tensor output_tensor = - at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); - int blockSize = 256; int numBlocks = 1; auto stream = at::cuda::getCurrentCUDAStream().stream(); @@ -368,7 +365,7 @@ std::vector f8f8bf16_grouped_impl( "zero_start_index_M must be int64."); // Set arguments - for (int i = 0; i < problem_count; ++i) { + for (int i = 0; i < group_count; ++i) { int N = WQ[i].size(0); int K = XQ[i].size(1); TORCH_CHECK_EQ(WQ[i].size(1), K); @@ -382,7 +379,7 @@ std::vector f8f8bf16_grouped_impl( w_scale[i].data_ptr()), input_args.data_ptr(), output_args.data_ptr(), - output_tensor.data_ptr(), + output.data_ptr(), output_offset, xq_ptr_offset, wq_ptr_offset, @@ -391,7 +388,7 @@ std::vector f8f8bf16_grouped_impl( problem_shape_buf_offset, stride_buf_offset, stride_size, - problem_count, + group_count, problem_shape_size, i, reinterpret_cast(zero_start_index_M.value().data_ptr()), @@ -408,7 +405,7 @@ std::vector f8f8bf16_grouped_impl( w_scale[i].data_ptr()), input_args.data_ptr(), output_args.data_ptr(), - output_tensor.data_ptr(), + output.data_ptr(), output_offset, xq_ptr_offset, wq_ptr_offset, @@ -417,7 +414,7 @@ std::vector f8f8bf16_grouped_impl( problem_shape_buf_offset, stride_buf_offset, stride_size, - problem_count, + group_count, problem_shape_size, i, M, @@ -451,7 +448,7 @@ std::vector f8f8bf16_grouped_impl( typename GroupedGemmConfigs::Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kGrouped, - {problem_count, problem_shape_ptr, nullptr}, + {group_count, problem_shape_ptr, nullptr}, {reinterpret_cast(xq_ptr), stride_input_A_ptr, reinterpret_cast(wq_ptr), @@ -510,59 +507,119 @@ std::vector f8f8bf16_grouped_impl( C10_CUDA_KERNEL_LAUNCH_CHECK(); - std::vector output_group = output_tensor.split(output_sizes); - for (int i = 0; i < problem_count; ++i) { + std::vector output_group = output.split(output_sizes); + for (int i = 0; i < group_count; ++i) { output_group[i] = output_group[i].view({XQ[i].size(0), WQ[i].size(0)}); } return output_group; } // FP8 Tensorwise grouped cutlass kernel dispatch. -template std::vector dispatch_fp8_grouped_kernel( - at::TensorList xq_group, // FP8 - at::TensorList wq_group, // FP8 + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, + at::Tensor output, std::optional zero_start_index_M) { - KernelMode kernel = get_grouped_kernel_mode(xq_group, wq_group); + KernelMode kernel = get_grouped_kernel_mode(XQ, WQ); if (kernel == KernelMode::Small) { - return f8f8bf16_grouped_impl<64, 128, 128, 2, 1, 1, true, FastAccum>( - xq_group, wq_group, x_scale, w_scale, zero_start_index_M); + return f8f8bf16_rowwise_grouped_impl<64, 128, 128, 2, 1, 1, true>( + XQ, WQ, x_scale, w_scale, output, zero_start_index_M); } else if (kernel == KernelMode::Large) { - return f8f8bf16_grouped_impl<128, 128, 128, 2, 1, 1, true, FastAccum>( - xq_group, wq_group, x_scale, w_scale, zero_start_index_M); + return f8f8bf16_rowwise_grouped_impl<128, 128, 128, 2, 1, 1, true>( + XQ, WQ, x_scale, w_scale, output, zero_start_index_M); } else { - return f8f8bf16_grouped_impl<128, 128, 128, 1, 2, 1, true, FastAccum>( - xq_group, wq_group, x_scale, w_scale, zero_start_index_M); + return f8f8bf16_rowwise_grouped_impl<128, 128, 128, 1, 2, 1, true>( + XQ, WQ, x_scale, w_scale, output, zero_start_index_M); } } -std::vector f8f8bf16_grouped( - at::TensorList xq_group, // FP8 - at::TensorList wq_group, // FP8 +std::vector f8f8bf16_rowwise_grouped( + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M, - bool use_fast_accum) { - if (use_fast_accum) { - return dispatch_fp8_grouped_kernel( - xq_group, wq_group, x_scale, w_scale, zero_start_index_M); + std::optional> output = std::nullopt, + std::optional kernel_name = std::nullopt) { + at::Tensor Y; + int group_count = XQ.size(); + if (output.has_value()) { + std::vector output_; + output_ = output.value(); + TORCH_CHECK( + output_.size() == group_count, + "Output and input must have same number of groups."); + // Check that output shapes are correct. + for (int i = 0; i < group_count; i++) { + int M = XQ[i].size(0); + int N = WQ[i].size(0); + int out_M = output_[i].size(0); + int out_N = output_[i].size(1); + TORCH_CHECK( + M == out_M && N == out_N, + "Output tensors do not have the expected shape."); + TORCH_CHECK( + output_[i].dtype() == at::kBFloat16, + "Output dtype must be bfloat16."); + } + Y = at::stack(output.value(), 0); } else { - return dispatch_fp8_grouped_kernel( - xq_group, wq_group, x_scale, w_scale, zero_start_index_M); + int64_t total_output_size = 0; + std::vector output_sizes; + for (int i = 0; i < group_count; ++i) { + const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); + total_output_size += output_size; + output_sizes.push_back(output_size); + } + Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); } + return dispatch_fp8_grouped_kernel(XQ, WQ, x_scale, w_scale, Y, std::nullopt); +} + +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor zero_start_index_M, + std::optional kernel_name = std::nullopt) { + at::Tensor Y; + int group_count = XQ.size(); + int64_t total_output_size = 0; + std::vector output_sizes; + for (int i = 0; i < group_count; ++i) { + const int64_t output_size = XQ[i].size(0) * WQ[i].size(0); + total_output_size += output_size; + output_sizes.push_back(output_size); + } + Y = at::zeros(total_output_size, XQ[0].options().dtype(at::kBFloat16)); + return at::stack( + dispatch_fp8_grouped_kernel( + XQ, WQ, x_scale, w_scale, Y, zero_start_index_M), + 0); } #else -std::vector f8f8bf16_grouped( - at::TensorList xq_group, // FP8 - at::TensorList wq_group, // FP8 +std::vector f8f8bf16_rowwise_grouped( + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 + at::TensorList x_scale, + at::TensorList w_scale, + std::optional> output = std::nullopt, + std::optional kernel_name = std::nullopt) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} + +at::Tensor f8f8bf16_rowwise_grouped_dynamic( + at::TensorList XQ, // FP8 + at::TensorList WQ, // FP8 at::TensorList x_scale, at::TensorList w_scale, - std::optional zero_start_index_M, - bool use_fast_accum) { + at::Tensor zero_start_index_M, + std::optional kernel_name = std::nullopt) { throw std::runtime_error( "CUDA version is older than 12.0"); // requires CUDA>=12 } diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index d34321077..aac32e76f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -55,13 +55,6 @@ at::Tensor f8f8bf16_tensorwise( at::Tensor WQ, double scale, bool use_fast_accum = true); -std::vector f8f8bf16_grouped( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::optional zero_start_index_M, - bool use_fast_accum = true); std::vector bf16bf16bf16_grouped( at::TensorList X, at::TensorList W, @@ -186,8 +179,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f8f8bf16_cublas(Tensor A, Tensor B, Tensor? Ainvs=None, Tensor? Binvs=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); m.def( "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); - m.def( - "f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]"); m.def( "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( @@ -197,11 +188,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.impl("i8i8bf16_dynamic", i8i8bf16_dynamic); #endif #ifdef USE_ROCM - m.def( - "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); - m.def( - "f8f8bf16_rowwise_grouped_dynamic(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor zero_start_index_M, str? kernel_name=None) -> Tensor"); - m.def("get_f8f8bf16_rowwise_grouped_kernels() -> str[]"); m.impl( "get_f8f8bf16_rowwise_grouped_kernels", @@ -219,6 +205,10 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f8f8bf16_rowwise_out(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor(a!) output, Tensor? bias=None, bool use_fast_accum=True) -> ()"); m.def( "f8f8bf16_rowwise_batched(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor? bias=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); + m.def( + "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); + m.def( + "f8f8bf16_rowwise_grouped_dynamic(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor zero_start_index_M, str? kernel_name=None) -> Tensor"); m.def( "f8f8bf16_tensorwise(Tensor XQ, Tensor WQ, float scale, bool use_fast_accum=True) -> Tensor"); m.def("per_tensor_quantize_i8(Tensor X, float scale) -> Tensor"); @@ -259,24 +249,22 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); m.impl("f8f8bf16_rowwise_out", f8f8bf16_rowwise_out); m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); + m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); + m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); m.impl("quantize_fp8_per_col", quantize_fp8_per_col); m.impl("bf16bf16bf16_grouped", bf16bf16bf16_grouped); m.impl("bf16bf16bf16_grouped_dynamic", bf16bf16bf16_grouped_dynamic); + #ifndef USE_ROCM m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); - m.impl("f8f8bf16_grouped", f8f8bf16_grouped); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); #endif -#ifdef USE_ROCM - m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); - m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); -#endif } // Though it should never be used, it still seems helpful to define these @@ -286,6 +274,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("f8f8bf16_tensorwise", f8f8bf16_tensorwise); m.impl("f8f8bf16_rowwise", f8f8bf16_rowwise); m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched); + m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); + m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); m.impl("quantize_fp8_per_tensor", quantize_fp8_per_tensor); m.impl("quantize_fp8_per_row", quantize_fp8_per_row); m.impl("quantize_fp8_per_col", quantize_fp8_per_col); @@ -295,15 +285,10 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); - m.impl("f8f8bf16_grouped", f8f8bf16_grouped); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); #endif -#ifdef USE_ROCM - m.impl("f8f8bf16_rowwise_grouped", f8f8bf16_rowwise_grouped); - m.impl("f8f8bf16_rowwise_grouped_dynamic", f8f8bf16_rowwise_grouped_dynamic); -#endif } // Shape registration functions. @@ -475,22 +460,6 @@ std::vector quantize_fp8_per_col_meta( return {Y, scale}; } -std::vector f8f8bf16_grouped_meta( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList /* x_scale */, - at::TensorList /* w_scale */, - std::optional /* zero_start_index_M = std::nullopt */, - bool /* use_fast_accum = true */) { - std::vector Y; - for (int i = 0; i < XQ.size(); i++) { - const at::SymInt M = XQ[i].sym_size(0); - const at::SymInt N = WQ[i].sym_size(0); - Y.push_back(at::empty_symint({M, N}, XQ[i].options().dtype(at::kBFloat16))); - } - return Y; -} - std::vector bf16bf16bf16_grouped_meta( at::TensorList X, at::TensorList W, @@ -534,7 +503,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); - m.impl("f8f8bf16_grouped", f8f8bf16_grouped_meta); #endif } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index 997682d77..febb8e584 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -806,33 +806,62 @@ def test_fp8_grouped_gemm( # FP8 grouped gemm kernel if use_cudagraph: - # warmup - torch.ops.fbgemm.f8f8bf16_grouped( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - zero_start_index_M if use_padding_zeros else None, - ) - # With cudagraph - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - y_fp8_group = torch.ops.fbgemm.f8f8bf16_grouped( + if use_padding_zeros: + # warmup + torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( xq_group, wq_group, x_scale_group, w_scale_group, - zero_start_index_M if use_padding_zeros else None, + zero_start_index_M, ) - g.replay() + # With cudagraph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( + xq_group, + wq_group, + x_scale_group, + w_scale_group, + zero_start_index_M, + ) + g.replay() + y_fp8_group = y_fp8_group.unbind(dim=0) + else: + # warmup + torch.ops.fbgemm.f8f8bf16_rowwise_grouped( + xq_group, + wq_group, + x_scale_group, + w_scale_group, + ) + # With cudagraph + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped( + xq_group, + wq_group, + x_scale_group, + w_scale_group, + ) + g.replay() else: - y_fp8_group = torch.ops.fbgemm.f8f8bf16_grouped( - xq_group, - wq_group, - x_scale_group, - w_scale_group, - zero_start_index_M if use_padding_zeros else None, - ) + if use_padding_zeros: + y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( + xq_group, + wq_group, + x_scale_group, + w_scale_group, + zero_start_index_M, + ) + y_fp8_group = y_fp8_group.unbind(dim=0) + else: + y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped( + xq_group, + wq_group, + x_scale_group, + w_scale_group, + ) # BF16 grouped gemm kernel bf16_args = (