From 71649a64920f5b74016d34f2f13b2c238a0a2f53 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Mon, 2 Dec 2024 13:28:57 +0000 Subject: [PATCH 01/21] Implemented moe gemm, test and benchmarking. The gemm support weights so as the benchmark. You can tune the gemm with bencmark option -tune --- .../amd_perf_kernel_Integration_tests.yml | 1 + python/perf-kernels/README.md | 3 + .../perf-kernels/fused_moe/benchmark_utils.py | 211 +++++++++ ...14336,device_name=AMD_Instinct_MI300X.json | 200 ++++++++ ...=1792,device_name=AMD_Instinct_MI300X.json | 200 ++++++++ ...=3584,device_name=AMD_Instinct_MI300X.json | 200 ++++++++ ...=7168,device_name=AMD_Instinct_MI300X.json | 200 ++++++++ python/perf-kernels/fused_moe/moe-gemm.py | 441 ++++++++++++++++++ 8 files changed, 1456 insertions(+) create mode 100644 python/perf-kernels/fused_moe/benchmark_utils.py create mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json create mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json create mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json create mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json create mode 100644 python/perf-kernels/fused_moe/moe-gemm.py diff --git a/.github/workflows/amd_perf_kernel_Integration_tests.yml b/.github/workflows/amd_perf_kernel_Integration_tests.yml index b6a12841a646..0a263f68efb7 100644 --- a/.github/workflows/amd_perf_kernel_Integration_tests.yml +++ b/.github/workflows/amd_perf_kernel_Integration_tests.yml @@ -129,6 +129,7 @@ jobs: pytest -vvvv ./python/perf-kernels/softmax.py pytest -vvv ./python/perf-kernels/rmsnorm.py pytest -vvv ./python/perf-kernels/layernorm.py + pytest -vvv ./python/perf-kernels/fused_moe/moe-gemm.py sh ./python/perf-kernels/streamk/utils/unittest.sh pytest -vvv ./python/perf-kernels/multreduce_matmul_kernel.py - name: Run Perf Kernels Benchmark diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index d590aa06d573..91283a129b0c 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -99,3 +99,6 @@ Kernel that implements RMS Norm over a row of tensor. ## `layernorm.py` Kernel that implements Layer Normalization over a row on tensor + +## `fused_moe/moe-gemm.py` +Kernel that implements moe gemm diff --git a/python/perf-kernels/fused_moe/benchmark_utils.py b/python/perf-kernels/fused_moe/benchmark_utils.py new file mode 100644 index 000000000000..907e5c836b86 --- /dev/null +++ b/python/perf-kernels/fused_moe/benchmark_utils.py @@ -0,0 +1,211 @@ +from typing import TypedDict, List, Optional +from itertools import product +import json +import torch +import os + + +def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False): + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: + device_name = torch.cuda.get_device_name(0).replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + + +def get_tuning_space(use_fp16): + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [16, 32, 64, 128, 256] + if not use_fp16: + block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + num_stage_range = [1, 2] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] + kpack_range = [1, 2] if use_fp16 else [] + + param_ranges = { + "BLOCK_SIZE_M": block_mn_range, + "BLOCK_SIZE_N": block_mn_range, + "BLOCK_SIZE_K": block_k_range, + "GROUP_SIZE_M": group_m_range, + "num_warps": num_warps_range, + "num_stages": num_stage_range, + "waves_per_eu": waves_per_eu_range, + } + if use_fp16: + param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range + param_ranges["kpack"] = kpack_range + + return param_ranges + + +def prune_configs(M, N, K, configs, is_fp16=True): + pruned_configs = [] + elemBytes_a = 2 if is_fp16 else 1 + elemBytes_b = 2 if is_fp16 else 1 + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + + if is_fp16: + matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") + if matrix_instr_nonkdim > mfma: + continue + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + # some layouts could not work properly in case + # number elements per thread is less 1 + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = config.get("SPLIT_K", 1) + GROUP_M = config.get("GROUP_SIZE_M") + if is_fp16: + if (matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N): + continue + if (matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M): + continue + if (matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N): + continue + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 128 or BLOCK_SIZE_N < 128: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def merge_unique_dicts(list1, list2): + result = [] + combined_list = list1.copy() + combined_list.extend(list2) + for dictionary in combined_list: + if dictionary not in result: + result.append(dictionary) + return result + + +def prune_search_space(num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16): + N1, K1 = shard_intermediate_size, hidden_size + + pruned_space_1 = prune_configs(num_tokens * 2, N1, K1, search_space, is_fp16) + # NOTE, we are only tunning thr gemm here so only one pass of moe + # pruned_space_2 = prune_configs(num_tokens * 8, N2, K2, search_space, + # is_fp16) + # search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) + return pruned_space_1 + + +def update_configs(M: int, config: BenchmarkConfig, num_experts: int, shard_intermediate_size: int, hidden_size: int, + topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: + dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) + + # NOTE(woosuk): The current naming convention uses w2.shape[2], which + # is the intermediate size after silu_and_mul. + # NOTE, we are only tunning thr gemm here so no // 2 + # filename = get_config_file_name(num_experts, shard_intermediate_size // 2, + # dtype_str) + + filename = get_config_file_name(num_experts, shard_intermediate_size, dtype_str) + print(f"Best config: {config}") + print(f"Writing best config to {filename}...") + + config_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", filename) + # 1) Read the existing JSON file if it exists + old_configs = {} + if os.path.isfile(config_file_path): + with open(config_file_path, "r") as f: + try: + old_configs = json.load(f) + except json.JSONDecodeError: + # If the file is empty or corrupt, we just ignore it + old_configs = {} + + # 2) Update existing data with new configs + # If they share any keys, the new 'configs' will overwrite + old_configs[str(M)] = config + # old_configs[configs.keys()[0]] = configs[configs.keys()[0]] + + # 3) Write back to the same file + with open(config_file_path, "w") as f: + json.dump(old_configs, f, indent=2) + f.write("\n") + + +def get_tuning_configs(M, N, K, use_fp16): + param_ranges = get_tuning_space(use_fp16) + configs: List[BenchmarkConfig] = [] + + keys, values = zip(*param_ranges.items()) + for config_values in product(*values): + config = dict(zip(keys, config_values)) + configs.append(config) + + configs = prune_search_space(num_tokens=M, shard_intermediate_size=N, hidden_size=K, search_space=configs, + is_fp16=use_fp16) + + return configs diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..2cc0f41254eb --- /dev/null +++ b/python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 1, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..2d799bc0f4e9 --- /dev/null +++ b/python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..89ece30f9c15 --- /dev/null +++ b/python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + } +} diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..20529c52ee2d --- /dev/null +++ b/python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,200 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1024": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 1, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py new file mode 100644 index 000000000000..44ac71d17c8e --- /dev/null +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -0,0 +1,441 @@ +import triton +import torch +import triton.language as tl +import pytest +from typing import Any, Dict, Optional, Tuple +import os +import json +import functools +import argparse +import sys +from benchmark_utils import get_tuning_configs, get_config_file_name, update_configs + + +@triton.jit +def moe_gemm_kernel( + A, + B, + Out, + stride_am, + stride_ak, + stride_be, + stride_bn, + stride_bk, + stride_cm, + stride_cn, + top_k: tl.constexpr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + EM: tl.constexpr, + N: tl.constexpr, + K: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + + # Here we assume that valid tokens are in the range [0, M). + token_mask = (offs_token >= 0) & (offs_token < EM) + + off_experts = tl.load(expert_ids_ptr + pid_m) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + off_experts * stride_be + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Masking ensures we don't load from invalid tokens or indices + a = tl.load(a_ptrs, mask=(token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K)), other=0.0) + b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K), other=0.0) + + accumulator = tl.dot(a, b, acc=accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + out_ptrs = Out + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(out_ptrs, accumulator.to(Out.dtype.element_ty), mask=c_mask) + + +def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int, + sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, + num_tokens_post_pad: torch.Tensor) -> None: + M, top_k = topk_ids.shape + + expert_to_tokens = [[] for _ in range(num_experts)] + # For each token, for each selected expert, we append (token_id, expert) + for token_id in range(M): + for j in range(top_k): + e_id = topk_ids[token_id, j].item() + expert_to_tokens[e_id].append(token_id * top_k + j) + + # Reorder tokens block by block, padding if needed + reordered_token_ids = [] + reordered_expert_ids = [] + + for e_id in range(num_experts): + tokens_for_expert = expert_to_tokens[e_id] + num_tokens = len(tokens_for_expert) + + n_blocks = ((num_tokens + block_size - 1) // block_size) + # If not a multiple of block_size, pad up to the next multiple + padded_size = n_blocks * block_size + + # Reorder all actual tokens for expert e_id + reordered_token_ids.extend(tokens_for_expert) + # reordered_expert_ids.extend([e_id]*num_tokens) + reordered_expert_ids.extend([e_id] * n_blocks) + + # Pad with dummy token_id = -1 (or any sentinel), if needed + if padded_size > num_tokens: + pad_count = padded_size - num_tokens + reordered_token_ids.extend([-1] * pad_count) + + token_length = len(reordered_token_ids) + expert_length = len(reordered_expert_ids) + + sorted_token_ids[:token_length] = torch.tensor(reordered_token_ids, dtype=sorted_token_ids.dtype, + device=sorted_token_ids.device) + expert_ids[:expert_length] = torch.tensor(reordered_expert_ids, dtype=expert_ids.dtype, device=expert_ids.device) + + # Fill remainder with -1 if these arrays are bigger than total_length + if token_length < sorted_token_ids.numel(): + sorted_token_ids[token_length:] = -1 + if expert_length < expert_ids.numel(): + expert_ids[expert_length:] = -1 + + num_tokens_post_pad.fill_(token_length) + + +def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. + """ + top_k = topk_ids.shape[1] + sorted_ids = torch.empty((topk_ids.numel() + num_experts * (block_size - 1), ), dtype=torch.int32, + device=topk_ids.device) + expert_ids = torch.empty((topk_ids.numel() + num_experts, ), dtype=torch.int32, device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) + _moe_align_block_size(topk_ids, num_experts, top_k, block_size, sorted_ids, expert_ids, num_tokens_post_pad) + + return sorted_ids, expert_ids, num_tokens_post_pad + + +@functools.lru_cache +def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_config_file_name(E, N, dtype) + + config_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + return None + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: + config = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8} + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = {'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 1} + return config + + +def try_get_optimal_moe_config( + b_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + is_marlin: bool = False, +): + E, N, K = b_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, K, top_k, dtype, is_marlin) + + return config + + +def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, + config) -> None: + # TODO shard M dim + _, top_k = topk_ids.shape + + EM = num_tokens_post_padded.item() + _, N, K = b.shape + grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + moe_gemm_kernel[grid](a, b, c, a.stride(0), a.stride(1), b.stride(0), b.stride(1), b.stride(2), c.stride(1), + c.stride(2), top_k, topk_weights, sorted_token_ids, expert_ids, EM, N, K, + MUL_ROUTED_WEIGHT=topk_weights is not None, **config) + return c + + +def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype): + a = torch.randn((M, K), dtype=dtype, device='cuda') + b = torch.randn((E, N, K), dtype=dtype, device='cuda') + c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda') + + values = torch.randn(M, E, device='cuda') + + softmax_vals = torch.softmax(values, dim=1) + topk_weights, topk_ids = torch.topk(softmax_vals, k=top_k, dim=1) + + config_dtype = None + get_config_func = functools.partial( + try_get_optimal_moe_config, + b.shape, + topk_ids.shape[1], + config_dtype, + ) + config = get_config_func(M) + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(topk_ids, config['BLOCK_SIZE_M'], E) + + if not routed_weight: + return a, b, c, None, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config + + return a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config + + +@pytest.mark.parametrize("M, K, N, top_k, E", [ + (64, 4096, 14336, 2, 8), + (16, 1, 14336, 2, 4), + (1, 128, 14336, 2, 4), + (16, 128, 14336, 1, 4), + (16, 128, 14336, 1, 1), + (64, 128, 7186, 2, 8), + (64, 128, 3584, 2, 8), + (64, 128, 1792, 2, 8), + (64, 128, 64, 2, 8), +]) +@pytest.mark.parametrize('routed_weight', [True, False]) +def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16): + torch.manual_seed(20) + a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( + M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) + + tri_out = moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config) + + ref_out = torch.empty_like(c) + # Repeat a -> (M, top_k, K) + a_expanded = a.unsqueeze(1).repeat(1, top_k, 1) + # (M, top_k, N, K) + b_indexed = b[topk_ids] + ref_out = torch.einsum("mek,menk->men", a_expanded, b_indexed) + if routed_weight: + ref_out *= topk_weights.unsqueeze(-1) + + # Validate correctness + torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2) + + +def get_configs(): + configs = [ + {"M": 64, "K": 128, "N": 256, "E": 8, "top_k": 2}, + {"M": 64, "K": 1024, "N": 1792, "E": 8, "top_k": 2}, + {"M": 1024, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, + {"M": 4096, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, + {"M": 1024, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, + {"M": 4096, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, + ] + return configs + + +def run_benchmark(custom, args): + print_time = args.return_time + routed_weight = args.routed_weight + dtype = arg_to_torch_dtype[args.dtype] + use_fp16 = args.dtype == 'fp16' + tune = args.tune + if custom: + assert args.M and args.K and args.N and args.E and args.top_k, \ + "Please provide M, K, N, E, top_k for custom runs." + configs = [{"M": args.M, "K": args.K, "N": args.N, "E": args.E, "top_k": args.top_k}] + else: + configs = get_configs() + + x_names = ['M', 'K', 'N', 'E', 'top_k'] + x_vals_list = [(cfg['M'], cfg['K'], cfg['N'], cfg['E'], cfg['top_k']) for cfg in configs] + + line_names = 'Time (ms)' if print_time else 'TFLOPS' + + benchmark = triton.testing.Benchmark( + x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], line_names=[line_names], + styles=[('red', '-'), ('blue', '-')], ylabel='ms', plot_name='moe-gemm-benchmark', args={ + 'dtype': dtype, 'use_fp16': use_fp16, 'tune': tune, 'print_time': print_time, 'routed_weight': routed_weight + }) + + @triton.testing.perf_report([benchmark]) + def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, tune, print_time, provider): + a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( + M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) + + flops = 2.0 * M * top_k * K * N + if routed_weight: + flops += M * top_k * N + + if tune: + configs = get_tuning_configs(M, N, K, use_fp16) + print(f"Tuning start with {len(configs)} configs") + + min_ms = None + best_config = None + for config in configs: + print(config) + fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, + num_tokens_post_padded, config) + ms = triton.testing.do_bench(fn) + print(ms) + c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda') + if min_ms is None or ms < min_ms: + min_ms = ms + best_config = config + + update_configs(M, best_config, E, N, K, top_k, dtype, False, False) + else: + fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, + config) + ms = triton.testing.do_bench(fn) + + if print_time: + return ms + else: + # Convert flops to TFLOPs + return flops / ms * 1e-9 + + bench_moe_gemm.run(save_path=".", print_data=True) + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark MoE GEMM", + allow_abbrev=False, + ) + parser.add_argument("-M", type=int, default=0, help="M dimension") + parser.add_argument("-K", type=int, default=0, help="K dimension") + parser.add_argument("-N", type=int, default=0, help="N dimension") + parser.add_argument("-E", type=int, default=0, help="Number of experts") + parser.add_argument("-top_k", type=int, default=0, help="top_k experts per token") + parser.add_argument("-routed_weight", action='store_true', default=False) + parser.add_argument("-tune", action='store_true', default=False) + parser.add_argument("-dtype", default='fp16') + parser.add_argument("-return_time", action='store_true', default=False, help='Return time instead of TFLOPs') + args = parser.parse_args() + return args + + +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def main(): + args = parse_args() + custom_config = False + # If user provides all M,K,N,E,top_k we consider it custom + if args.M and args.K and args.N and args.E and args.top_k: + custom_config = True + run_benchmark(custom_config, args) + + +if __name__ == '__main__': + sys.exit(main()) From 5adb9715d0469a387fcad6de213c60851387570e Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 31 Dec 2024 16:17:34 +0000 Subject: [PATCH 02/21] removed benchmark files --- moe-gemm-benchmark.csv | 7 +++++++ moe-gemm-benchmark.png | Bin 0 -> 20192 bytes results.html | 3 +++ 3 files changed, 10 insertions(+) create mode 100644 moe-gemm-benchmark.csv create mode 100644 moe-gemm-benchmark.png create mode 100644 results.html diff --git a/moe-gemm-benchmark.csv b/moe-gemm-benchmark.csv new file mode 100644 index 000000000000..b37672c102a0 --- /dev/null +++ b/moe-gemm-benchmark.csv @@ -0,0 +1,7 @@ +M,K,N,E,top_k,TFLOPS +64.000000,128.000000,256.000000,8.000000,2.000000,0.106786 +64.000000,1024.000000,1792.000000,8.000000,2.000000,5.040689 +1024.000000,4096.000000,7168.000000,8.000000,2.000000,250.097741 +4096.000000,4096.000000,7168.000000,8.000000,2.000000,394.569816 +1024.000000,4096.000000,14336.000000,8.000000,2.000000,313.915161 +4096.000000,4096.000000,14336.000000,8.000000,2.000000,421.260017 diff --git a/moe-gemm-benchmark.png b/moe-gemm-benchmark.png new file mode 100644 index 0000000000000000000000000000000000000000..b34d965cf98e9e5ffb216b24c6b7a5e8d8a4cfc1 GIT binary patch literal 20192 zcmeIaXH*p3)-GBkK?MN?5hN)f0+K-_qexajvP2OT0SQWy&;%7Q0ZPt6$x$Vz21P+7 zH%T&xWT8|KxoYj^a?jcAp50AOkJ}Gi z?VO#&g+xV!M9y*Eb#rrZl@S(p{O>!2oFCWVjr@rMeYw1@`rc zTkXF_JFH2)eA-*%=^90-Nf4w#@=~%8d=r9ILC6urg_)ESK>}Ln5cr#UFhYeO<6`7W z@b}nL2q}V8Qj^fZk8V>1z$GE&|NkfduV?I2vJmI1=;&y>F)Sv}svdjk(xoIJN(3=@ zY<=Q;M~Cr)2Olmek4PU^MdS`|tu5+nYCiA24ND>Z&J%Ab>Co3YC&qda{Kp-z6ioek2n}!yY_kwKUQqh942jXXjec+#`J39 zdBa<`GP*OtAWo3p5abgO=q|FTVK{y~>({SaDyphte}8c{iWnFeWROuFx!Yf4VLwt` z=D@T^jszekTF;+9H!w2Fii%=TQB#w(IKtgqXl6`LPd`}aQ-S_1TbnIgixg>~rw;59 zQ;pv~JVOTij}ykf@ZY9b{OZ-KNNkGtJm1dtN{3fW-Kl(B_!x%N*m zd|Ae0O4|7@HI80GV~fnTM5+ASEio)&x1Pcj5T+BI^WdM~zkkR5{PMzmwllqC=Pj&9 z4m~mv9ToL!B<#fteRcJxj;Y+v$i1GNYnkQx(%ys|)lim9hoUM~|D6S^x-ES5f)LM{ zGxG@!MO{hKt`fdP7iTkH-`a?Xh^VNjjD9PBi@M503p8qAZ9nIYWMBv%4)NsCfR`t7KWXtkj)5db?tM9*w_R&KC(PX z_1~>J!X}du!tR};H`HrvRe0WUaCWg`tem3ZuiV?0Dk>_^pFO*F<;u(Hj?`iW%;s5s z{yZZGN~9+Or{^ac}TO2YT7%j4p#fvhs$$DG~mtTwf^yyPiv6W&_NQlGQ?*SG` z`=~o@37PEj-Z@|$ETT7`RJ+YR2(BYVda7}Hc8v!FmIbYxXni9#YW_xOWoQXU7gc^* zm#jue=xl$Dc#f4(RKaA~`xjfE=v7!-uivYK7r1Q`?@QX@F8l9ojn$Dje|>d|MegCd z7-6d-=lUI^!%RXwFPX&&74O7a{Jk4m9wnM@(xt~)D~+isz+uo`ASpv-xs+K^4g z?cJgem{oL2N)cEyd`Jdvy(d>+E!N7fN~nJOf&4CUvU;JYhVUy==O{wbz&AyfYWoC# zD-e&<%3K;b=&R6_S}q*)q=#s5kCR#CMgYtCdpUa>se8E$h~nlf*A8(q+uE%WW#g&18Ldz(zP9s#lM<#VR%HC}M6=jDt=&{=U z%M^M-^pfWorlKi~E&s<46G6*blgZ}r{?2q|$$LNE+1Kr*D~GsEhDyzc$u8-}3RxJ` z`(X;Nm#B*0Y3(fQw>0Y>AweESMX4|JT@S=B6z3oD#cU0dS((IjXnlD`k6aT_eR<-V z-PmZIkEGkIE@rjYI6=lePpInGQAP%a9QTDjZ$g)v``*s_P#+=}54&}BX~awzOMK(8 zG@SeM=S^^ucVX9f|9Y)vUoDG_Tk`PWJjbs0gDD72%?yP5 z*3zPHZl3hO)s^w&Ne2gkvu8WO+V|wVzIpSe#I3J%w{PDbz&dxN$fM@`=Iz_)SFh6kOg86dWMrs1OrCMr`?P`3 zEkA$W`fB}m2`+$Sny{$Cq!F&bFKBuzxL!3hcjCsO2VIJ4u_@Ts^>3WbpLp7)xEUTeaV!8v8Q|QzrQ#<*m+if*T&|u{g*lDcO4x5r?m;aGmb-^iun1MxvNl-VxuFcb+goTWjPBNznl` z0r=XLFznWnY@uacwVv6XJ6Wf>xEMJ&oQG25;&NJBui!k7^{ux^ zl$-l9iU=-UdIT$(zk|tPgJ6Eh<~kXoEpRT%sHn7WVr%+jZd!ximB=$|d?QELxbM8y zeM~BbEREC{ua9`9LInEN-RyNWK|*7)DRjjzIW{&MoUXo~m5IrTk7USjURv5;DyuII zv9E5V;vC0P8bVANUy-;h4wijyZ#N2Om#^K@)Y2-uHz2ly-`g!#*!nHb$jHdY%iDbZ zP+WBMRPwHtrY0`WF#jHYp-PK*S^vqkYu6yq=1q(u$Ro{*fzqFh6_%rGeAOeEiQjeR z${p0kAL~%uLVl0c*Fzw?HRQRqSdnj0asIQfuseN+e8>FEq?dWlfKwYP zC&xzeor}WmS$=Pi*<6qPL?%{5EDP8^RtV5g#(nTq;74^~@z{NSi-t0b2VYY_8lJH!It9yh zNctL5U^BS4Bz?)T$7Fp?2H#yYtsnKq?A86c9=jeYl`-VxT7>0<*?lKt>H9||C^*<2 z2FZW!Tu(zF+3KJ}vDZ-Jr%$CfK9YvA%bVKS+2t7*r>KUp=JM%hT`-YE8t@NEtv2Yf zZ%%KCrZmiWmDwqe=YJK+u&DWc(Yf(}T}R3>{|*(V`0xVd(8IItNOF_RC4}xt7>88x z&_nr&JXT`6bhE-I&rz?XBfbuUx~CQv z7uexSGnQu zW6ZYK%)nzy7-Z?gFN!6@=5ueuTVG>46iPy{?a76m{pF5=ZighF!KW=dtA(ix7?}h* zwtg*B6XY@2w&;#K9Y(PFn^!P!)Mrr5y0r&e?JUPex!&7WnTf|4uww@9*$ED{cnO8 z*B;3)vt77BQ9FNS!ehzY>rjQ?j`-GZF13ts#qvv57Xy3yvO*vYu(tIrZIci6g9`T?`;6X@YsBr<>*vgVWursW}0~4Ze{xWVZPM6hbr71 zT<4-n;$ebpnlpk9w{=$B2;3-Z_L{lgU1L#mbkiLDqzkfDRk*)Xef!>Fr5&s>ih=S_ zAj@F8V@B}Qf~F?}O7BQ?i1&7Y#rk(ZWU zqsO{0@P{ND75?sDrcfT2jpQ4CeweSq&UH<{l2>2%#%P(WZI~C6^9FshYe$*trZ<6; zP{j&6)8&D*^!~)yo>;Sy3+yK2D|JCYFL+Ux zE9F$+du3~F&4CyZ;cNcOPR8TPU}-Lx$p*zs`Ae5t9T@TxKCRmF@Jw$ch^ql0T#@tU zKBK7|9G8y&`6Bb3Nb~p0F0f=uY0?inVTv5{e6NHYfyW&^hJnX@Z7tY7H+<~sL|>uV zDuYv>{IN`YKnf3^EK#ao=-%63~kSKS`gOF^DRs$6w*b8WR_&W@bN4~I02v0k`w zKpp~RqLi^hmW`UvxR(0z{nyzpguAf(&Eduh=ea-E9?bOZTo~-#1J9LJ+T&~boR<(y zTNTQpAvJs>7^@}BL-AOo<_^$gEEG#AkKz5) zW*=LD2UBuv`Tl`)4427!#CIy&?WEM&{nw1F(t2KpJiD`CqbVy3j0K&XMK`O_8O&B8 z^X7CrK;M$VpOcvjMdHdC!cwTxJ!%9gsR%syStQ8(Ha8A34zW_=5w)D1?80&1PGbg$2RHcAmF%^qdP7W zbv8ED9%WXI{W-70UqoRJRcs5$*33^Po-dRO@)`B|C<*Q-_*9|k?NGwvh(p zM(#{Sib6d4L)6Nqb%WHtuG8v1YbK7OqDHePMH+kySn1Nx@8aKD8_0IkVL)-=`Fpd8 zgVc@G8H{SH9=~(E4^>#om-PMZu5^m~B+uOVs;>zAncwb!{flR>i>jKaF=UiSH|Zqc zF}kl|!>AJdh#yDUkKR0x{Wq}q;^V<$$gCg&{yls+Ha5w8l|iT@LQ;9G+WpRKqPW^% zS&MzXu}<}-y*Pn}<(2}5oC>91iL1sw<6D2!e09SWTud!^h1weIyw>)wD z?SVXq@SMb*zb_cqV$Qt}LB1XGP2n4PAzYN)V61o3Qqe^8H=CTv>YpnUDa?%`9H}!0 zF;m&g3q7w7(4JQw8~^x3q4GEhozn6^O+UD6zdRi=GJ|mTwFV|8;}XeVGeh`Kt5+tD zA2(7sf39Gj&)omvrW)tKhcAaPQ?WmvJ&!1=#ZhBmOrt+YDj|F*jStoI!*F8M6?bkq z)6ztHuay=}_W_-B^E1OYlT(iD&@A)=9dz zFgUdhVMnTcJiE_0_+Nffr7-r+dfFQ!DIZNs&IjIXZ)K$eJmGiYR)P4yh3%EaK>U?# zB6VY}KxJNyOTJ#cAAHw2I`b93?JkzNp~|j74ti^aXOm4Q+jw1j`q&1Q=)mv0jCLpv z6;BM+4UJmr2NqZUNKzPnm+q}!*_MDhjIKv~!-J(ZA(FTXzq&F-xnbYwrFYhMW>>y^ z8ruK;vS8yoN^GVUP%up9Hz=@&h2?e{ViQq!eC^tl4Qy?QJQepRyMME`f7MaEo{q1< z7wlMD>v>l*Zx+?GdxrOunFkUaK1ClfJ%stX5(!Y^^0j#Wv zE_`!UAbj3^wt=CjcyaJ{5Vp*r5Pi_z?lfU(wO=cgMcSB)?lbHqd!r3VEMW61o{=-3 z?`<)S$*cz75o8ATreMDMhmQ|3hN>B{poOTBn!?Ts@3JyQ34_fMJb=`NV%uMg%OtSU zB{nCf@0x8ithWfS*?gyj^&ZPjX?T_BV39wwf&Z4}BI}kT7t}KQdS&hR>rtN$YK(xf z>6Wr`TPlUl>}1uK03E;yN$iax_*a68nP#HPDwltdS(y$#BTKc8<9Y%ff?sI$=SPO3 zJ=88SV}1^cvyRZkq?6YY9`l%xVy=t9-sz>by8RpTvDdU`$lDA+%n#YWxp z7j4;+2Gl;lC$8FNXAP*Fv7CW402wFjB*uJUlQlh%lIpY@)u9BoV5O!SPVNAn<_2;y{G%zvSaZD_9Hy`p1wzixxRG4-A;qsGj_{R^3 z7H{9!r@Q9d-&#*h2fuE&@nB`!$K3b5be-(2ryD2a5Wg1)Rb{BPy5jZ?2ormLTcWRD z_fe(tF83BXgpE3m*zf)o<{&cZgf$4%&ptkxiBp;1IlATOr~6z zwgebtK=&nt{E-E^`nT^4hCh2Lk$qo2u-MaiCB7nsZ`gmgJK>X{P|-H?#KiPu=Xaek zv7nPbK9pqADu^gUywkfL-`kfR&b$feDs~WtNEM*;PL$ei%fn|JLL$d8#@iHFET+w& zwA@)rd9+11aj%fLZ4q8s`Tec;;*faIN$Ml4@TB1?oUzq_vK4;RXjbd0dyy(42_TwW z1}02R$e(8H87X{o!a|{n%Js1ooE9(av(efIKO9Dr_=ayhSWc60PLal~_ZmZD*z56; z0&=-^^JYzQ-4*zrG*RjmncBbo%eQTZ!hzDY@7K2PRC*pRmHn{7G4f7(f%hS)({JW$ zTuw^lERHrD3^}5)G6Pfp=cdo>!az4QMr!nI+1&|~XNWx$3Ei?MFu%M6Luz04^mLlSZPl2OE(;MA7 zyr4jdBGdY?t3XIyk3p8p$j^$g3|d7Awaw##A3}fzF6ivfi?-n*bvXbNYkECiep%4W zXs2O@p5D6ta6il)hgEm)r?IJD>KvF`!PpMP3Y`tlxtl@Q^Ny?n^Z_vTfa2)r_?Fy< zs_QXAlER;hg|b|(U9(JK+*H_V-7Ss&C$lUB^IR4nS4u07ru79BIIFvgtx$RK-CCF@ zh|e5g<$vXX%2v%`Fq*N^Di#u=AStJA_xJA^ZHYiEY`c3IK70B~#{C0josdP`ryvVI zZ0U_a=FQ_m5=;wgcgIE#Vc2A4z}y0q5b}1T>+z*_kr>VOTWH8eEn!eJ=tUlz^l8lY zd*)4rjm!LJ-wh6yM+ehZT;kwrM9{F96dG|f*T>WKb1}c5S%Q>}NYIP!`-B{M!hR0t z zNjZLw9tkYDR#8Og)|rlEIDI;)=s+MOd@H6{Y1h`1GGS~bgQn%BS)MiSTm+JwYvfAf z_kw~GgwtezVZjfsKEb|KD=a@QSnwJPg#Zjc>_i+BlBQ`a5a z5@6fbdDCwq7oNl4Y~SXnZ~i>QYW~Ky9%NDU ztJT!h^c`VD0&^;V)V`h!lg;>i*>3s9T$@zea%|E$pce_N{kRl)+zVk%JG&=}h#wE9 z8j>jKkTutxoevDK7`3u9b|b5(Zq2mNtU?>DVUW?vrG+c*nz+ui1O1wWmN8B`RbCV@ zDtN^WbyEA~@sIFUoYa`_+nufTg+4Q9!^@YSq+r&rhxkx72s1+N}uqciM2DpB-}_Zy-@IT6t~m<>dvetG<&{(M=m0heBtBRZs6U zW;GfNli0BPv>xVSR#zNyx7Q5Er_GZQ?Zfr;WNzf|YJp^LOO`cO3ui|e#J(bne6%i- zo$V8D*I$&Zr>|c}BErR`oGk0P7Obeja=|4{HJrUr(5yViXMM@g&&b+3{ZNa%uYB<~ z#VGPfS>xG;oUP8R?88obE%jHy)Ih}W#L{gdBd(PC3~JvREVWCJ^Lo(cK_0N8shM%RrJ&PuJOIu4 z7$INdA&^iVA~}fwQJ)WGZ~-YPV?5+LC=r>d727v&tNkBjvgythq%5Pb5bNe|?EdrNF}ppbF*&(ANT<^RbbL#g1q(7ONji{pZNuqS}l9bXwF z>3Q=(%ryf)>eKcmPPop%dr?w+Tkg@-gI0Dj`UhREH*>q#` z`$4dj(W$A$6{Fr-S~2IkY*R5=P&6Fznn`Ana!Nuge^8+t@xj*tf2|KD$E`tjfLa#_ zYqu7`io37{zqhvb*o4bxiP|)Ye4-?O{6UP5tC4li2QP0orjReFU|GHE?V3^S@6UJo z(|a+fHj*$SB_$>93xyT-d%HWRc|wVgFkqWH2Q~Lt?mFVn7l&wQdS9JVHZU>CffB|% z)cjWb{7p>aj-NA0>w2puASifsb-uUz>uWB_+us?otI5ajMMrCl`3bZoJ76Jgv48O* z63`_zW^PVTGS-_k;0B|_0wNRW{sakO4laF=DUFlzECSNg$n0K$JVCUYcD$ar!pHLAzDEtcJZd@rDr65$1RKfVXqU`f+{Hdy0n2HD_Rr*U?et2?Pu z0QkaX9f*nR$;`y**4JpBhvsL~ERdhEp6xAMFo{R#o&rGK4zX5W=r>EpXQvjF&)%Jz z&u{DW{KEJ|0l9D(5`=a!YN$IqczZS<$rqF{t)Ax>AgHLx8uKFL^nl{V#HsBdC&<7z za}+|dmHF{8{CX;S+D#ZqQw7XP47OG`D>q4Z-tx#V)MgX^O>^>2aDCi~nR#-A3BEDM z0{+@v1LXJC+c&?{i~TS&_T>ATog6Q?$b}j6cU}yCDN#HHJXLLZN9N&oqygP1u5Mae zG@2c%9T?SlOH*Rp*ach6-gU}ALo@j z06!0z^IAwyd^*1AB#;b|ERqknYrx8{g{nSN zW0Zr{H8jvyEK8Ge(%oCb$NmdW3o-{+DRN(~Cf;=Q_WVl6Cp|DALxP0CKt6Jc13q(| z(QJvio~{!iRQ;7@PIiVBT^u3}JS6AChYUcfYMU`aWg*y739u?Qeo2RE^dssoBjlvt z0?CYKfd}FjC>iu#2Xfgy4*(ak&4B_vO4r%`dF-$W@e_*(8de|c$LHGtI5xia*hih;fR^;0FsEw>W8w`ziJh^x3vJ3BIJ~C zL&lh&_&!p6qoSwrthcYOhH`Z(mtHQOvB+6j6K4ZluwV0d!yh6aG0P8P;NR3??a$O3Mrm-V@4QT zGTadGfGVJ2;OqAYp-QE}n*Ko`a;0%VLYz9)-b;xnF`|EslN+G+znZA&Icyya#S1`0 zNWChczMgVNoYkJaQ~gePPjPC=DuU==s@31VM{Q{=8c ze|Eo=6>epK*(sVyi9Et4=Db-sc_cfZ1Q~w`2yy;Ljj{I}kLy;|OomIJE+s-P1QX2` zt|-LF5O-L?%00Tf7ZiR*@cp^>esJ!rq)xB3#>>qxL3tm&b|vnH(5~FgCvG1p_Q!M` z7Lj*lV`f_5^rZxu-Z^_vo6*DcU)IzrRKtd{36$NNyC;tzJ#>I(Pw#EUmfnBMz#$BB z9P~Z9HjvpEhH+F}8siFAhrMq`3o=S57#K`V#b#Rhz1Gvu^ufDVvml5)I;7bK61zbO zSD$%HG#n1G6OJucY?CvrpI?}Wzxcb-Q_bW^Pk>M0+{Xo?dgT&`6e% zHa&tMW#48fA4sIdU6n|_x(XkmN_6j+Nx%4087U`XVe?1#zpa=v!hh_ z*9-Ao8?ZI3aI(&#fd71o|NhXEIGdL0LiU1Nx0B8%r2!ZGsC(+rkJ3u zO|tc+k$j^1T40Yap@p8hEWx3`l@Xru*yhI?}GpO>v{P~89> z6{xMaH8w&0=Dsl%s}%u7=eMA-f$yt?e1_1YKfM;Iqxt@QSPU5wAWZZnUxwh@Z{axt z0-&0O5DvdiP{{FVpvr*+Dd(ky2jd>%cj9z*?{=$JZ#kiJ*1+qlKuVm?hZIKeDk_xQ z!q*oA9M_0I(W2sV4+Rm2#o+d~!>i;m)eQ6#r~wewpIV;?E>2BP@92G9}`|n5siE2I!^wPt(aV=P?D4RSsA_Q6#p0 zeg^qIP#RG&({7URVNo4qHz0#(3O$KxIIR37ZfyS5nv2xQSatzp+aK_ZFFOLjCyz%f zCaKrwO5*veOy(|u^ZmuePo9v&)VU90F*S>Ej~e{7L}{N=5oK!gbDHVa7c!%ee$&i+GJ^x&;96gGs3B^Q zjqJMha`iKlSV8$U7h?t)J%Xrh z_>#XXEj2(DaVTULf?2OV_zU$yYd)cDvPOrGpX>Z^yQQb#x;TC}BZK?j^)4W9jB>VK~)1?pp+@xU^yBvWEViyNW)a>>w2vXoEVD1cI z(;q#0lw0F}(ih#S8P$Ihq|M|Y|A&NYxjE{PtJsYz?QgvMVRDCw_jdk-`}deR*JaWw zZtH?MtuBQ_?l4mA&J*9#;QRL*Z-Tr}Ii$Jn@J*T#*p2Cv8hv6S-@biI5VZ-0^>K%U zWwxkh=_YgBjSqK1A$Co_9mQ^G&-b;>Jx7oulk1)V2O5rrGGM`1R4zQ0Crp`l*hWWb&MZATs@$=LGsr%_>_f5e5qTm z+d?5O(OCvHUE1X>StYWx`f|`i8N-PI>rg*l+uO?x#oNpcG^y zq5C)!#FWu`1)!N7;qLyec(~PEOgxiWymis8JL9W58e1A@>EY@botQB8Xz$esr`j$w z-Ysj9E0Kpg4rT0ZDY|ZL%(Q(;B|(&U(2LKGW#aVa1|2e?e20eNZ79rqv>p%xWk<$4 z5zvVI^I7mzBDFG$;^;8Dq^5m;J8J36uJ8bCm20M9eunpIMs~`B{Re>5Cb00GZi2+O z;S`aJnpfOVY(omzCeX3(95@s((Yb$t8=%=PkAMI6I_rvPMfMfZ>;M=5AH1qE>uMwb zQyI5cPW2<`0H{vDtFq*MGGeXHSETxAAENc#e<~F`b?$w0>q#Y{Is+EDAN%*7g?m%n z7c%^JTF_Xbh9yB<5~u>G&IMZwt=~L>!}|*-`(wKqNf49>raG7EFNFCM#K6pVHF+Kl z806ffoTO@YxvO6*-rHmz7fdRo1o^rOd@6B9G4pC-2Jzusm)br$3xKOPE_ht=-oEd; z^}abHm?nAuW>vWSaE@oL-zsyuEmPD6-Pr*Fs>qW)Sr?@Qg^^*uwd?;vIGhb2?cU3} z^7`)zXq7YQx&FCV4SC|fH9=|DcMe^!OkKZQB*+D`?ejtYKELQug!5MuQl9;L)0E)3 zQ1_kNS#P^m&n*7W#f$I&F=~gd8QG;*G9LSj@Cc@B*KSnKm0U!|G(m1)f42C^m58Z7 zZ_Z7Gv@KF{L`6jIkAPDdp1HALmC4~Bb@XU~-Hm^++_JEjS6>ufUq@qgK^BEf9|6nl ziP;2i{0cwsm{GH|Zvu}1yYEiI*qJOv(A9eW)CFCwM}Y8lC!g_VOg3k}5RNIT>)ZbT z5VH_NU{<9@MJjCDqo5-KokC8SEdRI*UFO>xkQgdp{uvZpTsbZr1AJ}qhR@n{)HXS1 zVg3p}h&dG2cfHLtDnx&cX2YMZJJ<0Pn(;!JI9zcFUDrxd61sN2v#;1OJ;Nyy8ur6K z_WkuiQqH868AUChHTAXxjrEb4{dtwONyHuzGT$;o?)T^rW`ipRF zQnDzRFS>+0!S4_v3G?~WYcHAg{^ePrWJ??c)cfTNs+TG=y+bo$N1kf(#|+beoT zF+F+FZXNRb;cx*&)irB)DY7)~tBGe+RQC}&C-Npnel^OCkkgYL4oC`pDgrkqmODl7 zu1UH2OkY9K12TSrl+(UrA8)yhxZ}Ps?av1hDpK~%xJ;b?E>}>HMtAo9)w3bU(*$Fq zGxCI2Rhab=JxBuq5Fz)d0&={d0PoPJE-Ko$=?{HH&UDDNl0%0x)!0QVc(tRB9z!h> zF}Q?WZUg4au3sqEPJb?t6151#AP}CqGUSZw)=3=FQ55XWpFzQBnkGy^l3jNz-H-Qg z-G&^|p(5G=aPNDc3c_~*be*RE3d75UVOZkqKCQJRzdp2NK#d$-6qlpqN=bEiY~=2$ zuxldfK*rZWwJHS73een977-cAc=YI-jg^_9L7K+SV>)u_SyvPDoJOiOb#(fHKE5kE zy*$xaJ(rcpaPnk6^n#$Io6Lh>PicMVx&Lva*ts*gwmPU$A=wPCNR|QE%@3C4L;ldq zKV;h)mu?JlY7{(uj}}0o9}VjTT?z_{Q5&b~xigLfA5lmwgSjhI5AIl*`pi?fax{Bve&Nd+qAY;Cq)BDo6iDW z{DMh%ZK1Qv+ayVNNig*2A7ZM>*fcDNX$>^^FbO z#*9{D$o}5BaFxW1X8bjmTB%L*5qv+OVjMK<^g<;eGb?Kzl7oKn4!O6WmC6UR@ZSE@ z9G#T(yH39n&^n96-B*CIVuwwmb|*Y}{M zkGw8`I>4c?pzC(5#q^H%aOM4o40x+OWE*HFQy9B^UJV{=CM_*p4OH$b1oJ%TvKqy2 zK>kn&Sxuh*-tNB$Uc&R4>*W{Thj7FIg`Zy(0o)tv?Hib>h#>v_Rqg%+N-9FwAFQiG zr5f{{A5olk!Lev-$AV%SuPZ;rvEeb33^f~s^aKfA@5)C?J=I_aHo>bYzft!*evS^g zZ?B>*su{NkeA*3uMbiB}Xb9F;uHOinKi+as~ z5;f%WH~IA>x8<~+@QqC)sw5C$NSCaXGm)^g#5AO7a z{6tu1zV9k{4>W2bWe~XSX0+$Z`db&{#Q|TO1C(`VJ1YcoP9X9mT+*Ua_M>K63_UX9 zG^Gg|QLFG5OQ4PNpFa3ElPotZe7eyjd}C-pYNxD%ZW0*)Pcz| z8eD2J`_p`kyy3>`N(U}=@19df+^L|TeV%HW1HC-^R`1&Dp89&&;QJG;cnP6ab?Dc0 zuDk8jk@Ssg42DlbBJ$(JKg+bzrnkPdKlj^5;OaT^vrsH#V7xs-0#l+3m*Sx9>e!jM zA^*+Dp4uujIYh~4>N6#_T4v?I-I?m_!KIu4f^$- zoQq?^Wp__OHn+b_+b2jwPTaa3|3Y31S86v+!wGltDk0PdB6+ny9L{9UNpo_ap=s!# zN61Fj7WaAAl38(h%K_3GX+A{j+)@Kj!R1;h?%OO0y!&g6aO`wP>?P&QnUwuSqT{3r z2+Z5&8e6K?vL72oA)QhKCvqTfeKgXjC~Ce}+?f>YJ?9MR4NiXkh%;BkDw@M%BO>;f zhnxYfCP-=xR&`VR(JJhl0%rpqLUz@uNaFVMDtWID49E5topT(nwh`uV>B`erIXs6a zP^;ktP(=1^v<{~w9KY`}LiRy1)fMsvrN6{Uy^Dc86~%77hWpb&%0MEb`(KpnO2^7q zDX;}l+y~y=_ZReB8Ha`T(Fa`vTpVBNXfwY2M8Z7}UEBaf4Iy}DX7+K7HHAf+mhZ!P zl(5}hvf7I{?-cd(XNDTbK95RA4p)B@@1tqaq8dju6jV_3N5_C({!h}CjvKDd|0n5+ zT!f8WoZ_YWzmcxU#V5#YI7+(uwEaW6N{}D{36q<*{s2`tKkV9Jm@MPoQ{(9j4UFIC zA~GN|1TBQW{|W^d*kf}xWDK)z0mP`bj>1L;+H2&#IA{oc!YEZ(62!1OO%oj(YwH6mPtMYRpaEKL^MF)&_Uu{t4*UlINx@$q0BvYR6feNT z^MxrB8ryqvT(S>SaYI=i2fa46Rx?vmz1mc)B1+FgqpRI=H`W1h7R?MIKPvCrpzYEE z0xji=0s>k95)A!x3=E>)ym@0^Us)=#Kjg#avJ|Y$%(*G@KK5v_7XwU_Y3;-JI4baV zoKLGf$aN!7QW7HUJU}hz^{FF(CCvd8tln80N)`e~o7Pt;gB*b}G)lvwKojcZXzyC8 z-=nRYjxwTe;KsOp+5#hnx4K? z%MwQ4P(lY_5SBZ+*XMIn(}5`~7ux^2J1}&&SEwRv+pAK14c4m?umOw>4k9@6?D=zC zq>e%^G;?Qb$WCNY9!^h42c@4C+$rW|LZQC@xyapU8s%HS!{w$cl8vFl|0LCso=e8U8aJzbQ zJ|FN{0SduW703FVP6Ves{CQB)anJx((hzX5-}WE1A<+<8j=rl>BN7N?$Hm772EzW#?4s!Q|HxV=Y{r`d ziRVq2O^4Yjv+va`vF-4vb&j?4w#!`c?2un4V#c`VR{O1tp(jjk!l zJF{ocTXCk=Z#w~&&;;G>p>UwV3Ct!zRH$|tWIq}j&$@E; zdAVROn^ZU@|M|w7dIG$~`7>7&eqHu_2~OklLul>|g~X-d)veYo0l7c5A>QXNNLmvus@CmeY($haJRauG?kx2u;(G(|YQ)?9)6T1Bok!& zuyxOH$FC;`kG+Jp-a!Qi;_^fjv>Q6qXjwAC4BdHpcJFY-8>P8KwZ{i&Ue@m{I$v}j z{k}rz5P38+41VJz2)E8dsY|6NMsjYfud;lPL6L4PQu;_FL?X=u<52BT#yDDX!2 z84oWndfoz>o0r;luxsG{X|%n!U<_GRR6@f1B3j3g@mwl^S`l6p4Nu_x%I^*Q912VN z1lAN4A8)sy^G_n{xS#{p-mbXkix-8^gQh-^=G8S5;B*4e6OGy&IKr9(rI7_NF96Tl z_r68&XsWC0f_6#Ln{XHU`M(Ywzq`}Y))vhT{qeebZ97-)+`9E+w9)mY@E+FacVct> z-j1`R)36DIR$PWgq}%TwqRwV%ghB{rXv;2|l9iU0rpY+#P3?M-H&&e#zw`0{ExUPc z*bUi`kNm@Pq{Gutn}SMS*-odH9@I0+9fziM4ST!67q_Lzo52G0?(S?!di}j+@!%T| zTC0F;n-Ss`@D@Ys+-J`4i;L?=z_731u9tV*{j9HJxEh1k`L_hTl&$-X$_q{Wgq=fY z+Qq^iD>hgx9F5d)8m!FSnf7F~-QMjBqw!hLm04R`ll0buagF%z?Nracc>(7a@Zx9s zVFNu2gKqn7Xgh~;>d4@1gs~weGDC5Df)+iL0!Wv#`;)3ylP+9?jnYra;p>8{mUrZp zZDPUTj=ko_!@O88y8njPkt~1)n&>)HxS}84a95PzyFE$3B z18uweB?VADx>pylxW3tiBR)=zcZ|S+2JQfl(5_)S=oK6HSF&j6fL8ITAXf!+0xu-kupL;>Nc_@81D1PZw>Q2-BRyNU|D^Ja#? zO9*O|!vPSO7cXBLKv}5=v}_yAVE@-*1 zxc2Y!PmPVhFP&hR5&1!nV9tq&>B;-7^??7c+kyfyT22~TJry_(im$;c*X#sn1p~mu z4?eK3_+RR{Q9rb~8l-XMy6^9HoreBr?`+j=lmUYYxV4}I8`P-C9Cqb@@+#{bgwj$8 z81@waMx!=hjTvD5GBP`6(N=yNDp2v&gQxOr`0r6Sc0w%`)?W?-MfK>cHxBc*VuFHg zYwRgD_AB^>+)blsP#bDREN*YujA>^S6kGuSNP-#FZtAYjBOk0<_@`ey18(xMno0G? zqC6emLnNg!ZIWPd)`Ml0c5nNW0#jp**qex&4(h&gxt=5J{I zF?oM8=yghplZ4aokFV|R=b`aqQ4!kziBQJh51G()AU`oy%D$e%xdnm`3`7jQ*a<=UV_ zzWISHpt734bA=Yy?+)IDBYPa~DyL%1Uc!z#iFzA)Hnzpc+j`o4_3)0# z^bmG5j=Th8V7Bu869n=-+P45{R4-nlZ4}7&(oJLCEBHDSZ&a_0~@M#u{oSWYM0oo$U+0> zZookpp%B(ae?T#PQQmKN!*=qyK#|CNU!eoKIuDmWQig(f*2?pX2z4VWn4s2(w&YR^ zlb~)iKQs+>dMS|Zh63n%3&a>`?1mii_hG+JqKC=QA3J6P@RkKaL(S{dRA<>cjlncc zaIDZh5R@{b$Mc8+x@q3>1dngi5I~{|S;@<%PpME9Eea#}cXXaAt{3NFHN-|APSW5- zC0E5a;P@S8SbA}g*M`Ojnhh#6KEpzdwt04_qPPp_LN0#(?{`7vD(biWaLDw*?x0ko z$V=I!%Er6kVbMvFu)}?x1yh#VW>y_fkqVDo74Ga9d1gQzfKDmrXBemFWGG}3E-d{} z)ce0b3M>H58K}96BFUK)AO12_PJ#XdpscUe38V+eHe?&+XQ5us#s(;^rWi1{bKPnT zs9D(xs1bXpPa?bdpkCJ0(CC6FFg~&lyXUN=q#=lOJl}~SIoCmHd3FX4nCS$(HnFpb zqEnD_c~ArV zeF(Pq@*;j4rdq1VXy{|JC7Z admO2=?KGE7w$8(|5e*gX%SD%NJo;bl#D6*f literal 0 HcmV?d00001 diff --git a/results.html b/results.html new file mode 100644 index 000000000000..e0ab06e853b4 --- /dev/null +++ b/results.html @@ -0,0 +1,3 @@ + + + From ce4c4e82f1f46685bb36d4400d5be63c06dcc429 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 31 Dec 2024 16:25:12 +0000 Subject: [PATCH 03/21] removed all the benchmark files --- moe-gemm-benchmark.csv | 7 ------- moe-gemm-benchmark.png | Bin 20192 -> 0 bytes results.html | 3 --- 3 files changed, 10 deletions(-) delete mode 100644 moe-gemm-benchmark.csv delete mode 100644 moe-gemm-benchmark.png delete mode 100644 results.html diff --git a/moe-gemm-benchmark.csv b/moe-gemm-benchmark.csv deleted file mode 100644 index b37672c102a0..000000000000 --- a/moe-gemm-benchmark.csv +++ /dev/null @@ -1,7 +0,0 @@ -M,K,N,E,top_k,TFLOPS -64.000000,128.000000,256.000000,8.000000,2.000000,0.106786 -64.000000,1024.000000,1792.000000,8.000000,2.000000,5.040689 -1024.000000,4096.000000,7168.000000,8.000000,2.000000,250.097741 -4096.000000,4096.000000,7168.000000,8.000000,2.000000,394.569816 -1024.000000,4096.000000,14336.000000,8.000000,2.000000,313.915161 -4096.000000,4096.000000,14336.000000,8.000000,2.000000,421.260017 diff --git a/moe-gemm-benchmark.png b/moe-gemm-benchmark.png deleted file mode 100644 index b34d965cf98e9e5ffb216b24c6b7a5e8d8a4cfc1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 20192 zcmeIaXH*p3)-GBkK?MN?5hN)f0+K-_qexajvP2OT0SQWy&;%7Q0ZPt6$x$Vz21P+7 zH%T&xWT8|KxoYj^a?jcAp50AOkJ}Gi z?VO#&g+xV!M9y*Eb#rrZl@S(p{O>!2oFCWVjr@rMeYw1@`rc zTkXF_JFH2)eA-*%=^90-Nf4w#@=~%8d=r9ILC6urg_)ESK>}Ln5cr#UFhYeO<6`7W z@b}nL2q}V8Qj^fZk8V>1z$GE&|NkfduV?I2vJmI1=;&y>F)Sv}svdjk(xoIJN(3=@ zY<=Q;M~Cr)2Olmek4PU^MdS`|tu5+nYCiA24ND>Z&J%Ab>Co3YC&qda{Kp-z6ioek2n}!yY_kwKUQqh942jXXjec+#`J39 zdBa<`GP*OtAWo3p5abgO=q|FTVK{y~>({SaDyphte}8c{iWnFeWROuFx!Yf4VLwt` z=D@T^jszekTF;+9H!w2Fii%=TQB#w(IKtgqXl6`LPd`}aQ-S_1TbnIgixg>~rw;59 zQ;pv~JVOTij}ykf@ZY9b{OZ-KNNkGtJm1dtN{3fW-Kl(B_!x%N*m zd|Ae0O4|7@HI80GV~fnTM5+ASEio)&x1Pcj5T+BI^WdM~zkkR5{PMzmwllqC=Pj&9 z4m~mv9ToL!B<#fteRcJxj;Y+v$i1GNYnkQx(%ys|)lim9hoUM~|D6S^x-ES5f)LM{ zGxG@!MO{hKt`fdP7iTkH-`a?Xh^VNjjD9PBi@M503p8qAZ9nIYWMBv%4)NsCfR`t7KWXtkj)5db?tM9*w_R&KC(PX z_1~>J!X}du!tR};H`HrvRe0WUaCWg`tem3ZuiV?0Dk>_^pFO*F<;u(Hj?`iW%;s5s z{yZZGN~9+Or{^ac}TO2YT7%j4p#fvhs$$DG~mtTwf^yyPiv6W&_NQlGQ?*SG` z`=~o@37PEj-Z@|$ETT7`RJ+YR2(BYVda7}Hc8v!FmIbYxXni9#YW_xOWoQXU7gc^* zm#jue=xl$Dc#f4(RKaA~`xjfE=v7!-uivYK7r1Q`?@QX@F8l9ojn$Dje|>d|MegCd z7-6d-=lUI^!%RXwFPX&&74O7a{Jk4m9wnM@(xt~)D~+isz+uo`ASpv-xs+K^4g z?cJgem{oL2N)cEyd`Jdvy(d>+E!N7fN~nJOf&4CUvU;JYhVUy==O{wbz&AyfYWoC# zD-e&<%3K;b=&R6_S}q*)q=#s5kCR#CMgYtCdpUa>se8E$h~nlf*A8(q+uE%WW#g&18Ldz(zP9s#lM<#VR%HC}M6=jDt=&{=U z%M^M-^pfWorlKi~E&s<46G6*blgZ}r{?2q|$$LNE+1Kr*D~GsEhDyzc$u8-}3RxJ` z`(X;Nm#B*0Y3(fQw>0Y>AweESMX4|JT@S=B6z3oD#cU0dS((IjXnlD`k6aT_eR<-V z-PmZIkEGkIE@rjYI6=lePpInGQAP%a9QTDjZ$g)v``*s_P#+=}54&}BX~awzOMK(8 zG@SeM=S^^ucVX9f|9Y)vUoDG_Tk`PWJjbs0gDD72%?yP5 z*3zPHZl3hO)s^w&Ne2gkvu8WO+V|wVzIpSe#I3J%w{PDbz&dxN$fM@`=Iz_)SFh6kOg86dWMrs1OrCMr`?P`3 zEkA$W`fB}m2`+$Sny{$Cq!F&bFKBuzxL!3hcjCsO2VIJ4u_@Ts^>3WbpLp7)xEUTeaV!8v8Q|QzrQ#<*m+if*T&|u{g*lDcO4x5r?m;aGmb-^iun1MxvNl-VxuFcb+goTWjPBNznl` z0r=XLFznWnY@uacwVv6XJ6Wf>xEMJ&oQG25;&NJBui!k7^{ux^ zl$-l9iU=-UdIT$(zk|tPgJ6Eh<~kXoEpRT%sHn7WVr%+jZd!ximB=$|d?QELxbM8y zeM~BbEREC{ua9`9LInEN-RyNWK|*7)DRjjzIW{&MoUXo~m5IrTk7USjURv5;DyuII zv9E5V;vC0P8bVANUy-;h4wijyZ#N2Om#^K@)Y2-uHz2ly-`g!#*!nHb$jHdY%iDbZ zP+WBMRPwHtrY0`WF#jHYp-PK*S^vqkYu6yq=1q(u$Ro{*fzqFh6_%rGeAOeEiQjeR z${p0kAL~%uLVl0c*Fzw?HRQRqSdnj0asIQfuseN+e8>FEq?dWlfKwYP zC&xzeor}WmS$=Pi*<6qPL?%{5EDP8^RtV5g#(nTq;74^~@z{NSi-t0b2VYY_8lJH!It9yh zNctL5U^BS4Bz?)T$7Fp?2H#yYtsnKq?A86c9=jeYl`-VxT7>0<*?lKt>H9||C^*<2 z2FZW!Tu(zF+3KJ}vDZ-Jr%$CfK9YvA%bVKS+2t7*r>KUp=JM%hT`-YE8t@NEtv2Yf zZ%%KCrZmiWmDwqe=YJK+u&DWc(Yf(}T}R3>{|*(V`0xVd(8IItNOF_RC4}xt7>88x z&_nr&JXT`6bhE-I&rz?XBfbuUx~CQv z7uexSGnQu zW6ZYK%)nzy7-Z?gFN!6@=5ueuTVG>46iPy{?a76m{pF5=ZighF!KW=dtA(ix7?}h* zwtg*B6XY@2w&;#K9Y(PFn^!P!)Mrr5y0r&e?JUPex!&7WnTf|4uww@9*$ED{cnO8 z*B;3)vt77BQ9FNS!ehzY>rjQ?j`-GZF13ts#qvv57Xy3yvO*vYu(tIrZIci6g9`T?`;6X@YsBr<>*vgVWursW}0~4Ze{xWVZPM6hbr71 zT<4-n;$ebpnlpk9w{=$B2;3-Z_L{lgU1L#mbkiLDqzkfDRk*)Xef!>Fr5&s>ih=S_ zAj@F8V@B}Qf~F?}O7BQ?i1&7Y#rk(ZWU zqsO{0@P{ND75?sDrcfT2jpQ4CeweSq&UH<{l2>2%#%P(WZI~C6^9FshYe$*trZ<6; zP{j&6)8&D*^!~)yo>;Sy3+yK2D|JCYFL+Ux zE9F$+du3~F&4CyZ;cNcOPR8TPU}-Lx$p*zs`Ae5t9T@TxKCRmF@Jw$ch^ql0T#@tU zKBK7|9G8y&`6Bb3Nb~p0F0f=uY0?inVTv5{e6NHYfyW&^hJnX@Z7tY7H+<~sL|>uV zDuYv>{IN`YKnf3^EK#ao=-%63~kSKS`gOF^DRs$6w*b8WR_&W@bN4~I02v0k`w zKpp~RqLi^hmW`UvxR(0z{nyzpguAf(&Eduh=ea-E9?bOZTo~-#1J9LJ+T&~boR<(y zTNTQpAvJs>7^@}BL-AOo<_^$gEEG#AkKz5) zW*=LD2UBuv`Tl`)4427!#CIy&?WEM&{nw1F(t2KpJiD`CqbVy3j0K&XMK`O_8O&B8 z^X7CrK;M$VpOcvjMdHdC!cwTxJ!%9gsR%syStQ8(Ha8A34zW_=5w)D1?80&1PGbg$2RHcAmF%^qdP7W zbv8ED9%WXI{W-70UqoRJRcs5$*33^Po-dRO@)`B|C<*Q-_*9|k?NGwvh(p zM(#{Sib6d4L)6Nqb%WHtuG8v1YbK7OqDHePMH+kySn1Nx@8aKD8_0IkVL)-=`Fpd8 zgVc@G8H{SH9=~(E4^>#om-PMZu5^m~B+uOVs;>zAncwb!{flR>i>jKaF=UiSH|Zqc zF}kl|!>AJdh#yDUkKR0x{Wq}q;^V<$$gCg&{yls+Ha5w8l|iT@LQ;9G+WpRKqPW^% zS&MzXu}<}-y*Pn}<(2}5oC>91iL1sw<6D2!e09SWTud!^h1weIyw>)wD z?SVXq@SMb*zb_cqV$Qt}LB1XGP2n4PAzYN)V61o3Qqe^8H=CTv>YpnUDa?%`9H}!0 zF;m&g3q7w7(4JQw8~^x3q4GEhozn6^O+UD6zdRi=GJ|mTwFV|8;}XeVGeh`Kt5+tD zA2(7sf39Gj&)omvrW)tKhcAaPQ?WmvJ&!1=#ZhBmOrt+YDj|F*jStoI!*F8M6?bkq z)6ztHuay=}_W_-B^E1OYlT(iD&@A)=9dz zFgUdhVMnTcJiE_0_+Nffr7-r+dfFQ!DIZNs&IjIXZ)K$eJmGiYR)P4yh3%EaK>U?# zB6VY}KxJNyOTJ#cAAHw2I`b93?JkzNp~|j74ti^aXOm4Q+jw1j`q&1Q=)mv0jCLpv z6;BM+4UJmr2NqZUNKzPnm+q}!*_MDhjIKv~!-J(ZA(FTXzq&F-xnbYwrFYhMW>>y^ z8ruK;vS8yoN^GVUP%up9Hz=@&h2?e{ViQq!eC^tl4Qy?QJQepRyMME`f7MaEo{q1< z7wlMD>v>l*Zx+?GdxrOunFkUaK1ClfJ%stX5(!Y^^0j#Wv zE_`!UAbj3^wt=CjcyaJ{5Vp*r5Pi_z?lfU(wO=cgMcSB)?lbHqd!r3VEMW61o{=-3 z?`<)S$*cz75o8ATreMDMhmQ|3hN>B{poOTBn!?Ts@3JyQ34_fMJb=`NV%uMg%OtSU zB{nCf@0x8ithWfS*?gyj^&ZPjX?T_BV39wwf&Z4}BI}kT7t}KQdS&hR>rtN$YK(xf z>6Wr`TPlUl>}1uK03E;yN$iax_*a68nP#HPDwltdS(y$#BTKc8<9Y%ff?sI$=SPO3 zJ=88SV}1^cvyRZkq?6YY9`l%xVy=t9-sz>by8RpTvDdU`$lDA+%n#YWxp z7j4;+2Gl;lC$8FNXAP*Fv7CW402wFjB*uJUlQlh%lIpY@)u9BoV5O!SPVNAn<_2;y{G%zvSaZD_9Hy`p1wzixxRG4-A;qsGj_{R^3 z7H{9!r@Q9d-&#*h2fuE&@nB`!$K3b5be-(2ryD2a5Wg1)Rb{BPy5jZ?2ormLTcWRD z_fe(tF83BXgpE3m*zf)o<{&cZgf$4%&ptkxiBp;1IlATOr~6z zwgebtK=&nt{E-E^`nT^4hCh2Lk$qo2u-MaiCB7nsZ`gmgJK>X{P|-H?#KiPu=Xaek zv7nPbK9pqADu^gUywkfL-`kfR&b$feDs~WtNEM*;PL$ei%fn|JLL$d8#@iHFET+w& zwA@)rd9+11aj%fLZ4q8s`Tec;;*faIN$Ml4@TB1?oUzq_vK4;RXjbd0dyy(42_TwW z1}02R$e(8H87X{o!a|{n%Js1ooE9(av(efIKO9Dr_=ayhSWc60PLal~_ZmZD*z56; z0&=-^^JYzQ-4*zrG*RjmncBbo%eQTZ!hzDY@7K2PRC*pRmHn{7G4f7(f%hS)({JW$ zTuw^lERHrD3^}5)G6Pfp=cdo>!az4QMr!nI+1&|~XNWx$3Ei?MFu%M6Luz04^mLlSZPl2OE(;MA7 zyr4jdBGdY?t3XIyk3p8p$j^$g3|d7Awaw##A3}fzF6ivfi?-n*bvXbNYkECiep%4W zXs2O@p5D6ta6il)hgEm)r?IJD>KvF`!PpMP3Y`tlxtl@Q^Ny?n^Z_vTfa2)r_?Fy< zs_QXAlER;hg|b|(U9(JK+*H_V-7Ss&C$lUB^IR4nS4u07ru79BIIFvgtx$RK-CCF@ zh|e5g<$vXX%2v%`Fq*N^Di#u=AStJA_xJA^ZHYiEY`c3IK70B~#{C0josdP`ryvVI zZ0U_a=FQ_m5=;wgcgIE#Vc2A4z}y0q5b}1T>+z*_kr>VOTWH8eEn!eJ=tUlz^l8lY zd*)4rjm!LJ-wh6yM+ehZT;kwrM9{F96dG|f*T>WKb1}c5S%Q>}NYIP!`-B{M!hR0t z zNjZLw9tkYDR#8Og)|rlEIDI;)=s+MOd@H6{Y1h`1GGS~bgQn%BS)MiSTm+JwYvfAf z_kw~GgwtezVZjfsKEb|KD=a@QSnwJPg#Zjc>_i+BlBQ`a5a z5@6fbdDCwq7oNl4Y~SXnZ~i>QYW~Ky9%NDU ztJT!h^c`VD0&^;V)V`h!lg;>i*>3s9T$@zea%|E$pce_N{kRl)+zVk%JG&=}h#wE9 z8j>jKkTutxoevDK7`3u9b|b5(Zq2mNtU?>DVUW?vrG+c*nz+ui1O1wWmN8B`RbCV@ zDtN^WbyEA~@sIFUoYa`_+nufTg+4Q9!^@YSq+r&rhxkx72s1+N}uqciM2DpB-}_Zy-@IT6t~m<>dvetG<&{(M=m0heBtBRZs6U zW;GfNli0BPv>xVSR#zNyx7Q5Er_GZQ?Zfr;WNzf|YJp^LOO`cO3ui|e#J(bne6%i- zo$V8D*I$&Zr>|c}BErR`oGk0P7Obeja=|4{HJrUr(5yViXMM@g&&b+3{ZNa%uYB<~ z#VGPfS>xG;oUP8R?88obE%jHy)Ih}W#L{gdBd(PC3~JvREVWCJ^Lo(cK_0N8shM%RrJ&PuJOIu4 z7$INdA&^iVA~}fwQJ)WGZ~-YPV?5+LC=r>d727v&tNkBjvgythq%5Pb5bNe|?EdrNF}ppbF*&(ANT<^RbbL#g1q(7ONji{pZNuqS}l9bXwF z>3Q=(%ryf)>eKcmPPop%dr?w+Tkg@-gI0Dj`UhREH*>q#` z`$4dj(W$A$6{Fr-S~2IkY*R5=P&6Fznn`Ana!Nuge^8+t@xj*tf2|KD$E`tjfLa#_ zYqu7`io37{zqhvb*o4bxiP|)Ye4-?O{6UP5tC4li2QP0orjReFU|GHE?V3^S@6UJo z(|a+fHj*$SB_$>93xyT-d%HWRc|wVgFkqWH2Q~Lt?mFVn7l&wQdS9JVHZU>CffB|% z)cjWb{7p>aj-NA0>w2puASifsb-uUz>uWB_+us?otI5ajMMrCl`3bZoJ76Jgv48O* z63`_zW^PVTGS-_k;0B|_0wNRW{sakO4laF=DUFlzECSNg$n0K$JVCUYcD$ar!pHLAzDEtcJZd@rDr65$1RKfVXqU`f+{Hdy0n2HD_Rr*U?et2?Pu z0QkaX9f*nR$;`y**4JpBhvsL~ERdhEp6xAMFo{R#o&rGK4zX5W=r>EpXQvjF&)%Jz z&u{DW{KEJ|0l9D(5`=a!YN$IqczZS<$rqF{t)Ax>AgHLx8uKFL^nl{V#HsBdC&<7z za}+|dmHF{8{CX;S+D#ZqQw7XP47OG`D>q4Z-tx#V)MgX^O>^>2aDCi~nR#-A3BEDM z0{+@v1LXJC+c&?{i~TS&_T>ATog6Q?$b}j6cU}yCDN#HHJXLLZN9N&oqygP1u5Mae zG@2c%9T?SlOH*Rp*ach6-gU}ALo@j z06!0z^IAwyd^*1AB#;b|ERqknYrx8{g{nSN zW0Zr{H8jvyEK8Ge(%oCb$NmdW3o-{+DRN(~Cf;=Q_WVl6Cp|DALxP0CKt6Jc13q(| z(QJvio~{!iRQ;7@PIiVBT^u3}JS6AChYUcfYMU`aWg*y739u?Qeo2RE^dssoBjlvt z0?CYKfd}FjC>iu#2Xfgy4*(ak&4B_vO4r%`dF-$W@e_*(8de|c$LHGtI5xia*hih;fR^;0FsEw>W8w`ziJh^x3vJ3BIJ~C zL&lh&_&!p6qoSwrthcYOhH`Z(mtHQOvB+6j6K4ZluwV0d!yh6aG0P8P;NR3??a$O3Mrm-V@4QT zGTadGfGVJ2;OqAYp-QE}n*Ko`a;0%VLYz9)-b;xnF`|EslN+G+znZA&Icyya#S1`0 zNWChczMgVNoYkJaQ~gePPjPC=DuU==s@31VM{Q{=8c ze|Eo=6>epK*(sVyi9Et4=Db-sc_cfZ1Q~w`2yy;Ljj{I}kLy;|OomIJE+s-P1QX2` zt|-LF5O-L?%00Tf7ZiR*@cp^>esJ!rq)xB3#>>qxL3tm&b|vnH(5~FgCvG1p_Q!M` z7Lj*lV`f_5^rZxu-Z^_vo6*DcU)IzrRKtd{36$NNyC;tzJ#>I(Pw#EUmfnBMz#$BB z9P~Z9HjvpEhH+F}8siFAhrMq`3o=S57#K`V#b#Rhz1Gvu^ufDVvml5)I;7bK61zbO zSD$%HG#n1G6OJucY?CvrpI?}Wzxcb-Q_bW^Pk>M0+{Xo?dgT&`6e% zHa&tMW#48fA4sIdU6n|_x(XkmN_6j+Nx%4087U`XVe?1#zpa=v!hh_ z*9-Ao8?ZI3aI(&#fd71o|NhXEIGdL0LiU1Nx0B8%r2!ZGsC(+rkJ3u zO|tc+k$j^1T40Yap@p8hEWx3`l@Xru*yhI?}GpO>v{P~89> z6{xMaH8w&0=Dsl%s}%u7=eMA-f$yt?e1_1YKfM;Iqxt@QSPU5wAWZZnUxwh@Z{axt z0-&0O5DvdiP{{FVpvr*+Dd(ky2jd>%cj9z*?{=$JZ#kiJ*1+qlKuVm?hZIKeDk_xQ z!q*oA9M_0I(W2sV4+Rm2#o+d~!>i;m)eQ6#r~wewpIV;?E>2BP@92G9}`|n5siE2I!^wPt(aV=P?D4RSsA_Q6#p0 zeg^qIP#RG&({7URVNo4qHz0#(3O$KxIIR37ZfyS5nv2xQSatzp+aK_ZFFOLjCyz%f zCaKrwO5*veOy(|u^ZmuePo9v&)VU90F*S>Ej~e{7L}{N=5oK!gbDHVa7c!%ee$&i+GJ^x&;96gGs3B^Q zjqJMha`iKlSV8$U7h?t)J%Xrh z_>#XXEj2(DaVTULf?2OV_zU$yYd)cDvPOrGpX>Z^yQQb#x;TC}BZK?j^)4W9jB>VK~)1?pp+@xU^yBvWEViyNW)a>>w2vXoEVD1cI z(;q#0lw0F}(ih#S8P$Ihq|M|Y|A&NYxjE{PtJsYz?QgvMVRDCw_jdk-`}deR*JaWw zZtH?MtuBQ_?l4mA&J*9#;QRL*Z-Tr}Ii$Jn@J*T#*p2Cv8hv6S-@biI5VZ-0^>K%U zWwxkh=_YgBjSqK1A$Co_9mQ^G&-b;>Jx7oulk1)V2O5rrGGM`1R4zQ0Crp`l*hWWb&MZATs@$=LGsr%_>_f5e5qTm z+d?5O(OCvHUE1X>StYWx`f|`i8N-PI>rg*l+uO?x#oNpcG^y zq5C)!#FWu`1)!N7;qLyec(~PEOgxiWymis8JL9W58e1A@>EY@botQB8Xz$esr`j$w z-Ysj9E0Kpg4rT0ZDY|ZL%(Q(;B|(&U(2LKGW#aVa1|2e?e20eNZ79rqv>p%xWk<$4 z5zvVI^I7mzBDFG$;^;8Dq^5m;J8J36uJ8bCm20M9eunpIMs~`B{Re>5Cb00GZi2+O z;S`aJnpfOVY(omzCeX3(95@s((Yb$t8=%=PkAMI6I_rvPMfMfZ>;M=5AH1qE>uMwb zQyI5cPW2<`0H{vDtFq*MGGeXHSETxAAENc#e<~F`b?$w0>q#Y{Is+EDAN%*7g?m%n z7c%^JTF_Xbh9yB<5~u>G&IMZwt=~L>!}|*-`(wKqNf49>raG7EFNFCM#K6pVHF+Kl z806ffoTO@YxvO6*-rHmz7fdRo1o^rOd@6B9G4pC-2Jzusm)br$3xKOPE_ht=-oEd; z^}abHm?nAuW>vWSaE@oL-zsyuEmPD6-Pr*Fs>qW)Sr?@Qg^^*uwd?;vIGhb2?cU3} z^7`)zXq7YQx&FCV4SC|fH9=|DcMe^!OkKZQB*+D`?ejtYKELQug!5MuQl9;L)0E)3 zQ1_kNS#P^m&n*7W#f$I&F=~gd8QG;*G9LSj@Cc@B*KSnKm0U!|G(m1)f42C^m58Z7 zZ_Z7Gv@KF{L`6jIkAPDdp1HALmC4~Bb@XU~-Hm^++_JEjS6>ufUq@qgK^BEf9|6nl ziP;2i{0cwsm{GH|Zvu}1yYEiI*qJOv(A9eW)CFCwM}Y8lC!g_VOg3k}5RNIT>)ZbT z5VH_NU{<9@MJjCDqo5-KokC8SEdRI*UFO>xkQgdp{uvZpTsbZr1AJ}qhR@n{)HXS1 zVg3p}h&dG2cfHLtDnx&cX2YMZJJ<0Pn(;!JI9zcFUDrxd61sN2v#;1OJ;Nyy8ur6K z_WkuiQqH868AUChHTAXxjrEb4{dtwONyHuzGT$;o?)T^rW`ipRF zQnDzRFS>+0!S4_v3G?~WYcHAg{^ePrWJ??c)cfTNs+TG=y+bo$N1kf(#|+beoT zF+F+FZXNRb;cx*&)irB)DY7)~tBGe+RQC}&C-Npnel^OCkkgYL4oC`pDgrkqmODl7 zu1UH2OkY9K12TSrl+(UrA8)yhxZ}Ps?av1hDpK~%xJ;b?E>}>HMtAo9)w3bU(*$Fq zGxCI2Rhab=JxBuq5Fz)d0&={d0PoPJE-Ko$=?{HH&UDDNl0%0x)!0QVc(tRB9z!h> zF}Q?WZUg4au3sqEPJb?t6151#AP}CqGUSZw)=3=FQ55XWpFzQBnkGy^l3jNz-H-Qg z-G&^|p(5G=aPNDc3c_~*be*RE3d75UVOZkqKCQJRzdp2NK#d$-6qlpqN=bEiY~=2$ zuxldfK*rZWwJHS73een977-cAc=YI-jg^_9L7K+SV>)u_SyvPDoJOiOb#(fHKE5kE zy*$xaJ(rcpaPnk6^n#$Io6Lh>PicMVx&Lva*ts*gwmPU$A=wPCNR|QE%@3C4L;ldq zKV;h)mu?JlY7{(uj}}0o9}VjTT?z_{Q5&b~xigLfA5lmwgSjhI5AIl*`pi?fax{Bve&Nd+qAY;Cq)BDo6iDW z{DMh%ZK1Qv+ayVNNig*2A7ZM>*fcDNX$>^^FbO z#*9{D$o}5BaFxW1X8bjmTB%L*5qv+OVjMK<^g<;eGb?Kzl7oKn4!O6WmC6UR@ZSE@ z9G#T(yH39n&^n96-B*CIVuwwmb|*Y}{M zkGw8`I>4c?pzC(5#q^H%aOM4o40x+OWE*HFQy9B^UJV{=CM_*p4OH$b1oJ%TvKqy2 zK>kn&Sxuh*-tNB$Uc&R4>*W{Thj7FIg`Zy(0o)tv?Hib>h#>v_Rqg%+N-9FwAFQiG zr5f{{A5olk!Lev-$AV%SuPZ;rvEeb33^f~s^aKfA@5)C?J=I_aHo>bYzft!*evS^g zZ?B>*su{NkeA*3uMbiB}Xb9F;uHOinKi+as~ z5;f%WH~IA>x8<~+@QqC)sw5C$NSCaXGm)^g#5AO7a z{6tu1zV9k{4>W2bWe~XSX0+$Z`db&{#Q|TO1C(`VJ1YcoP9X9mT+*Ua_M>K63_UX9 zG^Gg|QLFG5OQ4PNpFa3ElPotZe7eyjd}C-pYNxD%ZW0*)Pcz| z8eD2J`_p`kyy3>`N(U}=@19df+^L|TeV%HW1HC-^R`1&Dp89&&;QJG;cnP6ab?Dc0 zuDk8jk@Ssg42DlbBJ$(JKg+bzrnkPdKlj^5;OaT^vrsH#V7xs-0#l+3m*Sx9>e!jM zA^*+Dp4uujIYh~4>N6#_T4v?I-I?m_!KIu4f^$- zoQq?^Wp__OHn+b_+b2jwPTaa3|3Y31S86v+!wGltDk0PdB6+ny9L{9UNpo_ap=s!# zN61Fj7WaAAl38(h%K_3GX+A{j+)@Kj!R1;h?%OO0y!&g6aO`wP>?P&QnUwuSqT{3r z2+Z5&8e6K?vL72oA)QhKCvqTfeKgXjC~Ce}+?f>YJ?9MR4NiXkh%;BkDw@M%BO>;f zhnxYfCP-=xR&`VR(JJhl0%rpqLUz@uNaFVMDtWID49E5topT(nwh`uV>B`erIXs6a zP^;ktP(=1^v<{~w9KY`}LiRy1)fMsvrN6{Uy^Dc86~%77hWpb&%0MEb`(KpnO2^7q zDX;}l+y~y=_ZReB8Ha`T(Fa`vTpVBNXfwY2M8Z7}UEBaf4Iy}DX7+K7HHAf+mhZ!P zl(5}hvf7I{?-cd(XNDTbK95RA4p)B@@1tqaq8dju6jV_3N5_C({!h}CjvKDd|0n5+ zT!f8WoZ_YWzmcxU#V5#YI7+(uwEaW6N{}D{36q<*{s2`tKkV9Jm@MPoQ{(9j4UFIC zA~GN|1TBQW{|W^d*kf}xWDK)z0mP`bj>1L;+H2&#IA{oc!YEZ(62!1OO%oj(YwH6mPtMYRpaEKL^MF)&_Uu{t4*UlINx@$q0BvYR6feNT z^MxrB8ryqvT(S>SaYI=i2fa46Rx?vmz1mc)B1+FgqpRI=H`W1h7R?MIKPvCrpzYEE z0xji=0s>k95)A!x3=E>)ym@0^Us)=#Kjg#avJ|Y$%(*G@KK5v_7XwU_Y3;-JI4baV zoKLGf$aN!7QW7HUJU}hz^{FF(CCvd8tln80N)`e~o7Pt;gB*b}G)lvwKojcZXzyC8 z-=nRYjxwTe;KsOp+5#hnx4K? z%MwQ4P(lY_5SBZ+*XMIn(}5`~7ux^2J1}&&SEwRv+pAK14c4m?umOw>4k9@6?D=zC zq>e%^G;?Qb$WCNY9!^h42c@4C+$rW|LZQC@xyapU8s%HS!{w$cl8vFl|0LCso=e8U8aJzbQ zJ|FN{0SduW703FVP6Ves{CQB)anJx((hzX5-}WE1A<+<8j=rl>BN7N?$Hm772EzW#?4s!Q|HxV=Y{r`d ziRVq2O^4Yjv+va`vF-4vb&j?4w#!`c?2un4V#c`VR{O1tp(jjk!l zJF{ocTXCk=Z#w~&&;;G>p>UwV3Ct!zRH$|tWIq}j&$@E; zdAVROn^ZU@|M|w7dIG$~`7>7&eqHu_2~OklLul>|g~X-d)veYo0l7c5A>QXNNLmvus@CmeY($haJRauG?kx2u;(G(|YQ)?9)6T1Bok!& zuyxOH$FC;`kG+Jp-a!Qi;_^fjv>Q6qXjwAC4BdHpcJFY-8>P8KwZ{i&Ue@m{I$v}j z{k}rz5P38+41VJz2)E8dsY|6NMsjYfud;lPL6L4PQu;_FL?X=u<52BT#yDDX!2 z84oWndfoz>o0r;luxsG{X|%n!U<_GRR6@f1B3j3g@mwl^S`l6p4Nu_x%I^*Q912VN z1lAN4A8)sy^G_n{xS#{p-mbXkix-8^gQh-^=G8S5;B*4e6OGy&IKr9(rI7_NF96Tl z_r68&XsWC0f_6#Ln{XHU`M(Ywzq`}Y))vhT{qeebZ97-)+`9E+w9)mY@E+FacVct> z-j1`R)36DIR$PWgq}%TwqRwV%ghB{rXv;2|l9iU0rpY+#P3?M-H&&e#zw`0{ExUPc z*bUi`kNm@Pq{Gutn}SMS*-odH9@I0+9fziM4ST!67q_Lzo52G0?(S?!di}j+@!%T| zTC0F;n-Ss`@D@Ys+-J`4i;L?=z_731u9tV*{j9HJxEh1k`L_hTl&$-X$_q{Wgq=fY z+Qq^iD>hgx9F5d)8m!FSnf7F~-QMjBqw!hLm04R`ll0buagF%z?Nracc>(7a@Zx9s zVFNu2gKqn7Xgh~;>d4@1gs~weGDC5Df)+iL0!Wv#`;)3ylP+9?jnYra;p>8{mUrZp zZDPUTj=ko_!@O88y8njPkt~1)n&>)HxS}84a95PzyFE$3B z18uweB?VADx>pylxW3tiBR)=zcZ|S+2JQfl(5_)S=oK6HSF&j6fL8ITAXf!+0xu-kupL;>Nc_@81D1PZw>Q2-BRyNU|D^Ja#? zO9*O|!vPSO7cXBLKv}5=v}_yAVE@-*1 zxc2Y!PmPVhFP&hR5&1!nV9tq&>B;-7^??7c+kyfyT22~TJry_(im$;c*X#sn1p~mu z4?eK3_+RR{Q9rb~8l-XMy6^9HoreBr?`+j=lmUYYxV4}I8`P-C9Cqb@@+#{bgwj$8 z81@waMx!=hjTvD5GBP`6(N=yNDp2v&gQxOr`0r6Sc0w%`)?W?-MfK>cHxBc*VuFHg zYwRgD_AB^>+)blsP#bDREN*YujA>^S6kGuSNP-#FZtAYjBOk0<_@`ey18(xMno0G? zqC6emLnNg!ZIWPd)`Ml0c5nNW0#jp**qex&4(h&gxt=5J{I zF?oM8=yghplZ4aokFV|R=b`aqQ4!kziBQJh51G()AU`oy%D$e%xdnm`3`7jQ*a<=UV_ zzWISHpt734bA=Yy?+)IDBYPa~DyL%1Uc!z#iFzA)Hnzpc+j`o4_3)0# z^bmG5j=Th8V7Bu869n=-+P45{R4-nlZ4}7&(oJLCEBHDSZ&a_0~@M#u{oSWYM0oo$U+0> zZookpp%B(ae?T#PQQmKN!*=qyK#|CNU!eoKIuDmWQig(f*2?pX2z4VWn4s2(w&YR^ zlb~)iKQs+>dMS|Zh63n%3&a>`?1mii_hG+JqKC=QA3J6P@RkKaL(S{dRA<>cjlncc zaIDZh5R@{b$Mc8+x@q3>1dngi5I~{|S;@<%PpME9Eea#}cXXaAt{3NFHN-|APSW5- zC0E5a;P@S8SbA}g*M`Ojnhh#6KEpzdwt04_qPPp_LN0#(?{`7vD(biWaLDw*?x0ko z$V=I!%Er6kVbMvFu)}?x1yh#VW>y_fkqVDo74Ga9d1gQzfKDmrXBemFWGG}3E-d{} z)ce0b3M>H58K}96BFUK)AO12_PJ#XdpscUe38V+eHe?&+XQ5us#s(;^rWi1{bKPnT zs9D(xs1bXpPa?bdpkCJ0(CC6FFg~&lyXUN=q#=lOJl}~SIoCmHd3FX4nCS$(HnFpb zqEnD_c~ArV zeF(Pq@*;j4rdq1VXy{|JC7Z admO2=?KGE7w$8(|5e*gX%SD%NJo;bl#D6*f diff --git a/results.html b/results.html deleted file mode 100644 index e0ab06e853b4..000000000000 --- a/results.html +++ /dev/null @@ -1,3 +0,0 @@ - - - From 826cf5d7ac331ace8fe3b4332cbecce25923302a Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 31 Dec 2024 16:30:28 +0000 Subject: [PATCH 04/21] updated readme --- python/perf-kernels/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index 91283a129b0c..8df2830f958c 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -101,4 +101,4 @@ Kernel that implements RMS Norm over a row of tensor. Kernel that implements Layer Normalization over a row on tensor ## `fused_moe/moe-gemm.py` -Kernel that implements moe gemm +Kernel that implements moe gemm. You can tune the gemm config in the benchmark with the -tune option From c06c61ec41840271125841a12dc49d435c6da77b Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 3 Jan 2025 16:29:14 +0000 Subject: [PATCH 05/21] remove the -tune option, and consolidated the config files --- python/perf-kernels/README.md | 2 +- .../perf-kernels/fused_moe/benchmark_utils.py | 211 ------------------ ...14336,device_name=AMD_Instinct_MI300X.json | 200 ----------------- ...=1792,device_name=AMD_Instinct_MI300X.json | 200 ----------------- ...=3584,device_name=AMD_Instinct_MI300X.json | 200 ----------------- ...=7168,device_name=AMD_Instinct_MI300X.json | 200 ----------------- .../device_name=AMD_Instinct_MI300X.json | 24 ++ python/perf-kernels/fused_moe/moe-gemm.py | 75 +++---- 8 files changed, 58 insertions(+), 1054 deletions(-) delete mode 100644 python/perf-kernels/fused_moe/benchmark_utils.py delete mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json delete mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json delete mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json delete mode 100644 python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json create mode 100644 python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json diff --git a/python/perf-kernels/README.md b/python/perf-kernels/README.md index 8df2830f958c..1dbe1ddf2e58 100644 --- a/python/perf-kernels/README.md +++ b/python/perf-kernels/README.md @@ -101,4 +101,4 @@ Kernel that implements RMS Norm over a row of tensor. Kernel that implements Layer Normalization over a row on tensor ## `fused_moe/moe-gemm.py` -Kernel that implements moe gemm. You can tune the gemm config in the benchmark with the -tune option +Kernel that implements moe gemm. diff --git a/python/perf-kernels/fused_moe/benchmark_utils.py b/python/perf-kernels/fused_moe/benchmark_utils.py deleted file mode 100644 index 907e5c836b86..000000000000 --- a/python/perf-kernels/fused_moe/benchmark_utils.py +++ /dev/null @@ -1,211 +0,0 @@ -from typing import TypedDict, List, Optional -from itertools import product -import json -import torch -import os - - -def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, - use_fp8_w8a8: Optional[bool] = False): - if use_fp8_w8a8: - return "fp8_w8a8" - elif use_int8_w8a16: - return "int8_w8a16" - elif dtype == torch.float: - # avoiding cases where kernel fails when float32 MoE - # use fp16/bfloat16 configs - return "float32" - return None - - -def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: - device_name = torch.cuda.get_device_name(0).replace(" ", "_") - dtype_selector = "" if not dtype else f",dtype={dtype}" - return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" - - -class BenchmarkConfig(TypedDict): - BLOCK_SIZE_M: int - BLOCK_SIZE_N: int - BLOCK_SIZE_K: int - GROUP_SIZE_M: int - num_warps: int - num_stages: int - - -def need_split_k(SIZE_M, SIZE_N, SIZE_K): - return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 - - -def get_tuning_space(use_fp16): - block_mn_range = [16, 32, 64, 128, 256] - block_k_range = [16, 32, 64, 128, 256] - if not use_fp16: - block_k_range.remove(16) # BLOCK_K=16 not supported for fp8 - num_warps_range = [1, 2, 4, 8] - group_m_range = [1, 4, 8, 16, 32] - num_stage_range = [1, 2] - waves_per_eu_range = [0] - matrix_instr_nonkdim_range = [16, 32] if use_fp16 else [] - kpack_range = [1, 2] if use_fp16 else [] - - param_ranges = { - "BLOCK_SIZE_M": block_mn_range, - "BLOCK_SIZE_N": block_mn_range, - "BLOCK_SIZE_K": block_k_range, - "GROUP_SIZE_M": group_m_range, - "num_warps": num_warps_range, - "num_stages": num_stage_range, - "waves_per_eu": waves_per_eu_range, - } - if use_fp16: - param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range - param_ranges["kpack"] = kpack_range - - return param_ranges - - -def prune_configs(M, N, K, configs, is_fp16=True): - pruned_configs = [] - elemBytes_a = 2 if is_fp16 else 1 - elemBytes_b = 2 if is_fp16 else 1 - - mfma = 16 if M < 32 or N < 32 else 32 - - # TODO (zhanglx): figure out the boundary between large and small gemms - large_gemm = False - if M >= 2048 and N >= 2048: - large_gemm = True - - for config in configs: - BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") - BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") - BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") - num_warps = config.get("num_warps") - - if is_fp16: - matrix_instr_nonkdim = config.get("matrix_instr_nonkdim") - if matrix_instr_nonkdim > mfma: - continue - if mfma == 4 and BLOCK_SIZE_K < 64: - continue - # some layouts could not work properly in case - # number elements per thread is less 1 - if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: - continue - SPLIT_K = config.get("SPLIT_K", 1) - GROUP_M = config.get("GROUP_SIZE_M") - if is_fp16: - if (matrix_instr_nonkdim > BLOCK_SIZE_M or matrix_instr_nonkdim > BLOCK_SIZE_N): - continue - if (matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M): - continue - if (matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N): - continue - # Skip BLOCK_SIZE that is too large compare to M/N - # unless BLOCK_SIZE is already small enough - if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: - continue - if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: - continue - # skip large split_k when not necessary - if SPLIT_K != 1 and not need_split_k(M, N, K): - continue - # skip split_k that leads to EVEN_K = false - leap = SPLIT_K * BLOCK_SIZE_K - modv = K % leap - if modv != 0: - continue - # skip large GROUP_M - if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: - continue - # out of shared memory resource - # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) - if LDS > 65536: - continue - # Skip small block sizes and num_warps for large gemm - # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 - if large_gemm: - if BLOCK_SIZE_M < 128 or BLOCK_SIZE_N < 128: - continue - if BLOCK_SIZE_K < 64: - continue - if num_warps < 4: - continue - - pruned_configs.append(config) - - return pruned_configs - - -def merge_unique_dicts(list1, list2): - result = [] - combined_list = list1.copy() - combined_list.extend(list2) - for dictionary in combined_list: - if dictionary not in result: - result.append(dictionary) - return result - - -def prune_search_space(num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16): - N1, K1 = shard_intermediate_size, hidden_size - - pruned_space_1 = prune_configs(num_tokens * 2, N1, K1, search_space, is_fp16) - # NOTE, we are only tunning thr gemm here so only one pass of moe - # pruned_space_2 = prune_configs(num_tokens * 8, N2, K2, search_space, - # is_fp16) - # search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) - return pruned_space_1 - - -def update_configs(M: int, config: BenchmarkConfig, num_experts: int, shard_intermediate_size: int, hidden_size: int, - topk: int, dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: - dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) - - # NOTE(woosuk): The current naming convention uses w2.shape[2], which - # is the intermediate size after silu_and_mul. - # NOTE, we are only tunning thr gemm here so no // 2 - # filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - # dtype_str) - - filename = get_config_file_name(num_experts, shard_intermediate_size, dtype_str) - print(f"Best config: {config}") - print(f"Writing best config to {filename}...") - - config_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", filename) - # 1) Read the existing JSON file if it exists - old_configs = {} - if os.path.isfile(config_file_path): - with open(config_file_path, "r") as f: - try: - old_configs = json.load(f) - except json.JSONDecodeError: - # If the file is empty or corrupt, we just ignore it - old_configs = {} - - # 2) Update existing data with new configs - # If they share any keys, the new 'configs' will overwrite - old_configs[str(M)] = config - # old_configs[configs.keys()[0]] = configs[configs.keys()[0]] - - # 3) Write back to the same file - with open(config_file_path, "w") as f: - json.dump(old_configs, f, indent=2) - f.write("\n") - - -def get_tuning_configs(M, N, K, use_fp16): - param_ranges = get_tuning_space(use_fp16) - configs: List[BenchmarkConfig] = [] - - keys, values = zip(*param_ranges.items()) - for config_values in product(*values): - config = dict(zip(keys, config_values)) - configs.append(config) - - configs = prune_search_space(num_tokens=M, shard_intermediate_size=N, hidden_size=K, search_space=configs, - is_fp16=use_fp16) - - return configs diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json deleted file mode 100644 index 2cc0f41254eb..000000000000 --- a/python/perf-kernels/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +++ /dev/null @@ -1,200 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 1, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 1, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "512": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "1024": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "3072": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "4096": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - } -} diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json deleted file mode 100644 index 2d799bc0f4e9..000000000000 --- a/python/perf-kernels/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +++ /dev/null @@ -1,200 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "16": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "1024": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "3072": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - } -} diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json deleted file mode 100644 index 89ece30f9c15..000000000000 --- a/python/perf-kernels/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +++ /dev/null @@ -1,200 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "24": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 4, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "512": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, - "kpack": 2 - }, - "1024": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "3072": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - } -} diff --git a/python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json deleted file mode 100644 index 20529c52ee2d..000000000000 --- a/python/perf-kernels/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +++ /dev/null @@ -1,200 +0,0 @@ -{ - "1": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "2": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "4": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "8": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "16": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "24": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "32": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "48": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 2, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "64": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "96": { - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 32, - "kpack": 2 - }, - "512": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "1024": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "1536": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "2048": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 1 - }, - "3072": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - }, - "4096": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 1, - "waves_per_eu": 0, - "matrix_instr_nonkdim": 16, - "kpack": 2 - } -} diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json new file mode 100644 index 000000000000..634617547730 --- /dev/null +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json @@ -0,0 +1,24 @@ +{ + "small_M": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 2, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "large_M": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} \ No newline at end of file diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 44ac71d17c8e..00914563709c 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -8,8 +8,8 @@ import functools import argparse import sys -from benchmark_utils import get_tuning_configs, get_config_file_name, update_configs +M_THRESHOLD = 1024 @triton.jit def moe_gemm_kernel( @@ -194,8 +194,27 @@ def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, return sorted_ids, expert_ids, num_tokens_post_pad +def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False): + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def get_config_file_name(dtype: Optional[str]) -> str: + device_name = torch.cuda.get_device_name(0).replace(" ", "_") + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"device_name={device_name}{dtype_selector}.json" + + @functools.lru_cache -def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]: +def get_moe_configs(dtype: Optional[str]) -> Optional[Dict[int, Any]]: """ Return optimized configurations for the fused MoE kernel. @@ -206,13 +225,13 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, """ # First look up if an optimized configuration is available in the configs # directory - json_file_name = get_config_file_name(E, N, dtype) + json_file_name = get_config_file_name(dtype) config_file_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) if os.path.exists(config_file_path): with open(config_file_path) as f: # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + return {key: val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration @@ -222,10 +241,6 @@ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, def get_default_config( M: int, E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], is_marlin: bool, ) -> Dict[str, int]: config = {'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8} @@ -236,22 +251,20 @@ def get_default_config( def try_get_optimal_moe_config( - b_shape: Tuple[int, ...], - top_k: int, + E: int, dtype: Optional[str], M: int, is_marlin: bool = False, ): - E, N, K = b_shape - configs = get_moe_configs(E, N, dtype) + configs = get_moe_configs(dtype) if configs: # If an optimal configuration map has been found, look up the # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + config = configs["small_M" if M < M_THRESHOLD else "large_M"] else: # Else use the default config - config = get_default_config(M, E, N, K, top_k, dtype, is_marlin) + config = get_default_config(M, E, is_marlin) return config @@ -285,8 +298,7 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool config_dtype = None get_config_func = functools.partial( try_get_optimal_moe_config, - b.shape, - topk_ids.shape[1], + E, config_dtype, ) config = get_config_func(M) @@ -347,7 +359,6 @@ def run_benchmark(custom, args): routed_weight = args.routed_weight dtype = arg_to_torch_dtype[args.dtype] use_fp16 = args.dtype == 'fp16' - tune = args.tune if custom: assert args.M and args.K and args.N and args.E and args.top_k, \ "Please provide M, K, N, E, top_k for custom runs." @@ -363,11 +374,11 @@ def run_benchmark(custom, args): benchmark = triton.testing.Benchmark( x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], line_names=[line_names], styles=[('red', '-'), ('blue', '-')], ylabel='ms', plot_name='moe-gemm-benchmark', args={ - 'dtype': dtype, 'use_fp16': use_fp16, 'tune': tune, 'print_time': print_time, 'routed_weight': routed_weight + 'dtype': dtype, 'use_fp16': use_fp16, 'print_time': print_time, 'routed_weight': routed_weight }) @triton.testing.perf_report([benchmark]) - def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, tune, print_time, provider): + def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, print_time, provider): a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) @@ -375,28 +386,9 @@ def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, tune, prin if routed_weight: flops += M * top_k * N - if tune: - configs = get_tuning_configs(M, N, K, use_fp16) - print(f"Tuning start with {len(configs)} configs") - - min_ms = None - best_config = None - for config in configs: - print(config) - fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, - num_tokens_post_padded, config) - ms = triton.testing.do_bench(fn) - print(ms) - c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda') - if min_ms is None or ms < min_ms: - min_ms = ms - best_config = config - - update_configs(M, best_config, E, N, K, top_k, dtype, False, False) - else: - fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, - config) - ms = triton.testing.do_bench(fn) + fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, + config) + ms = triton.testing.do_bench(fn) if print_time: return ms @@ -418,7 +410,6 @@ def parse_args(): parser.add_argument("-E", type=int, default=0, help="Number of experts") parser.add_argument("-top_k", type=int, default=0, help="top_k experts per token") parser.add_argument("-routed_weight", action='store_true', default=False) - parser.add_argument("-tune", action='store_true', default=False) parser.add_argument("-dtype", default='fp16') parser.add_argument("-return_time", action='store_true', default=False, help='Return time instead of TFLOPs') args = parser.parse_args() From 34130b9b3395772c29cc0f5fb5ef5ec8320895ca Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 3 Jan 2025 16:34:49 +0000 Subject: [PATCH 06/21] pre commit --- .../configs/device_name=AMD_Instinct_MI300X.json | 2 +- python/perf-kernels/fused_moe/moe-gemm.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json index 634617547730..02aa8f69c993 100644 --- a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json @@ -21,4 +21,4 @@ "matrix_instr_nonkdim": 16, "kpack": 2 } -} \ No newline at end of file +} diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 00914563709c..3168a740cd4b 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -2,7 +2,7 @@ import torch import triton.language as tl import pytest -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional import os import json import functools @@ -11,6 +11,7 @@ M_THRESHOLD = 1024 + @triton.jit def moe_gemm_kernel( A, @@ -373,9 +374,8 @@ def run_benchmark(custom, args): benchmark = triton.testing.Benchmark( x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], line_names=[line_names], - styles=[('red', '-'), ('blue', '-')], ylabel='ms', plot_name='moe-gemm-benchmark', args={ - 'dtype': dtype, 'use_fp16': use_fp16, 'print_time': print_time, 'routed_weight': routed_weight - }) + styles=[('red', '-'), ('blue', '-')], ylabel='ms', plot_name='moe-gemm-benchmark', + args={'dtype': dtype, 'use_fp16': use_fp16, 'print_time': print_time, 'routed_weight': routed_weight}) @triton.testing.perf_report([benchmark]) def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, print_time, provider): @@ -387,7 +387,7 @@ def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, print_time flops += M * top_k * N fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, - config) + config) ms = triton.testing.do_bench(fn) if print_time: From dc68ce1f2bf3be279240bc3d4dfa79b0448a568c Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 7 Jan 2025 13:20:49 +0000 Subject: [PATCH 07/21] updated M_THRESHOLD and configs after tunnig --- .../configs/device_name=AMD_Instinct_MI300X.json | 10 +++++----- python/perf-kernels/fused_moe/moe-gemm.py | 6 +++++- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json index 02aa8f69c993..1e5bfd190664 100644 --- a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json @@ -1,15 +1,15 @@ { "small_M": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 2, + "GROUP_SIZE_M": 4, + "num_warps": 8, "num_stages": 2, "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 - }, +}, "large_M": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 3168a740cd4b..a7a61942cfc5 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -9,7 +9,7 @@ import argparse import sys -M_THRESHOLD = 1024 +M_THRESHOLD = 128 @triton.jit @@ -347,8 +347,12 @@ def get_configs(): configs = [ {"M": 64, "K": 128, "N": 256, "E": 8, "top_k": 2}, {"M": 64, "K": 1024, "N": 1792, "E": 8, "top_k": 2}, + {"M": 64, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, + {"M": 128, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, {"M": 1024, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, {"M": 4096, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, + {"M": 64, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, + {"M": 128, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, {"M": 1024, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, {"M": 4096, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, ] From 53fa702c76d940318a939c384af5c47c92bc6aad Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 7 Jan 2025 14:49:32 +0000 Subject: [PATCH 08/21] mistral model benchmarking --- python/perf-kernels/fused_moe/moe-gemm.py | 43 +++++++++++++++++++---- python/perf-kernels/model_configs.json | 17 +++++++++ python/perf-kernels/utils/__init__.py | 0 3 files changed, 54 insertions(+), 6 deletions(-) create mode 100644 python/perf-kernels/utils/__init__.py diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index a7a61942cfc5..7d1887968d91 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -8,6 +8,12 @@ import functools import argparse import sys +SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) +PARENT_DIR = os.path.dirname(SCRIPT_DIR) # This goes one level up from fused-moe/ +if PARENT_DIR not in sys.path: + sys.path.append(PARENT_DIR) + +from utils.benchmark_utils import get_available_models, get_model_configs M_THRESHOLD = 128 @@ -359,20 +365,40 @@ def get_configs(): return configs +def model_benchmark_configs(args): + config_file = args.model_configs + configs = get_model_configs(config_path=config_file, model_families=["mistral"], model=args.model) + fa_configs = [] + M = args.M if args.M else 1024 # check size + # M, K, N, E, top_k + + for model_name, config in configs.items(): + N = config["intermediate_size"] + K = config["hidden_size"] + + E = 8 + top_k = 2 + fa_configs.append((model_name, M, K, N, E, top_k)) + + return fa_configs + def run_benchmark(custom, args): print_time = args.return_time routed_weight = args.routed_weight dtype = arg_to_torch_dtype[args.dtype] use_fp16 = args.dtype == 'fp16' + x_names = ['M', 'K', 'N', 'E', 'top_k'] if custom: assert args.M and args.K and args.N and args.E and args.top_k, \ "Please provide M, K, N, E, top_k for custom runs." - configs = [{"M": args.M, "K": args.K, "N": args.N, "E": args.E, "top_k": args.top_k}] + x_vals_list = [(args.M, args.K, args.N, args.E, args.top_k)] else: - configs = get_configs() - - x_names = ['M', 'K', 'N', 'E', 'top_k'] - x_vals_list = [(cfg['M'], cfg['K'], cfg['N'], cfg['E'], cfg['top_k']) for cfg in configs] + if args.model: + x_vals_list = model_benchmark_configs(args) + x_names = ['model', 'M', 'K', 'N', 'E', 'top_k'] + else: + configs = get_configs() + x_vals_list = [(cfg['M'], cfg['K'], cfg['N'], cfg['E'], cfg['top_k']) for cfg in configs] line_names = 'Time (ms)' if print_time else 'TFLOPS' @@ -382,7 +408,7 @@ def run_benchmark(custom, args): args={'dtype': dtype, 'use_fp16': use_fp16, 'print_time': print_time, 'routed_weight': routed_weight}) @triton.testing.perf_report([benchmark]) - def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, print_time, provider): + def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, print_time, provider, model=None): a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) @@ -408,6 +434,11 @@ def parse_args(): prog="Benchmark MoE GEMM", allow_abbrev=False, ) + parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.") + available_models = get_available_models(model_families=["mistral"]) # Dynamically load model names + model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) + + "]. Use 'all' to benchmark all models or leave blank for the default benchmark script.") + parser.add_argument('-model', type=str, default=None, help=model_help) parser.add_argument("-M", type=int, default=0, help="M dimension") parser.add_argument("-K", type=int, default=0, help="K dimension") parser.add_argument("-N", type=int, default=0, help="N dimension") diff --git a/python/perf-kernels/model_configs.json b/python/perf-kernels/model_configs.json index 52c44b52478b..1839cc5830ac 100644 --- a/python/perf-kernels/model_configs.json +++ b/python/perf-kernels/model_configs.json @@ -24,5 +24,22 @@ "intermediate_size": 53248, "vocab_size": 128256 } + }, + "mistral": { + "7B": { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "vocab_size": 32000 + } + , + "22B": { + "hidden_size": 6144, + "intermediate_size": 16384, + "num_attention_heads": 48, + "num_key_value_heads": 8, + "vocab_size": 32000 + } } } diff --git a/python/perf-kernels/utils/__init__.py b/python/perf-kernels/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 From 1698d79a744023960b3a8229e5817801e4ed3176 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 7 Jan 2025 14:50:32 +0000 Subject: [PATCH 09/21] pre commit --- .../fused_moe/configs/device_name=AMD_Instinct_MI300X.json | 2 +- python/perf-kernels/fused_moe/moe-gemm.py | 6 ++++-- python/perf-kernels/model_configs.json | 3 +-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json index 1e5bfd190664..c4c9fc045e4a 100644 --- a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json @@ -9,7 +9,7 @@ "waves_per_eu": 0, "matrix_instr_nonkdim": 16, "kpack": 2 -}, + }, "large_M": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 7d1887968d91..7a0270bbc951 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -8,6 +8,7 @@ import functools import argparse import sys + SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) PARENT_DIR = os.path.dirname(SCRIPT_DIR) # This goes one level up from fused-moe/ if PARENT_DIR not in sys.path: @@ -369,7 +370,7 @@ def model_benchmark_configs(args): config_file = args.model_configs configs = get_model_configs(config_path=config_file, model_families=["mistral"], model=args.model) fa_configs = [] - M = args.M if args.M else 1024 # check size + M = args.M if args.M else 1024 # check size # M, K, N, E, top_k for model_name, config in configs.items(): @@ -382,6 +383,7 @@ def model_benchmark_configs(args): return fa_configs + def run_benchmark(custom, args): print_time = args.return_time routed_weight = args.routed_weight @@ -437,7 +439,7 @@ def parse_args(): parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.") available_models = get_available_models(model_families=["mistral"]) # Dynamically load model names model_help = ("Model name to benchmark. Select from: [" + ", ".join(available_models) + - "]. Use 'all' to benchmark all models or leave blank for the default benchmark script.") + "]. Use 'all' to benchmark all models or leave blank for the default benchmark script.") parser.add_argument('-model', type=str, default=None, help=model_help) parser.add_argument("-M", type=int, default=0, help="M dimension") parser.add_argument("-K", type=int, default=0, help="K dimension") diff --git a/python/perf-kernels/model_configs.json b/python/perf-kernels/model_configs.json index 1839cc5830ac..c6b6fd0bdedb 100644 --- a/python/perf-kernels/model_configs.json +++ b/python/perf-kernels/model_configs.json @@ -32,8 +32,7 @@ "num_attention_heads": 32, "num_key_value_heads": 8, "vocab_size": 32000 - } - , + }, "22B": { "hidden_size": 6144, "intermediate_size": 16384, From 42d8dbc6e8f3e26e54bea2390ed49dcb0bdd8b4b Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 7 Jan 2025 14:53:35 +0000 Subject: [PATCH 10/21] noqa: E402 --- python/perf-kernels/fused_moe/moe-gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 7a0270bbc951..3b179c8df912 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -14,7 +14,7 @@ if PARENT_DIR not in sys.path: sys.path.append(PARENT_DIR) -from utils.benchmark_utils import get_available_models, get_model_configs +from utils.benchmark_utils import get_available_models, get_model_configs # noqa: E402 M_THRESHOLD = 128 From 3a04e9036d7c8ba8ebfb20ff2371afb4347678e5 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Tue, 7 Jan 2025 14:54:02 +0000 Subject: [PATCH 11/21] pre commit --- python/perf-kernels/fused_moe/moe-gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 3b179c8df912..35f50abd56ce 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -14,7 +14,7 @@ if PARENT_DIR not in sys.path: sys.path.append(PARENT_DIR) -from utils.benchmark_utils import get_available_models, get_model_configs # noqa: E402 +from utils.benchmark_utils import get_available_models, get_model_configs # noqa: E402 M_THRESHOLD = 128 From 9c83fd73a7a1e285c531914320080e2c5d429378 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 8 Jan 2025 12:12:23 +0000 Subject: [PATCH 12/21] more fine tuned model config. show mem throught put in benchmark --- .../device_name=AMD_Instinct_MI300X.json | 17 +++++- python/perf-kernels/fused_moe/moe-gemm.py | 61 ++++++++++++++----- 2 files changed, 61 insertions(+), 17 deletions(-) diff --git a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json index c4c9fc045e4a..5955cc32c481 100644 --- a/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json +++ b/python/perf-kernels/fused_moe/configs/device_name=AMD_Instinct_MI300X.json @@ -1,8 +1,8 @@ { "small_M": { - "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 8, "num_stages": 2, @@ -10,9 +10,20 @@ "matrix_instr_nonkdim": 16, "kpack": 2 }, - "large_M": { + "medium_M": { "BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "large_M": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 35f50abd56ce..0edd66df63d8 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -16,7 +16,8 @@ from utils.benchmark_utils import get_available_models, get_model_configs # noqa: E402 -M_THRESHOLD = 128 +M_THRESHOLD_SMALL = 256 +M_THRESHOLD_MEDIUM = 1024 @triton.jit @@ -267,9 +268,13 @@ def try_get_optimal_moe_config( configs = get_moe_configs(dtype) if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs["small_M" if M < M_THRESHOLD else "large_M"] + if configs: + if M < M_THRESHOLD_SMALL: + config = configs["small_M"] + elif M < M_THRESHOLD_MEDIUM: + config = configs["medium_M"] + else: + config = configs["large_M"] else: # Else use the default config config = get_default_config(M, E, is_marlin) @@ -360,7 +365,10 @@ def get_configs(): {"M": 4096, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, {"M": 64, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, {"M": 128, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, + {"M": 256, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, + {"M": 512, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, {"M": 1024, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, + {"M": 2048, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, {"M": 4096, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, ] return configs @@ -388,7 +396,6 @@ def run_benchmark(custom, args): print_time = args.return_time routed_weight = args.routed_weight dtype = arg_to_torch_dtype[args.dtype] - use_fp16 = args.dtype == 'fp16' x_names = ['M', 'K', 'N', 'E', 'top_k'] if custom: assert args.M and args.K and args.N and args.E and args.top_k, \ @@ -402,15 +409,30 @@ def run_benchmark(custom, args): configs = get_configs() x_vals_list = [(cfg['M'], cfg['K'], cfg['N'], cfg['E'], cfg['top_k']) for cfg in configs] - line_names = 'Time (ms)' if print_time else 'TFLOPS' + line_names = ['Time (ms)', 'Bandwidth (GB/s)'] if print_time else ['TFLOPS', 'Bandwidth (GB/s)'] - benchmark = triton.testing.Benchmark( - x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'], line_names=[line_names], - styles=[('red', '-'), ('blue', '-')], ylabel='ms', plot_name='moe-gemm-benchmark', - args={'dtype': dtype, 'use_fp16': use_fp16, 'print_time': print_time, 'routed_weight': routed_weight}) + if print_time: + # We'll have 2 lines: 'time' and 'bandwidth' + line_vals = ['time', 'bandwidth'] + line_names = ['Time (ms)', 'Bandwidth (GB/s)'] + else: + line_vals = ['tflops', 'bandwidth'] + line_names = ['TFLOPS', 'Bandwidth (GB/s)'] + benchmark = triton.testing.Benchmark( + x_names=x_names, + x_vals=x_vals_list, + line_arg='metric', # <--- important + line_vals=line_vals, # <--- a list of 2 metrics + line_names=line_names, # <--- matching 2 metrics + styles=[('red', '-'), ('blue', '-')], + ylabel='ms / TFLOPS / GB/s', # or a more generic label + plot_name='moe-gemm-benchmark', + args={'dtype': dtype, 'routed_weight': routed_weight} + ) @triton.testing.perf_report([benchmark]) - def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, print_time, provider, model=None): + def bench_moe_gemm(M, K, N, E, top_k, dtype, routed_weight, metric, model=None): + # metric will be either 'time'/'tflops' or 'bandwidth' a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) @@ -418,15 +440,26 @@ def bench_moe_gemm(M, K, N, E, top_k, dtype, use_fp16, routed_weight, print_time if routed_weight: flops += M * top_k * N + bytes_ = torch.tensor([], dtype=dtype).element_size() + mem_read = (M * K + top_k * N * K) * bytes_ + mem_write = (M * top_k * N) * bytes_ + mem = mem_read + mem_write fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config) ms = triton.testing.do_bench(fn) - if print_time: + bandwidth = mem / (ms * 1e-3) * 1e-9 # GB/s + tflops = flops / ms * 1e-9 + + # Return exactly one scalar depending on which metric is active + if metric == 'time': return ms + elif metric == 'tflops': + return tflops + elif metric == 'bandwidth': + return bandwidth else: - # Convert flops to TFLOPs - return flops / ms * 1e-9 + raise ValueError("Unknown metric: " + metric) bench_moe_gemm.run(save_path=".", print_data=True) From e9d3dc28794bf38b72a360cbf1b7a9c5fd4a3f4c Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 8 Jan 2025 12:12:50 +0000 Subject: [PATCH 13/21] pre-commit --- python/perf-kernels/fused_moe/moe-gemm.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 0edd66df63d8..bab9cb3e2f36 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -419,17 +419,14 @@ def run_benchmark(custom, args): line_vals = ['tflops', 'bandwidth'] line_names = ['TFLOPS', 'Bandwidth (GB/s)'] - benchmark = triton.testing.Benchmark( - x_names=x_names, - x_vals=x_vals_list, - line_arg='metric', # <--- important - line_vals=line_vals, # <--- a list of 2 metrics - line_names=line_names, # <--- matching 2 metrics - styles=[('red', '-'), ('blue', '-')], - ylabel='ms / TFLOPS / GB/s', # or a more generic label - plot_name='moe-gemm-benchmark', - args={'dtype': dtype, 'routed_weight': routed_weight} - ) + benchmark = triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='metric', # <--- important + line_vals=line_vals, # <--- a list of 2 metrics + line_names=line_names, # <--- matching 2 metrics + styles=[('red', '-'), + ('blue', '-')], ylabel='ms / TFLOPS / GB/s', # or a more generic label + plot_name='moe-gemm-benchmark', + args={'dtype': dtype, 'routed_weight': routed_weight}) + @triton.testing.perf_report([benchmark]) def bench_moe_gemm(M, K, N, E, top_k, dtype, routed_weight, metric, model=None): # metric will be either 'time'/'tflops' or 'bandwidth' From ad2daad1e0ce4e16f5c0f07c3eb3fe020bcbdc6f Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Wed, 8 Jan 2025 14:32:27 +0000 Subject: [PATCH 14/21] fixed bandwidth computation --- python/perf-kernels/fused_moe/moe-gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index bab9cb3e2f36..12263c493af0 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -438,7 +438,7 @@ def bench_moe_gemm(M, K, N, E, top_k, dtype, routed_weight, metric, model=None): flops += M * top_k * N bytes_ = torch.tensor([], dtype=dtype).element_size() - mem_read = (M * K + top_k * N * K) * bytes_ + mem_read = (M * K + E * N * K) * bytes_ mem_write = (M * top_k * N) * bytes_ mem = mem_read + mem_write fn = lambda: moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, From 30a488c73ee2438faaabbd64ec5f9c4145e50bf8 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Jan 2025 09:09:49 +0000 Subject: [PATCH 15/21] First and second gemm odel benchmarking --- python/perf-kernels/fused_moe/moe-gemm.py | 30 +++++++++++++---------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 12263c493af0..1e9eb8920859 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -298,7 +298,7 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: to return c -def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype): +def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, dtype): a = torch.randn((M, K), dtype=dtype, device='cuda') b = torch.randn((E, N, K), dtype=dtype, device='cuda') c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda') @@ -378,16 +378,20 @@ def model_benchmark_configs(args): config_file = args.model_configs configs = get_model_configs(config_path=config_file, model_families=["mistral"], model=args.model) fa_configs = [] - M = args.M if args.M else 1024 # check size + M = args.M if args.M else 4096 # check size # M, K, N, E, top_k for model_name, config in configs.items(): - N = config["intermediate_size"] - K = config["hidden_size"] + N1 = config["intermediate_size"] * 2 + K1 = config["hidden_size"] + + N2 = config["hidden_size"] + K2 = config["intermediate_size"] E = 8 top_k = 2 - fa_configs.append((model_name, M, K, N, E, top_k)) + fa_configs.append((model_name, M, N1, K1, E, top_k)) + fa_configs.append((model_name, M, N2, K2, E, top_k)) return fa_configs @@ -396,18 +400,18 @@ def run_benchmark(custom, args): print_time = args.return_time routed_weight = args.routed_weight dtype = arg_to_torch_dtype[args.dtype] - x_names = ['M', 'K', 'N', 'E', 'top_k'] + x_names = ['M', 'N', 'K', 'E', 'top_k'] if custom: - assert args.M and args.K and args.N and args.E and args.top_k, \ - "Please provide M, K, N, E, top_k for custom runs." - x_vals_list = [(args.M, args.K, args.N, args.E, args.top_k)] + assert args.M and args.N and args.K and args.E and args.top_k, \ + "Please provide M, N, K, E, top_k for custom runs." + x_vals_list = [(args.M, args.N, args.K, args.E, args.top_k)] else: if args.model: x_vals_list = model_benchmark_configs(args) - x_names = ['model', 'M', 'K', 'N', 'E', 'top_k'] + x_names = ['model', 'M', 'N', 'K', 'E', 'top_k'] else: configs = get_configs() - x_vals_list = [(cfg['M'], cfg['K'], cfg['N'], cfg['E'], cfg['top_k']) for cfg in configs] + x_vals_list = [(cfg['M'], cfg['N'], cfg['K'], cfg['E'], cfg['top_k']) for cfg in configs] line_names = ['Time (ms)', 'Bandwidth (GB/s)'] if print_time else ['TFLOPS', 'Bandwidth (GB/s)'] @@ -428,10 +432,10 @@ def run_benchmark(custom, args): args={'dtype': dtype, 'routed_weight': routed_weight}) @triton.testing.perf_report([benchmark]) - def bench_moe_gemm(M, K, N, E, top_k, dtype, routed_weight, metric, model=None): + def bench_moe_gemm(M, N, K, E, top_k, dtype, routed_weight, metric, model=None): # metric will be either 'time'/'tflops' or 'bandwidth' a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( - M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) + M, N, K, top_k, E, routed_weight=routed_weight, dtype=dtype) flops = 2.0 * M * top_k * K * N if routed_weight: From 39eca098a5ebd5ec85185fad3bdc82340784e430 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Jan 2025 11:10:17 +0000 Subject: [PATCH 16/21] reversed k n --- python/perf-kernels/fused_moe/moe-gemm.py | 48 +++++++++++------------ 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 1e9eb8920859..1f75cbafa08a 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -324,21 +324,21 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool @pytest.mark.parametrize("M, K, N, top_k, E", [ - (64, 4096, 14336, 2, 8), - (16, 1, 14336, 2, 4), - (1, 128, 14336, 2, 4), - (16, 128, 14336, 1, 4), - (16, 128, 14336, 1, 1), - (64, 128, 7186, 2, 8), - (64, 128, 3584, 2, 8), - (64, 128, 1792, 2, 8), - (64, 128, 64, 2, 8), + (64, 14336, 4096, 2, 8), + (16, 14336, 1, 2, 4), + (1, 14336, 128, 2, 4), + (16, 14336, 128, 1, 4), + (16, 14336, 128, 1, 1), + (64, 7186, 128, 2, 8), + (64, 3584, 128, 2, 8), + (64, 1792, 128, 2, 8), + (64, 64, 128, 2, 8), ]) @pytest.mark.parametrize('routed_weight', [True, False]) -def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16): +def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16): torch.manual_seed(20) a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper( - M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype) + M, N, K, top_k, E, routed_weight=routed_weight, dtype=dtype) tri_out = moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config) @@ -357,19 +357,19 @@ def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: def get_configs(): configs = [ - {"M": 64, "K": 128, "N": 256, "E": 8, "top_k": 2}, - {"M": 64, "K": 1024, "N": 1792, "E": 8, "top_k": 2}, - {"M": 64, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, - {"M": 128, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, - {"M": 1024, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, - {"M": 4096, "K": 4096, "N": 7168, "E": 8, "top_k": 2}, - {"M": 64, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, - {"M": 128, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, - {"M": 256, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, - {"M": 512, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, - {"M": 1024, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, - {"M": 2048, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, - {"M": 4096, "K": 4096, "N": 14336, "E": 8, "top_k": 2}, + {"M": 64, "K": 256, "N": 128, "E": 8, "top_k": 2}, + {"M": 64, "K": 1792, "N": 1024, "E": 8, "top_k": 2}, + {"M": 64, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 128, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 64, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 128, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 256, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 512, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 2048, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "K": 14336, "N": 4096, "E": 8, "top_k": 2} ] return configs From 0da016e8d44f4958e950dfa00ec10e3ed5a1bb4b Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Jan 2025 11:14:02 +0000 Subject: [PATCH 17/21] pre commit --- python/perf-kernels/fused_moe/moe-gemm.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index 1f75cbafa08a..f5aab2d9ea86 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -357,16 +357,16 @@ def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: def get_configs(): configs = [ - {"M": 64, "K": 256, "N": 128, "E": 8, "top_k": 2}, - {"M": 64, "K": 1792, "N": 1024, "E": 8, "top_k": 2}, - {"M": 64, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 128, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 1024, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 4096, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 64, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 128, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 256, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 512, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 64, "K": 256, "N": 128, "E": 8, "top_k": 2}, + {"M": 64, "K": 1792, "N": 1024, "E": 8, "top_k": 2}, + {"M": 64, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 128, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 64, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 128, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 256, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 512, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, {"M": 1024, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, {"M": 2048, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, {"M": 4096, "K": 14336, "N": 4096, "E": 8, "top_k": 2} From be6520b40ec5faa9723222d478cf4e9d2e14132d Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Jan 2025 11:23:10 +0000 Subject: [PATCH 18/21] pre-commit fix format --- python/perf-kernels/fused_moe/moe-gemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index f5aab2d9ea86..a6115fda1e9e 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -357,7 +357,6 @@ def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: def get_configs(): configs = [ - {"M": 64, "K": 256, "N": 128, "E": 8, "top_k": 2}, {"M": 64, "K": 1792, "N": 1024, "E": 8, "top_k": 2}, {"M": 64, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, {"M": 128, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, From 54d207dd3fd136c190c16710b8c7e39ae2809793 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Jan 2025 11:25:17 +0000 Subject: [PATCH 19/21] pre commit fix --- python/perf-kernels/fused_moe/moe-gemm.py | 27 +++++++++++------------ 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index a6115fda1e9e..d78b731aefd7 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -356,20 +356,19 @@ def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: def get_configs(): - configs = [ - {"M": 64, "K": 1792, "N": 1024, "E": 8, "top_k": 2}, - {"M": 64, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 128, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 1024, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 4096, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 64, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 128, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 256, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 512, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 1024, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 2048, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 4096, "K": 14336, "N": 4096, "E": 8, "top_k": 2} - ] + configs = [{"M": 64, "K": 256, "N": 128, "E": 8, "top_k": 2}, + {"M": 64, "K": 1792, "N": 1024, "E": 8, "top_k": 2}, + {"M": 64, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 128, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, + {"M": 64, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 128, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 256, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 512, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 2048, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "K": 14336, "N": 4096, "E": 8, "top_k": 2},] return configs From 2fe49dd97ea433fd5af15ded7964e8f14685c42d Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Jan 2025 11:30:09 +0000 Subject: [PATCH 20/21] pre commit fix --- python/perf-kernels/fused_moe/moe-gemm.py | 29 ++++++++++++----------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index d78b731aefd7..e7d922d2f97b 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -323,7 +323,7 @@ def input_helper(M: int, N: int, K: int, top_k: int, E: int, routed_weight: bool return a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config -@pytest.mark.parametrize("M, K, N, top_k, E", [ +@pytest.mark.parametrize("M, N, K, top_k, E", [ (64, 14336, 4096, 2, 8), (16, 14336, 1, 2, 4), (1, 14336, 128, 2, 4), @@ -356,19 +356,20 @@ def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: def get_configs(): - configs = [{"M": 64, "K": 256, "N": 128, "E": 8, "top_k": 2}, - {"M": 64, "K": 1792, "N": 1024, "E": 8, "top_k": 2}, - {"M": 64, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 128, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 1024, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 4096, "K": 7168, "N": 4096, "E": 8, "top_k": 2}, - {"M": 64, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 128, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 256, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 512, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 1024, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 2048, "K": 14336, "N": 4096, "E": 8, "top_k": 2}, - {"M": 4096, "K": 14336, "N": 4096, "E": 8, "top_k": 2},] + configs = [{"M": 64, "N": 256, "K": 128, "E": 8, "top_k": 2}, + {"M": 64, "N": 1792, "K": 1024, "E": 8, "top_k": 2}, + {"M": 64, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 128, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 64, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 128, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 256, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 512, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 2048, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + ] return configs From 695922239d693e7758631d51a42a84654dd62e18 Mon Sep 17 00:00:00 2001 From: Tianxing Wu Date: Fri, 10 Jan 2025 11:31:00 +0000 Subject: [PATCH 21/21] pre commit --- python/perf-kernels/fused_moe/moe-gemm.py | 27 ++++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/python/perf-kernels/fused_moe/moe-gemm.py b/python/perf-kernels/fused_moe/moe-gemm.py index e7d922d2f97b..025521462aad 100644 --- a/python/perf-kernels/fused_moe/moe-gemm.py +++ b/python/perf-kernels/fused_moe/moe-gemm.py @@ -356,19 +356,20 @@ def test_correctness(M: int, N: int, K: int, top_k: int, E: int, routed_weight: def get_configs(): - configs = [{"M": 64, "N": 256, "K": 128, "E": 8, "top_k": 2}, - {"M": 64, "N": 1792, "K": 1024, "E": 8, "top_k": 2}, - {"M": 64, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, - {"M": 128, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, - {"M": 1024, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, - {"M": 4096, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, - {"M": 64, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, - {"M": 128, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, - {"M": 256, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, - {"M": 512, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, - {"M": 1024, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, - {"M": 2048, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, - {"M": 4096, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + configs = [ + {"M": 64, "N": 256, "K": 128, "E": 8, "top_k": 2}, + {"M": 64, "N": 1792, "K": 1024, "E": 8, "top_k": 2}, + {"M": 64, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 128, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "N": 7168, "K": 4096, "E": 8, "top_k": 2}, + {"M": 64, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 128, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 256, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 512, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 1024, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 2048, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, + {"M": 4096, "N": 14336, "K": 4096, "E": 8, "top_k": 2}, ] return configs