Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tianxing/moe gemm #685

Open
wants to merge 24 commits into
base: main_perf
Choose a base branch
from
Open

Tianxing/moe gemm #685

wants to merge 24 commits into from

Conversation

Chi-Chu319
Copy link

@Chi-Chu319 Chi-Chu319 commented Dec 18, 2024

Implemented moe gemm, test and benchmarking.

run python python/perf-kernels/flash-attention.py -model all to benchmark with the mistral models
run python python/perf-kernels/flash-attention.py -routed_weight to benchmark with the routed weight

The benchmark shows the memory bandwidth

benchmark result:

M K N E top_k TFLOPS Bandwidth (GB/s)
64.0 128.0 256.0 8.0 2.0 0.092229 7.422335
64.0 1024.0 1792.0 8.0 2.0 5.079523 339.276790
64.0 4096.0 7168.0 8.0 2.0 27.586739 1725.612375
128.0 4096.0 7168.0 8.0 2.0 53.671204 1677.165903
1024.0 4096.0 7168.0 8.0 2.0 179.189020 755.743558
4096.0 4096.0 7168.0 8.0 2.0 365.320720 485.117674
64.0 4096.0 14336.0 8.0 2.0 35.497301 2211.471846
128.0 4096.0 14336.0 8.0 2.0 71.006905 2160.232271
256.0 4096.0 14336.0 8.0 2.0 126.936021 1970.342294
512.0 4096.0 14336.0 8.0 2.0 206.943225 1605.343648
1024.0 4096.0 14336.0 8.0 2.0 246.641795 986.634938
2048.0 4096.0 14336.0 8.0 2.0 351.595861 798.413673
4096.0 4096.0 14336.0 8.0 2.0 415.472299 511.619561

mistral benchmark result:

Model M N K E top_k TFLOPS Bandwidth (GB/s)
mistral-7B 4096 28672 4096 8 2 409.995953 517.772213
mistral-7B 4096 4096 14336 8 2 408.037832 477.285160
mistral-22B 4096 32768 6144 8 2 411.230913 449.228758
mistral-22B 4096 6144 16384 8 2 410.775487 465.042536

references:

#435

@Chi-Chu319 Chi-Chu319 self-assigned this Dec 18, 2024
@Chi-Chu319 Chi-Chu319 requested a review from vgokhale December 20, 2024 15:37
@Chi-Chu319 Chi-Chu319 marked this pull request as ready for review December 20, 2024 15:40
The gemm support weights so as the benchmark. You can tune the gemm with
bencmark option -tune
@@ -0,0 +1,211 @@
from typing import TypedDict, List, Optional
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think I should have clarified this a bit more - my bad. I don't think we want to take the benchmarking stuff. We want this to remain with autotune like the other kernels here.

We do eventually want to go down the route of tuning like this, but not this comprehensively.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to revert these have have just the moe kernel and friends it needs to set up the inputs and autotune?

@Chi-Chu319 Chi-Chu319 requested a review from vgokhale January 3, 2025 16:36
{
"small_M": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this the best config it picked?

import argparse
import sys

M_THRESHOLD = 1024
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the threshold should be much smaller. Like 128.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will change it and check for the best config

return args


arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add int8/fp8 similar to FA? May be in a follow up PR?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there is a draft pr #693 I will open it once this is merged

allow_abbrev=False,
)
parser.add_argument("-M", type=int, default=0, help="M dimension")
parser.add_argument("-K", type=int, default=0, help="K dimension")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Juuso checked in his PR for the models, can you also add mixtral 22B and 7B? you can find their configs online.

@Chi-Chu319 Chi-Chu319 requested a review from vgokhale January 7, 2025 14:54
@vgokhale vgokhale requested a review from zhanglx13 January 10, 2025 22:35
M, N, K, top_k, E, routed_weight=routed_weight, dtype=dtype)

flops = 2.0 * M * top_k * K * N
if routed_weight:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a comment here on why this increase.

Also below for bytes calculation. Since it is not intuitive.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants