Skip to content

Commit

Permalink
Add performance reference for important matmul kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglx13 committed Sep 16, 2024
1 parent c4bd738 commit 87cf022
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 10 deletions.
14 changes: 14 additions & 0 deletions python/perf-kernels/tools/tune_gemm/[email protected]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
trans,M,N,K,TFLOPS,us
TN,4864,4096,4096,467.39,349.19
TN,4864,4096,4160,567.17,292.26
TN,4864,4096,4224,557.49,301.90
TN,4864,4096,4288,569.55,299.99
TN,4864,4096,4097,501.58,325.47
TN,4864,4096,4098,491.96,331.92
TN,4864,4096,4100,503.51,324.46
TN,4864,4096,4104,515.70,317.10
TN,4864,4096,4112,525.66,311.70
TN,4864,8192,4096,519.95,627.79
TN,4864,8192,4160,579.14,572.43
TN,4864,8192,8192,543.30,1201.6
TN,4864,8192,8256,563.43,1167.7
18 changes: 18 additions & 0 deletions python/perf-kernels/tools/tune_gemm/database.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# M // BLOCK_M * N // BLOCK_N % 304 == 0
## 1 workgroup / CU
- {'M': 4864, 'N': 4096, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4224, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4288, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
## 1 workgroup / CU masked loadK
- {'M': 4864, 'N': 4096, 'K': 4097, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4098, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4100, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4104, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 4096, 'K': 4112, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}

## 2 workgroups / CU
- {'M': 4864, 'N': 8192, 'K': 4096, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 8192, 'K': 4160, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 8192, 'K': 8192, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
- {'M': 4864, 'N': 8192, 'K': 8256, 'rowMajorA': 'T', 'rowMajorB': 'N', 'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 4, 'SPLIT_K': 1, 'num_warps': 8, 'num_stages': 0, 'waves_per_eu': 0, 'matrix_instr_nonkdim': 16, 'kpack': 2}
24 changes: 17 additions & 7 deletions python/perf-kernels/tools/tune_gemm/matmul_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,26 @@ def matmul_kernel(a_ptr, b_ptr, c_ptr, bias_ptr, M, N, K, stride_am, stride_ak,
bias = tl.load(bias_ptrs, mask=offs_am < M, other=0.0)
acc_dtype = tl.float32 if a_ptr.type.element_ty != tl.int8 else tl.int32
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=acc_dtype)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
if EVEN_K:
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
else:
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)

max_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) - 1
if EVEN_K:
max_k += 1
for k in range(0, max_k):
a = tl.load(tl.multiple_of(a_ptrs, (1, 16)))
b = tl.load(tl.multiple_of(b_ptrs, (16, 1)))
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak
b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk

if not EVEN_K:
k = max_k
offs_k = k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
a_ptrsX = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak
b_ptrsX = b_ptr + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
a = tl.load(a_ptrsX, mask=offs_k[None, :] < K, other=0.0)
b = tl.load(b_ptrsX, mask=offs_k[:, None] < K, other=0.0)
accumulator += tl.dot(a, b)

c = accumulator.to(c_ptr.type.element_ty)
if BIAS:
c += bias[:, None]
Expand Down
8 changes: 5 additions & 3 deletions python/perf-kernels/tools/tune_gemm/tune_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
num_warps = config.get("num_warps")
num_stages = config.get("num_stages")
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
EVEN_K = (K % BLOCK_SIZE_K == 0)
if matrix_instr_nonkdim > mfma:
continue
if mfma == 4 and BLOCK_SIZE_K < 64:
Expand Down Expand Up @@ -149,10 +150,11 @@ def prune_configs(M, N, K, configs, elemBytes_a, elemBytes_b):
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
# We only want to use a small BLOCK_SIZE_K if not EVEN_K
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
if BLOCK_SIZE_K < 64 and EVEN_K:
continue
if num_warps < 4:
continue
Expand Down Expand Up @@ -657,14 +659,14 @@ def main():

# write best config to tuning_results.yaml
if run_bench:
print(f"{formatted_tflops} {minTime}")
print(f"{formatted_tflops} {minTime} {bestConfig_compact_str}")
f_results.write(f"{formatted_tflops},{minTime}\n")

sizeDict = {'M': M, 'N': N, 'K': K, 'rowMajorA': row_a_str, 'rowMajorB': row_b_str}
sizeDict.update(bestConfig)
if not run_bench:
f_results.write("- " + str(sizeDict) + " ")
f_results.write(f'# TFLOPS: {formatted_tflops} time(us): {minTime}\n')
f_results.write(f'# {bestConfig_compact_str}\n')

# remove generated files if asked to
if not keepTmp:
Expand Down

0 comments on commit 87cf022

Please sign in to comment.