Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FA3 KV Cache is slower than FA2 KV Cache #1465

Open
DD-DuDa opened this issue Jan 26, 2025 · 2 comments
Open

FA3 KV Cache is slower than FA2 KV Cache #1465

DD-DuDa opened this issue Jan 26, 2025 · 2 comments

Comments

@DD-DuDa
Copy link

DD-DuDa commented Jan 26, 2025

GPU I use is NVIDIA H100 80GB HBM3.
I try to benchmark the performance of flash-decoding.

And the figure I show is under the parameter by:

batch_size = 1
q_len = 1
nheads_q = 32
nheads_kv = 8
dim = 128

Image

And I've checked other settings, and cannot reproduce PR#1236

@tridao
Copy link
Contributor

tridao commented Jan 26, 2025

Can you post a short script to benchmark the two?

@DD-DuDa
Copy link
Author

DD-DuDa commented Jan 26, 2025

Hi, thanks for quick response!

Here is the code (Actually I don't know how to install FA2 and FA3 in the same conda env):

import torch
import torch.nn as nn
import numpy as np
# from flash_attn_interface import flash_attn_with_kvcache # FA3
from flash_attn import flash_attn_with_kvcache # FA2
import math
import triton


@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=["seq_len"],
        x_vals=[2**i for i in range(10, 20, 1)],
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['flash-attn-v3'],  # Possible values for `line_arg`.
        line_names=['flash-attn-v3'],  # Label name for the lines.
        styles=[('red', '-')],  # Line color and style.
        plot_name="decoding benchmark",
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    )
)
def benchmark(seq_len, provider):
    torch.random.manual_seed(0)
    device = "cuda"
    dtype = torch.float16

    batch_size = 1
    nheads = 32
    nheads_k = 8
    d = 128

    q = torch.randn(batch_size, 1, nheads, d, device=device, dtype=dtype)
    k_cache = torch.randn(batch_size, seq_len, nheads_k, d, device=device, dtype=dtype)
    v_cache = torch.randn(batch_size, seq_len, nheads_k, d, device=device, dtype=dtype)

    quantiles = [0.5, 0.2, 0.8]

    if provider == 'flash-attn-v3':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: flash_attn_with_kvcache(q, k_cache, v_cache), quantiles=quantiles)
    

    perf = lambda ms: 4 * seq_len * nheads * d * 1e-12 / (ms * 1e-3)
    return ms, min_ms, max_ms

benchmark.run(show_plots=True, print_data=True)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants