diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 98c6f1476417..f9b9f315013c 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -2,7 +2,7 @@ import torch import triton.language as tl import pytest -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, TypedDict import os import json import functools @@ -101,7 +101,7 @@ def moe_gemm_kernel( offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) out_ptrs = Out + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] c_mask = token_mask[:, None] & (offs_cn[None, :] < N) - tl.store(out_ptrs, accumulator, mask=c_mask) + tl.store(out_ptrs, accumulator.to(Out.dtype.element_ty), mask=c_mask) def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int, @@ -290,10 +290,10 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: to return c -def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool): - a = torch.randn((M, K), dtype=torch.float32, device='cuda') - b = torch.randn((E, N, K), dtype=torch.float32, device='cuda') - c = torch.zeros((M, top_k, N), dtype=torch.float32, device='cuda') +def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype): + a = torch.randn((M, K), dtype=dtype, device='cuda') + b = torch.randn((E, N, K), dtype=dtype, device='cuda') + c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda') values = torch.randn(M, E, device='cuda') @@ -329,10 +329,10 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool (64, 128, 64, 2, 8), ]) @pytest.mark.parametrize('routed_weight', [True, False]) -def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool): +def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16): torch.manual_seed(20) a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( - M, K, N, top_k, E, routed_weight=routed_weight) + M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) # TODO Quantization support tri_out = moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config) @@ -349,6 +349,62 @@ def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: # Validate correctness torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2) +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + +def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, + shard_intermediate_size: int, hidden_size: int, topk: int, + dtype: torch.dtype, use_fp8_w8a8: bool, + use_int8_w8a16: bool) -> None: + dtype_str = get_config_dtype_str(dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8) + + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + filename = get_config_file_name(num_experts, shard_intermediate_size // 2, + dtype_str) + + print(f"Writing best config to {filename}...") + with open(filename, "w") as f: + json.dump(configs, f, indent=4) + f.write("\n") + +def get_rocm_tuning_space(use_fp16): + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + if not use_fp16: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] + kpack_range = [1, 2] if use_fp16 else [] + + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + } + if use_fp16: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + + return param_ranges + def get_configs(): configs = [ @@ -367,6 +423,8 @@ def get_configs(): def run_benchmark(custom, args): print_time = args.return_time routed_weight = args.routed_weight + dtype = arg_to_torch_dtype[args.dtype] + tune = args.tune if custom: assert args.M and args.K and args.N and args.E and args.top_k, \ "Please provide M, K, N, E, top_k for custom runs." @@ -382,12 +440,12 @@ def run_benchmark(custom, args): benchmark = triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], line_names=[line_names], styles=[('red', '-'), ('blue', '-')], ylabel='ms', plot_name='moe-gemm-benchmark', - args={'print_time': print_time, 'routed_weight': routed_weight}) + args={'dtype': dtype, 'print_time': print_time, 'routed_weight': routed_weight}) @triton.testing.perf_report([benchmark]) - def bench_moe_gemm(M, K, N, E, top_k, routed_weight, print_time, provider): + def bench_moe_gemm(M, K, N, E, top_k, dtype, routed_weight, print_time, provider): a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( - M, K, N, top_k, E, routed_weight=routed_weight) + M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) flops = 2.0 * M * top_k * K * N if routed_weight: @@ -417,10 +475,13 @@ def parse_args(): parser.add_argument("-E", type=int, default=0, help="Number of experts") parser.add_argument("-top_k", type=int, default=0, help="top_k experts per token") parser.add_argument("-routed_weight", action='store_true', default=False) + parser.add_argument("-tune", action='store_true', default=False) + parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False, help='Return time instead of TFLOPs') args = parser.parse_args() return args +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} def main(): args = parse_args()