-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tianxing/moe gemm #685
base: main_perf
Are you sure you want to change the base?
Tianxing/moe gemm #685
Conversation
cd06c06
to
b9a9504
Compare
The gemm support weights so as the benchmark. You can tune the gemm with bencmark option -tune
3d044a2
to
5adb971
Compare
@@ -0,0 +1,211 @@ | |||
from typing import TypedDict, List, Optional |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think I should have clarified this a bit more - my bad. I don't think we want to take the benchmarking stuff. We want this to remain with autotune like the other kernels here.
We do eventually want to go down the route of tuning like this, but not this comprehensively.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to revert these have have just the moe kernel and friends it needs to set up the inputs and autotune?
{ | ||
"small_M": { | ||
"BLOCK_SIZE_M": 16, | ||
"BLOCK_SIZE_N": 16, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this the best config it picked?
import argparse | ||
import sys | ||
|
||
M_THRESHOLD = 1024 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the threshold should be much smaller. Like 128.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will change it and check for the best config
return args | ||
|
||
|
||
arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we also add int8/fp8 similar to FA? May be in a follow up PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, there is a draft pr #693 I will open it once this is merged
allow_abbrev=False, | ||
) | ||
parser.add_argument("-M", type=int, default=0, help="M dimension") | ||
parser.add_argument("-K", type=int, default=0, help="K dimension") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since Juuso checked in his PR for the models, can you also add mixtral 22B and 7B? you can find their configs online.
M, N, K, top_k, E, routed_weight=routed_weight, dtype=dtype) | ||
|
||
flops = 2.0 * M * top_k * K * N | ||
if routed_weight: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be a comment here on why this increase.
Also below for bytes calculation. Since it is not intuitive.
Implemented moe gemm, test and benchmarking.
run
python python/perf-kernels/flash-attention.py -model all
to benchmark with the mistral modelsrun
python python/perf-kernels/flash-attention.py -routed_weight
to benchmark with the routed weightThe benchmark shows the memory bandwidth
benchmark result:
mistral benchmark result:
references:
https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/fused_moe/fused_moe.py#L261
https://github.com/vllm-project/vllm/blob/main/csrc/moe/moe_align_sum_kernels.cu
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD
.Select one of the following.
/test
forlit
tests/unittest
for C++ tests/python/test
for end-to-end testsFILL THIS IN
.Select one of the following.
lit
tests.lit
tests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)
#435