Skip to content

Commit

Permalink
add dtype support, set the default dtype to fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi-Chu319 committed Dec 23, 2024
1 parent cefc74e commit cd06c06
Showing 1 changed file with 72 additions and 11 deletions.
83 changes: 72 additions & 11 deletions python/perf-kernels/fused_moe/moe-gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import triton.language as tl
import pytest
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, TypedDict
import os
import json
import functools
Expand Down Expand Up @@ -101,7 +101,7 @@ def moe_gemm_kernel(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_ptrs = Out + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(out_ptrs, accumulator, mask=c_mask)
tl.store(out_ptrs, accumulator.to(Out.dtype.element_ty), mask=c_mask)


def _moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, top_k: int, block_size: int,
Expand Down Expand Up @@ -290,10 +290,10 @@ def moe_gemm(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor, topk_weights: to
return c


def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool):
a = torch.randn((M, K), dtype=torch.float32, device='cuda')
b = torch.randn((E, N, K), dtype=torch.float32, device='cuda')
c = torch.zeros((M, top_k, N), dtype=torch.float32, device='cuda')
def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype):
a = torch.randn((M, K), dtype=dtype, device='cuda')
b = torch.randn((E, N, K), dtype=dtype, device='cuda')
c = torch.zeros((M, top_k, N), dtype=dtype, device='cuda')

values = torch.randn(M, E, device='cuda')

Expand Down Expand Up @@ -329,10 +329,10 @@ def input_helper(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool
(64, 128, 64, 2, 8),
])
@pytest.mark.parametrize('routed_weight', [True, False])
def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool):
def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight: bool, dtype=torch.float16):
torch.manual_seed(20)
a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper(
M, K, N, top_k, E, routed_weight=routed_weight)
M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype)

# TODO Quantization support
tri_out = moe_gemm(a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config)
Expand All @@ -349,6 +349,62 @@ def test_correctness(M: int, K: int, N: int, top_k: int, E: int, routed_weight:
# Validate correctness
torch.testing.assert_close(tri_out, ref_out, atol=1e-2, rtol=1e-2)

class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int

def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)

# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)

print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")

def get_rocm_tuning_space(use_fp16):
block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256]
if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [1, 2, 4, 8]
group_m_range = [1, 4, 8, 16, 32]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
# other values in the future
num_stage_range = [0]
waves_per_eu_range = [0]
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
kpack_range = [1, 2] if use_fp16 else []

param_ranges = {
"BLOCK_SIZE_M": block_mn_range,
"BLOCK_SIZE_N": block_mn_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range,
}
if use_fp16:
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
param_ranges["kpack"] = kpack_range

return param_ranges


def get_configs():
configs = [
Expand All @@ -367,6 +423,8 @@ def get_configs():
def run_benchmark(custom, args):
print_time = args.return_time
routed_weight = args.routed_weight
dtype = arg_to_torch_dtype[args.dtype]
tune = args.tune
if custom:
assert args.M and args.K and args.N and args.E and args.top_k, \
"Please provide M, K, N, E, top_k for custom runs."
Expand All @@ -382,12 +440,12 @@ def run_benchmark(custom, args):
benchmark = triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'],
line_names=[line_names], styles=[('red', '-'), ('blue', '-')], ylabel='ms',
plot_name='moe-gemm-benchmark',
args={'print_time': print_time, 'routed_weight': routed_weight})
args={'dtype': dtype, 'print_time': print_time, 'routed_weight': routed_weight})

@triton.testing.perf_report([benchmark])
def bench_moe_gemm(M, K, N, E, top_k, routed_weight, print_time, provider):
def bench_moe_gemm(M, K, N, E, top_k, dtype, routed_weight, print_time, provider):
a, b, c, topk_weights, topk_ids, sorted_token_ids, expert_ids, num_tokens_post_padded, config = input_helper(
M, K, N, top_k, E, routed_weight=routed_weight)
M, K, N, top_k, E, routed_weight=routed_weight, dtype=dtype)

flops = 2.0 * M * top_k * K * N
if routed_weight:
Expand Down Expand Up @@ -417,10 +475,13 @@ def parse_args():
parser.add_argument("-E", type=int, default=0, help="Number of experts")
parser.add_argument("-top_k", type=int, default=0, help="top_k experts per token")
parser.add_argument("-routed_weight", action='store_true', default=False)
parser.add_argument("-tune", action='store_true', default=False)
parser.add_argument("-dtype", default='fp16')
parser.add_argument("-return_time", action='store_true', default=False, help='Return time instead of TFLOPs')
args = parser.parse_args()
return args

arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}

def main():
args = parse_args()
Expand Down

0 comments on commit cd06c06

Please sign in to comment.