Skip to content

Commit

Permalink
Merge branch 'main_perf' into tianxing/moe-gemm
Browse files Browse the repository at this point in the history
  • Loading branch information
Chi-Chu319 authored Dec 20, 2024
2 parents 9a43c1c + 4a7afd2 commit cefc74e
Show file tree
Hide file tree
Showing 22 changed files with 2,209 additions and 512 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/amd_perf_kernel_Integration_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/triton" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm.gfx90a"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx942"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand Down Expand Up @@ -100,7 +100,7 @@ jobs:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
container:
image: rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.4
image: rocm/pytorch:latest
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/amd_perf_kernel_postmerge_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"ROCm/triton" ]; then
echo '::set-output name=matrix-HIP::[["self-hosted", "rocm.gfx90a"]]'
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx942"]]'
else
echo '::set-output name=matrix-HIP::[["ubuntu-latest"]]'
fi
Expand All @@ -41,7 +41,7 @@ jobs:
matrix:
runner: ${{fromJson(needs.Runner-Preparation-AMD.outputs.matrix-HIP)}}
container:
image: rocm/pytorch:rocm6.0.2_ubuntu22.04_py3.10_pytorch_2.1.2
image: rocm/pytorch:latest
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --user root
steps:
- name: Checkout
Expand Down
19 changes: 19 additions & 0 deletions python/perf-kernels/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,28 @@ This script contains the Flash Attention kernel with the following support
- Multi and Grouped Query attention
- ALiBi bias
- Matrix bias
- Persistent kernels. Useful when the sequence lengths are up to a moderate length and especially when doing causal attention.
- Int8 quantization

These are currently supported for the forward kernel only.

INT8 Quantization Support

1. <em>q_descale</em>, <em>k_descale</em>, and <em>v_descale</em> provided:
- The first QK GEMM runs in INT8, then the output is dequantized to the specified <em>dtype</em>.
- The second PV GEMM runs in the specified <em>dtype</em>.

2. <em>q_descale</em>, <em>k_descale</em>, <em>p_descale</em>, and <em>v_descale</em> provided:
- Both the first and second GEMM operations run in INT8.
- The results are dequantized to the specified <em>dtype</em> after both GEMMs.

3. Only <em>k_descale</em> and <em>v_descale</em> provided:
- K and V are dequantized before the first and second GEMM operations, respectively.
- Both GEMMs run in the specified <em>dtype</em>.

Note: The softmax operation is always performed in <em>fp32</em>.


## `06-attention-decode.py`

This contains the Flash Decoding kernel.
Expand Down
980 changes: 704 additions & 276 deletions python/perf-kernels/flash-attention.py

Large diffs are not rendered by default.

Loading

0 comments on commit cefc74e

Please sign in to comment.