Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM CPU] hgemm optimized for gqa #23107

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

[ARM CPU] hgemm optimized for gqa #23107

wants to merge 22 commits into from

Conversation

fajin-corp
Copy link
Contributor

@fajin-corp fajin-corp commented Dec 14, 2024

Description

Add fp16 kernels for GQA matmul on ARM CPU.
The kernels are mlas hgemm for C = alpha * A x B' + beta * C

Motivation and Context

Add fp16 support for GQA, speed up the operator and reduce memory usage.

Token Generation

HGEMM Runtime (ns) SGEMM Runtime (ns) Speed-up (%)
M:1/N:4096/K:4096 251551 1775905 85.84
M:1/N:11008/K:4096 892507 4649145 80.80
M:1/N:4096/K:11008 866860 3240015 73.25
M:1/N:11008/K:11008 2631615 8783877 70.04

Prompting

HGEMM Runtime (ns) SGEMM Runtime (ns) Speed-up (%)
M:1024/N:4096/K:4096 90508701 111283029 18.67
M:2048/N:4096/K:4096 181307522 240211107 24.52
M:1024/N:11008/K:4096 241120234 307707933 21.64
M:2048/N:11008/K:4096 481091232 648921367 25.86
M:1024/N:4096/K:11008 241736343 310129880 22.05
M:2048/N:4096/K:11008 480456703 644814999 25.49
M:1024/N:11008/K:11008 642121440 847925766 24.27
M:2048/N:11008/K:11008 1276097154 1731314509 26.29

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h Outdated Show resolved Hide resolved
@fajin-corp fajin-corp marked this pull request as ready for review January 23, 2025 23:06
@fajin-corp fajin-corp requested a review from a team as a code owner January 23, 2025 23:06
Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can commit the suggested changes from lintrunner.

onnxruntime/test/mlas/bench/bench_hgemm.cpp Outdated Show resolved Hide resolved
size_t k = CountK;
constexpr size_t step = 8 * 16; // pack 8 * 16
for (; k >= 8; k -= 8, b += 8, PackedB_data += step) {
float16x8_t v0 = MlasLoadFloat16x8(b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall #pragma unroll with loop be used?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants