From cfced91d998d8a0a38b4b1a1eb2949ce9088c221 Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Mon, 6 Jan 2025 15:50:24 -0800 Subject: [PATCH] Switch dynamic FP8 grouped gemm to accept tensor inputs Summary: As we continue the long march towards optimal MOE performance, we've identifed that in prefill, having to artificially split inputs and then check that each group is valid introduces non-trivial overhead. Since all the inputs must be in consecutive memory anyway, it's better to just require them to be contiguous tensors rather than TensorLists. While making this change may sound simple, it required switching the kernels to a templated implementation. This is the most elegant way to support various input and output types for shared kernels, despite it being a large refactor. I also removed some of the now outdated fbgemm profiling scripts. They likely arent useful going forward anyway. Differential Revision: D67881909 --- .../gen_ai/bench/profile_grouped_gemm.py | 103 -- .../experimental/gen_ai/bench/quantize_ops.py | 9 +- .../fp8_rowwise_grouped_gemm.hip | 264 ++-- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip | 34 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip | 128 +- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip | 128 +- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip | 128 +- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 128 +- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip | 128 +- ...4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip | 128 +- ...4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip | 128 +- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 128 +- ...x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip | 128 +- ...x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip | 128 +- ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 128 +- ...4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip | 128 +- ...8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip | 128 +- ...8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip | 128 +- ...8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip | 128 +- ...8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 128 +- ...4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip | 128 +- ...4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip | 128 +- ...8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip | 128 +- ...8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip | 128 +- ...x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip | 128 +- ...x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip | 128 +- ...8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 128 +- ...16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip | 34 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 128 +- ...32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 34 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip | 128 +- ..._8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 128 +- ...4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip | 128 +- ...4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip | 128 +- .../kernels/fp8_rowwise_grouped_common.h | 75 ++ .../fp8_rowwise_grouped_kernel_manifest.h | 1086 ++++++++--------- .../gen_ai/src/quantize/quantize.cpp | 24 +- 76 files changed, 5776 insertions(+), 4463 deletions(-) delete mode 100644 fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py diff --git a/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py b/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py deleted file mode 100644 index 82019a769c..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import argparse -from typing import Any - -import pandas as pd -import torch - -from .quantize_ops import FP8RowwiseGroupedGemm - - -def main(args: Any): - # Extract and format shape arguments. - M = [int(m) for m in args.M.strip().split(",")] - N = [int(n) for n in args.N.strip().split(",")] - K = [int(k) for k in args.K.strip().split(",")] - assert len(M) == len(N) == len(K), "M, N, and K must have the same length." - - # initialize tensors for benchmarking. - A = [] - B = [] - num_groups = len(M) - for i in range(num_groups): - A.append(torch.randn(M[i], K[i], device="cuda")) - B.append(torch.randn(N[i], K[i], device="cuda")) - - # Get quantized tensors. - group_gemm_op = FP8RowwiseGroupedGemm() - quantized_vals = group_gemm_op.quantize(A, B) - # Iterate over kernels to find the most performant one. - benchmark_results = [] - for kernel_name in torch.ops.fbgemm.get_f8f8bf16_rowwise_grouped_kernels(): - # Do a warmup run of the kernel. - output = group_gemm_op.compute(*quantized_vals, kernel_name=kernel_name) - # Benchmark this kernel implementation. - ms_runtime = group_gemm_op.benchmark( - *quantized_vals, use_cuda_graph=True, kernel_name=kernel_name - ) - # Compute statistics for this kernel. - tflops = 0 - gbps = 0 - for i in range(num_groups): - tflops += 2 * M[i] * N[i] * K[i] / (ms_runtime / 1e3) / 1e12 - gbps += ( - ( - quantized_vals[0][i].numel() * quantized_vals[0][i].element_size() - + quantized_vals[1][i].numel() * quantized_vals[1][i].element_size() - + output[i].numel() * output[i].element_size() - ) - / (ms_runtime / 1e3) - / 1e9 - ) - # Record results. - print(f"Kernel: {kernel_name}, ms: {ms_runtime:.4f}, TFLOPS: {tflops:.2f}") - benchmark_results.append( - { - "kernel_name": kernel_name, - "ms_runtime": ms_runtime, - "tflops": tflops, - "gbps": gbps, - } - ) - # Report best kernel. - best_kernel = min(benchmark_results, key=lambda x: x["ms_runtime"]) - print( - f"Best kernel for this shape: {best_kernel['kernel_name']}: {best_kernel['tflops']:.2f} TFLOPS" - ) - - # If specified, save all results. - if args.export_csv: - df = pd.DataFrame(benchmark_results) - df.to_csv("grouped_gemm_benchmark.csv", index=False) - - -def invoke_main() -> None: - parser = argparse.ArgumentParser() - parser.add_argument( - "--export_csv", - action="store_true", - help="Export results to a CSV file.", - ) - parser.add_argument( - "--M", - required=True, - help="Comma separated list of M values of each group to benchmark.", - ) - parser.add_argument( - "--N", - required=True, - help="Comma separated list of N values of each group to benchmark", - ) - parser.add_argument( - "--K", - required=True, - help="Comma separated list of K values of each group to benchmark.", - ) - - args = parser.parse_args() - main(args) diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index ab1d28e026..1a21bc21e2 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -486,11 +486,6 @@ def quantize_fixed_nk(self, x, w): # Apply quantization. xq, x_scale = quantize_fp8_row(xq) wq, w_scale = quantize_fp8_row(wq) - # View these unified tensors as lists of tensors. - xq = [x.squeeze() for x in xq.split(1, dim=0)] - wq = [w.squeeze() for w in wq.split(1, dim=0)] - x_scale = [xs.squeeze() for xs in x_scale.view(group_size, -1).split(1, dim=0)] - w_scale = [ws.squeeze() for ws in w_scale.view(group_size, -1).split(1, dim=0)] # Return processed tensors. return ( @@ -520,14 +515,13 @@ def quantize(self, x, w): m_values = None return xq, wq, x_scale, w_scale, m_values - def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): + def compute(self, xq, wq, x_scale, w_scale, m_values): if m_values is None: return torch.ops.fbgemm.f8f8bf16_rowwise_grouped( xq, wq, x_scale, w_scale, - kernel_name=kernel_name, ) else: return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic( @@ -536,7 +530,6 @@ def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None): x_scale, w_scale, zero_start_index_M=m_values, - kernel_name=kernel_name, ) def quantize_and_compute(self, x, w): diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip index e053eebb00..5fce43734c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip @@ -27,6 +27,15 @@ namespace fbgemm_gpu { +template +using RowwiseGroupedKernel = std::function; + // Define useful types that are needed for various kernels. using KernelArguments = ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<2>; @@ -37,48 +46,49 @@ using D1DataType = float; using DsDataType = ck::Tuple; using EDataType = ck::bhalf_t; -RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { +template +RowwiseGroupedKernel rowwise_grouped_heuristic_dispatch(int M, int N, int K) { // We use shape heuristics to find the best kernel. // To do this, we divide by the size of M and find the best // option within that grouping. if (M <= 16) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; + return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; } if (K <= 8192) { - return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; } - return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2; } if (M <= 32) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; } if (K <= 8192) { - return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2; } - return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2; + return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2; } if (M <= 64) { - return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } if (M <= 128) { if (N < 8192 && K <= 8192) { - return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } if (M <= 256) { - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } if (M <= 512) { if (K <= 8192) { - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; } - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; } // Default kernel for all other shapes. - return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; } __global__ void set_kernel_args_kernel( @@ -239,18 +249,18 @@ __global__ void set_kernel_args_fixed_nk_kernel( void set_dynamic_kernel_args( at::Tensor kernel_args, - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::vector output, + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor output, at::Tensor zero_start_index_M) { // Get current cuda stream. auto stream = at::cuda::getCurrentHIPStream().stream(); - int group_count = XQ.size(); + int group_count = XQ.size(0); // Confirm M is on the proper device. TORCH_CHECK( - XQ[0].device() == zero_start_index_M.device(), + XQ.device() == zero_start_index_M.device(), "zero_start_index_M and inputs must be on the same device."); TORCH_CHECK( zero_start_index_M.size(0) == group_count, @@ -261,35 +271,9 @@ void set_dynamic_kernel_args( // We assume that M, N, and K are fixed across groups. // The actual m values are sstored in the passed M tensor. - int M = XQ[0].size(0); - int K = XQ[0].size(1); - int N = WQ[0].size(0); - - // Make sure that inputs are allocated in sequential memory as required by - // this mode. - for (int i = 1; i < group_count; i++) { - // Check that all inputs are allocated directly following preceding input. - TORCH_CHECK( - XQ[i].data_ptr() == - (reinterpret_cast(XQ[i - 1].data_ptr()) + (M * K)), - "Inputs must be sequential in memory to support dynamic M, but XQ is not."); - TORCH_CHECK( - WQ[i].data_ptr() == - (reinterpret_cast(WQ[i - 1].data_ptr()) + (N * K)), - "Inputs must be sequential in memory to support dynamic M, but WQ is not."); - TORCH_CHECK( - x_scale[i].data_ptr() == - (reinterpret_cast(x_scale[i - 1].data_ptr()) + (M)), - "Inputs must be sequential in memory to support dynamic M, but x_scale is not."); - TORCH_CHECK( - w_scale[i].data_ptr() == - (reinterpret_cast(w_scale[i - 1].data_ptr()) + (N)), - "Inputs must be sequential in memory to support dynamic M, but w_scale is not."); - TORCH_CHECK( - output[i].data_ptr() == - (reinterpret_cast(output[i - 1].data_ptr()) + (M * N)), - "Inputs must be sequential in memory to support dynamic M, but output is not."); - } + int M = XQ.size(1); + int K = XQ.size(2); + int N = WQ.size(1); // Launch a kernel that sets kernel argument memory. const int BLOCK_SIZE = 8; @@ -298,11 +282,11 @@ void set_dynamic_kernel_args( int numBlocks = (block_factor + blockSize - 1) / blockSize; set_kernel_args_fixed_nk_kernel<<>>( reinterpret_cast(kernel_args.data_ptr()), - reinterpret_cast(XQ[0].data_ptr()), - reinterpret_cast(WQ[0].data_ptr()), - reinterpret_cast(w_scale[0].data_ptr()), - reinterpret_cast(x_scale[0].data_ptr()), - reinterpret_cast(output[0].data_ptr()), + reinterpret_cast(XQ.data_ptr()), + reinterpret_cast(WQ.data_ptr()), + reinterpret_cast(w_scale.data_ptr()), + reinterpret_cast(x_scale.data_ptr()), + reinterpret_cast(output.data_ptr()), reinterpret_cast(zero_start_index_M.data_ptr()), M, N, @@ -311,50 +295,12 @@ void set_dynamic_kernel_args( BLOCK_SIZE); } -at::Tensor get_grouped_kernel_args( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - std::optional zero_start_index_M, - std::vector output) { - int group_count = XQ.size(); - // Get space on device for the kernel argument tensor. - at::Tensor kernel_args = at::empty( - {static_cast(group_count * sizeof(KernelArguments))}, - XQ[0].options().dtype(at::kByte)); - - // There are two different modes for this kernel. - // When zero_start_index_M is provided, we assume that data is sequential and - // that N and K are constants. This allows a more efficient kernel - // launch and is best suited to MOE use cases where M is truly dynamic. - // When zero_start_index_M is not provided, we assume M, N, and K can all vary - // and set them for each group. It is important to note that this does not - // work well with cuda graphs and runtime dynamism so if possible we recommend - // using zero_start_index_M. - - if (zero_start_index_M.has_value()) { - set_dynamic_kernel_args( - kernel_args, - XQ, - WQ, - x_scale, - w_scale, - output, - zero_start_index_M.value()); - } else { - set_static_kernel_args(kernel_args, XQ, WQ, x_scale, w_scale, output); - } - return kernel_args; -} - std::vector f8f8bf16_rowwise_grouped( at::TensorList XQ, at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional> output = std::nullopt, - std::optional kernel_name = std::nullopt) { + std::optional> output = std::nullopt) { // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. TORCH_CHECK( @@ -414,20 +360,11 @@ std::vector f8f8bf16_rowwise_grouped( } // Prepare kernel arguments by copying them to the proper device location. - at::Tensor kernel_args = - get_grouped_kernel_args(XQ, WQ, x_scale, w_scale, std::nullopt, Y); + at::Tensor kernel_args = at::empty( + {static_cast(group_count * sizeof(KernelArguments))}, + XQ[0].options().dtype(at::kByte)); + set_static_kernel_args(kernel_args, XQ, WQ, x_scale, w_scale, Y); - // If provided a specific kernel implementation, dispatch to it. - if (kernel_name.has_value()) { - auto it = kernel_name_map.find(kernel_name.value()); - // If not found, raise an error. - TORCH_CHECK( - it != kernel_name_map.end(), - "Could not find kernel " + kernel_name.value()); - // If found, always use requested kernel. - return it->second(XQ, WQ, x_scale, w_scale, kernel_args, Y); - } - // Otherwise, use heuristics to find the best kernel options. // We use the largest of each shape for heuristics. int MaxM = 0; int MaxN = 0; @@ -437,76 +374,63 @@ std::vector f8f8bf16_rowwise_grouped( MaxN = max(MaxN, WQ[i].size(0)); MaxK = max(MaxK, XQ[i].size(1)); } - RowwiseGroupedKernel selected_kernel = - rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + RowwiseGroupedKernel> selected_kernel = + rowwise_grouped_heuristic_dispatch>(MaxM, MaxN, MaxK); return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); } at::Tensor f8f8bf16_rowwise_grouped_dynamic( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor zero_start_index_M, - std::optional kernel_name = std::nullopt) { + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M) { // Check that input datatypes are valid. // First confirm that there are the same number of groups in all inputs. + int group_count = XQ.size(0); + int M = XQ.size(1); + int N = WQ.size(1); TORCH_CHECK( - XQ.size() == WQ.size() && XQ.size() == x_scale.size() && - XQ.size() == w_scale.size(), + WQ.size(0) == group_count && x_scale.numel() / group_count == M && + w_scale.numel() / group_count == N, "All inputs must have the same number of groups."); - int group_count = XQ.size(); // Iterate over inputs and check they are valid. - for (at::Tensor x : XQ) { - TORCH_CHECK(x.is_cuda() && x.is_contiguous()); - TORCH_CHECK(x.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - x.dtype() == at::kFloat8_e4m3fnuz, - "Inputs must be type float8_e4m3fnuz."); - } - for (at::Tensor w : WQ) { - TORCH_CHECK(w.is_cuda() && w.is_contiguous()); - TORCH_CHECK(w.dim() == 2, "Inputs must be 2D."); - TORCH_CHECK( - w.dtype() == at::kFloat8_e4m3fnuz, - "Inputs must be type float8_e4m3fnuz."); - TORCH_CHECK( - w.size(0) >= 512 && w.size(1) >= 512, - "N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling."); - } - for (at::Tensor xs : x_scale) { - TORCH_CHECK(xs.dtype() == at::kFloat, "Scales must be float32."); - } - for (at::Tensor ws : x_scale) { - TORCH_CHECK(ws.dtype() == at::kFloat, "Scales must be float32."); - } + TORCH_CHECK(XQ.is_cuda() && XQ.is_contiguous()); + TORCH_CHECK(XQ.dim() == 3, "Input XQ must be 3D (G,M,K)."); + TORCH_CHECK( + XQ.dtype() == at::kFloat8_e4m3fnuz, + "Input XQ must be type float8_e4m3fnuz."); + + TORCH_CHECK(WQ.is_cuda() && WQ.is_contiguous()); + TORCH_CHECK(WQ.dim() == 3, "Input WQ must be 3D (G,N,K)."); + TORCH_CHECK( + WQ.dtype() == at::kFloat8_e4m3fnuz, + "Input WQ must be type float8_e4m3fnuz."); + TORCH_CHECK( + WQ.size(1) >= 512 && WQ.size(2) >= 512, + "N and K must be at least 512 for grouped gemm. For smaller inputs, consider unrolling."); + + TORCH_CHECK(x_scale.dtype() == at::kFloat, "Scales must be float32."); + TORCH_CHECK(w_scale.dtype() == at::kFloat, "Scales must be float32."); - // Create a single chunk of tensor but view it as a list for compatibility. - int M = XQ[0].size(0); - int N = WQ[0].size(0); // Allocate an empty output array. We will set its values to zero as part // of kernel setup. - at::Tensor Y_full = + at::Tensor Y = at::empty({group_count, M, N}, XQ[0].options().dtype(at::kBFloat16)); - // Split the output into groups. - std::vector Y = at::unbind(Y_full, 0); // Prepare kernel arguments by copying them to the proper device location. - at::Tensor kernel_args = - get_grouped_kernel_args(XQ, WQ, x_scale, w_scale, zero_start_index_M, Y); + at::Tensor kernel_args = at::empty( + {static_cast(group_count * sizeof(KernelArguments))}, + XQ.options().dtype(at::kByte)); + set_dynamic_kernel_args( + kernel_args, + XQ, + WQ, + x_scale, + w_scale, + Y, + zero_start_index_M); - // If provided a specific kernel implementation, dispatch to it. - if (kernel_name.has_value()) { - auto it = kernel_name_map.find(kernel_name.value()); - // If not found, raise an error. - TORCH_CHECK( - it != kernel_name_map.end(), - "Could not find kernel " + kernel_name.value()); - // If found, always use requested kernel. - it->second(XQ, WQ, x_scale, w_scale, kernel_args, Y); - return Y_full; - } - // Otherwise, use heuristics to find the best kernel options. // We use the largest of each shape for heuristics. int MaxM = 0; int MaxN = 0; @@ -516,21 +440,9 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic( MaxN = max(MaxN, WQ[i].size(0)); MaxK = max(MaxK, XQ[i].size(1)); } - RowwiseGroupedKernel selected_kernel = - rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); - // Run kernel to populate Y. - selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); - // Return unified view of Y_full. - return Y_full; -} - -std::vector get_f8f8bf16_rowwise_grouped_kernels() { - /* Helper function to get the names of avaialable grouped gemm kernels.*/ - std::vector kernel_names; - for (const auto& pair : kernel_name_map) { - kernel_names.push_back(pair.first); - } - return kernel_names; + RowwiseGroupedKernel selected_kernel = + rowwise_grouped_heuristic_dispatch(MaxM, MaxN, MaxK); + return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y); } } // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip index 4c99757b80..83d70aad13 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip index 21e92b3997..1b8d82c754 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index e24f8b6fd3..f790dfb139 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -8,14 +8,15 @@ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y) { + OutputType Y) { // A kernel that works well on small but not super tiny shapes. using DeviceGemmInstance = DeviceGemmHelper< 128, @@ -35,5 +36,24 @@ fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_in ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } + +template std::vector +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 3e5d0f249c..e8bfeadd25 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 9c871f64c9..d2e8078dd6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 128, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index f2a90ac2b0..52f0e4f7f3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 64f6311648..be17df0382 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 128, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 128, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip index b9ed2888b3..dc24ed8ab1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 34a7367f5b..b64e366cd1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip index 0637fde603..0eb6ae848a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 868cd8275e..3562c07fa8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip index 8c324b7989..378612faa3 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index eeb3ecfa99..7a93109888 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip index 2d4faaaa7a..d0d245410f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index f5c9aa7795..1ffc8cbe5d 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 256, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 256, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip index 6f8da7b7e7..b1e7a9d454 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 123638a331..7bc4105125 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip index 28a14967cc..0e547efba8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 0a20e29cd7..bd5354f11c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 32, - 512, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index fe794f8533..ffb5c4a355 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 6ef2ed503b..452236c088 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 16, - 64, - 128, - 16, - 16, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 16, + 64, + 128, + 16, + 16, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index 3cd1a219ce..ca07db247a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip index a3b3b37ec3..7b4c9c1f03 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 128, - 128, - 32, - 32, - 1, - 2, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip index fac1cd90b6..16b056222f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip index 5059ba7370..2502a2e3d6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip index 1750fe9158..30a412b5f9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip index cd5cc29f61..7bf51a1533 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip index 67d7d7d778..fe0058076c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip index 63bb549f4f..7e92b9ac4f 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 32, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip index 329e3a0e06..df610ef17b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip index 9d7908f441..5fa67949fa 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 16, - 128, - 16, - 16, - 2, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 16, + 128, + 16, + 16, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip index 7330ad30c5..b16371594e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip index 5c07c9b991..2108cd1b08 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 128, - 64, - 32, - 128, - 32, - 32, - 1, - 1, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index 270c4b6b47..d7cf20529e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index c90498392b..ebaca18200 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index 9282b9d898..a9dade813e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 128, - 32, - 32, - 2, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index 1a09ef2a40..4c6afcadd4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 64, - 32, - 32, - 2, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 128, - 64, - 32, - 32, - 2, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index e05304e317..bd78020f66 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 256, - 64, - 32, - 32, - 2, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 256, - 64, - 32, - 32, - 2, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 256, + 64, + 32, + 32, + 2, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 005b98e6f2..749d6d0cc4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 64, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 128, - 64, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip index 43440d9f09..bef3ac6b36 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip index 59892b1829..78c12db8cc 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 16, - 256, - 128, - 16, - 16, - 1, - 4, - S<8, 16, 1>, - S<8, 16, 1>, - S<1, 16, 1, 16>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 16, + 256, + 128, + 16, + 16, + 1, + 4, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 16>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index 47be5bd05b..2b5557688b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 224, - 256, - 128, - 16, - 16, - 7, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip index b4c7f93443..2cc5e909c1 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 128, - 64, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 128, - 64, - 32, - 32, - 4, - 2, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip index c27587b65d..9fa6aa869a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip index 5410a90837..131f6fbf63 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 16, - 128, - 16, - 16, - 4, - 1, - S<8, 32, 1>, - S<8, 16, 1>, - S<1, 32, 1, 8>, - S<2, 2, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 32, 1>, + S<8, 16, 1>, + S<1, 32, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip index f728fd9cfe..ebd4aefa4b 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 224, - 128, - 16, - 16, - 8, - 7, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 64, 1, 4>, - S<8, 8, 1>, - 2, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index 165d3cb906..f12551c220 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 128, - 16, - 16, - 8, - 8, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip index fd63d9befd..2afb8db044 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 16, - 16, - 8, - 8, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 16, - 16, - 8, - 8, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 2, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip index 872c8d6756..5e5a80d373 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 32, - 32, - 4, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 256, - 64, - 32, - 32, - 4, - 4, - S<4, 64, 1>, - S<4, 64, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v4, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip index 6a391e1da0..3b0869c7c5 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip index 3d051451a6..f1e8106382 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 256, - 32, - 128, - 32, - 32, - 2, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 256, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip index 7815fb9334..719156d338 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip index fc8bdf60d0..b4e5e543f7 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 32, - 256, - 128, - 32, - 32, - 1, - 2, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 16, 1, 16>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 32, + 256, + 128, + 32, + 32, + 1, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 16, 1, 16>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip index 35700d8098..914754bc7e 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 64, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 256, - 64, - 64, - 128, - 32, - 32, - 1, - 1, - S<8, 32, 1>, - S<8, 32, 1>, - S<1, 32, 1, 8>, - S<8, 8, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v3, - ck::tensor_operation::device::GemmSpecialization::Default>; + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip index 8714c5d8e7..ee9323528a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 749d0a3c94..0f94edb867 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip index 76388b8bba..d698dc14e6 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index 5cc809bef8..b7d9d65387 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 128 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 128 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 128 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 128, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip index bf8b430f0e..f70c7189dc 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip @@ -8,14 +8,15 @@ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y) { + OutputType Y) { // Secret kernel that seems good with small M but large N and K. using DeviceGemmInstance = DeviceGemmHelper< 64, @@ -36,5 +37,24 @@ fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intr ck::BlockGemmPipelineVersion::v1, ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip index 3468432a90..75edf99d80 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 1f2339131e..6d443edd2a 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip index 2f28b951ea..3c0df9b287 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index 673d19a0f4..d83e66795c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 256 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 256 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 256 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 256, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index a832dcd1e3..680f364862 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -8,14 +8,15 @@ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y) { + OutputType Y) { // The smallest kernel we have available. Works well for memory bound shapes. using DeviceGemmInstance = DeviceGemmHelper< 64, @@ -36,5 +37,24 @@ fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_inte ck::BlockGemmPipelineVersion::v2, ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. - return f8f8bf16_rowwise_grouped_impl(XQ, WQ, x_scale, w_scale, kernel_args, Y); + return f8f8bf16_rowwise_grouped_impl( + XQ, WQ, x_scale, w_scale, kernel_args, Y); } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip index b9d76db772..5dc1dee8c9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 18236d9db4..953aa3bdb2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip index 17f17cee6f..a661a0fa91 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v1, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index e9efd64a56..262c4bfd02 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 512 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 512 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 512 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 512, - 16, - 16, - 1, - 1, - S<8, 8, 1>, - S<8, 8, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip index 6b7d7553d0..ad4f484719 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Interwave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip index 3de229cd4d..e1b9fab701 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2.hip @@ -1,72 +1,94 @@ /* -* Copyright (c) Meta Platforms, Inc. and affiliates. -* All rights reserved. -* -* This source code is licensed under the BSD-style license found in the -* LICENSE file in the root directory of this source tree. -*/ + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ #include "fp8_rowwise_grouped_common.h" -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y) { + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y) { // Check if this input needs to be padded. bool pad = false; - for (int i = 0; i < XQ.size(); i++) { - int K = XQ[i].size(1); - if (K % 64 != 0) { - pad = true; + if constexpr (std::is_same_v) { + int K = XQ.size(2); + pad = K % 64 != 0; + } else { + for (int i = 0; i < XQ.size(); i++) { + int K = XQ[i].size(1); + pad = K % 64 != 0; } } if (pad) { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::KPadding>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } else { using DeviceGemmInstance = DeviceGemmHelper< - 64, - 16, - 16, - 64, - 16, - 16, - 1, - 1, - S<4, 16, 1>, - S<4, 16, 1>, - S<1, 16, 1, 4>, - S<4, 4, 1>, - 1, - 1, - ck::BlockGemmPipelineScheduler::Intrawave, - ck::BlockGemmPipelineVersion::v2, - ck::tensor_operation::device::GemmSpecialization::Default>; + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; // Run kernel instance. return f8f8bf16_rowwise_grouped_impl( - XQ, WQ, x_scale, w_scale, kernel_args, Y); + XQ, WQ, x_scale, w_scale, kernel_args, Y); } } + +template std::vector +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::TensorList XQ, + at::TensorList WQ, + at::TensorList x_scale, + at::TensorList w_scale, + at::Tensor kernel_args, + std::vector Y); + +template at::Tensor +fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h index 806adabca2..02086ea8b8 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_common.h @@ -192,3 +192,78 @@ std::vector f8f8bf16_rowwise_grouped_impl( return Y; } + +// Dynamic variant of kernel launch. +template +at::Tensor f8f8bf16_rowwise_grouped_impl( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor kernel_args, + at::Tensor Y) { + // Get input information. + int group_count = XQ.size(0); + using KernelArguments = + ck::tensor_operation::device::GroupedGemmTileLoopKernelArguments<2>; + using GemmDesc = ck::tensor_operation::device::GemmDesc; + // Create gemm shape containers. + std::vector gemm_descs; + // Create container for input arguments. + std::vector A_args; + std::vector B_args; + std::vector C_args; + std::vector> D_args = {}; + // Reserve space in argument arrays. + gemm_descs.reserve(group_count); + A_args.reserve(group_count); + B_args.reserve(group_count); + C_args.reserve(group_count); + D_args.reserve(group_count); + // Populate arguments. + int M = XQ.size(1); + int K = XQ.size(2); + int N = WQ.size(1); + for (int i = 0; i < group_count; i++) { + // Set the shape arguments for this gemm. + GemmDesc gemm_desc = {M, N, K, K, K, N, {0, 0}}; + gemm_descs.push_back(gemm_desc); + // Set pointers to inputs and outputs. + A_args.push_back(reinterpret_cast( + XQ.data_ptr() + i * M * K)); + B_args.push_back(reinterpret_cast( + WQ.data_ptr() + i * N * K)); + C_args.push_back( + reinterpret_cast(Y.data_ptr() + i * M * N)); + D_args.emplace_back(std::array{ + reinterpret_cast(w_scale.data_ptr() + i * N), + reinterpret_cast(x_scale.data_ptr() + i * M)}); + } + + // Create gemm launcher and arguments. + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + // Setup Gemm arguments. + auto argument = gemm.MakeArgument( + A_args, + B_args, + D_args, + C_args, + gemm_descs, + a_element_op, + b_element_op, + cde_element_op); + + // Set gemm kernel arguments. + gemm.SetDeviceKernelArgs(argument, kernel_args.data_ptr()); + + // Get hip graph stream if it exists. + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); + + return Y; +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h index ff7de71e75..181d8b27f9 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/kernels/fp8_rowwise_grouped_kernel_manifest.h @@ -6,804 +6,714 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include - #include -#define KERNEL_NAME_MAP_ENTRY(name) \ - { #name, name } - -using RowwiseGroupedKernel = std::function( - at::TensorList, - at::TensorList, - at::TensorList, - at::TensorList, - at::Tensor, - std::vector)>; - -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, at::Tensor kernel_args, - std::vector Y); + OutputType Y); -std::vector +template +OutputType fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor kernel_args, - std::vector Y); - -// Map function for string name to kernel implementation for manual -// specification. -static const std::unordered_map kernel_name_map = { - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v4), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x256x64_32x32_2x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_intrawave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v1), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x32x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x256x16x128_16x16_4x1_8x32x1_8x16x1_1x32x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x64x16x128_16x16_2x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x256_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x16x128x128_16x16_1x4_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x16x256x128_16x16_1x4_8x16x1_8x16x1_1x16x1x16_4x4x1_1x1_interwave_v2), - KERNEL_NAME_MAP_ENTRY( - fp8_rowwise_grouped_256x32x256x128_32x32_1x2_8x32x1_8x32x1_1x16x1x16_8x8x1_1x1_interwave_v2), -}; + InputType XQ, + InputType WQ, + InputType x_scale, + InputType w_scale, + at::Tensor kernel_args, + OutputType Y); diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index 9d24ecb9d0..2378be94b4 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -97,16 +97,13 @@ std::vector f8f8bf16_rowwise_grouped( at::TensorList WQ, at::TensorList x_scale, at::TensorList w_scale, - std::optional> output = std::nullopt, - std::optional kernel_name = std::nullopt); + std::optional> output = std::nullopt); at::Tensor f8f8bf16_rowwise_grouped_dynamic( - at::TensorList XQ, - at::TensorList WQ, - at::TensorList x_scale, - at::TensorList w_scale, - at::Tensor zero_start_index_M, - std::optional kernel_name = std::nullopt); -std::vector get_f8f8bf16_rowwise_grouped_kernels(); + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor zero_start_index_M); at::Tensor f8f8bf16_blockwise( at::Tensor XQ, at::Tensor WQ, @@ -197,14 +194,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { #endif #ifdef USE_ROCM m.def( - "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None, str? kernel_name=None) -> Tensor[]"); + "f8f8bf16_rowwise_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor[](a!)? output=None) -> Tensor[]"); m.def( - "f8f8bf16_rowwise_grouped_dynamic(Tensor[] XQ, Tensor[] WQ, Tensor[] x_scale, Tensor[] w_scale, Tensor zero_start_index_M, str? kernel_name=None) -> Tensor"); - - m.def("get_f8f8bf16_rowwise_grouped_kernels() -> str[]"); - m.impl( - "get_f8f8bf16_rowwise_grouped_kernels", - get_f8f8bf16_rowwise_grouped_kernels); + "f8f8bf16_rowwise_grouped_dynamic(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor zero_start_index_M) -> Tensor"); #endif m.def( "bf16bf16bf16_grouped(Tensor[] X, Tensor[] W, Tensor[](a!)? output=None) -> Tensor[]");