Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for layouts #587

Merged
merged 10 commits into from
May 20, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 98 additions & 109 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
import triton
import triton.language as tl

torch_dtype: tl.constexpr = torch.float16

TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz')
if TORCH_HAS_FP8E5:
torch_dtype: tl.constexpr = torch.float8_e5m2fnuz


class MetaData():
cu_seqlens_q = None
Expand All @@ -43,13 +37,15 @@ class MetaData():
causal = False
num_contexts = 0
varlen = False
layout = None
dropout_p, return_encoded_softmax = 0.0, False

def __init__(self, sm_scale=1.0):
self.sm_scale = sm_scale

def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k):
self.varlen = True
self.layout = 'thd'
self.cu_seqlens_q = cu_seqlens_q
self.cu_seqlens_k = cu_seqlens_k
# Without "varlen", there should still be one sequence.
Expand Down Expand Up @@ -108,6 +104,8 @@ def check_args(self, q, k, v, o):
assert head_size <= 256
assert o.shape == q.shape
assert (nheads_q % nheads_k) == 0
assert self.layout is not None
assert self.layout == 'thd' or not self.varlen


@triton.jit
Expand Down Expand Up @@ -316,60 +314,14 @@ def _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, start_m, actual_
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(
Q,
K,
V,
bias,
sm_scale,
L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
stride_bz,
stride_bh,
stride_bm,
stride_bn,
stride_az,
stride_ah,
cu_seqlens_q,
cu_seqlens_k,
dropout_p,
philox_seed,
philox_offset_base,
encoded_softmax,
alibi_slopes,
HQ: tl.constexpr,
HK: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr,
VARLEN: tl.constexpr,
IS_CAUSAL: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
PRE_LOAD_V: tl.constexpr,
BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr,
RETURN_ENCODED_SOFTMAX: tl.constexpr,
USE_ALIBI: tl.constexpr,
BATCH_SIZE: tl.constexpr,
):
def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh,
stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om,
stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, BIAS_TYPE: tl.constexpr,
ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr):
start_m = tl.program_id(0)
off_h_q = tl.program_id(1)
off_z = tl.program_id(2)
Expand Down Expand Up @@ -875,6 +827,34 @@ def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D,
empty = torch.empty(128, device="cuda")


# TODO: This can probably optimized to have fewer lines of code.
def get_strides_from_layout(metadata, q, k, v, o):
if metadata.layout == 'thd':
batch, nheads_q = metadata.num_contexts, q.shape[1]
nheads_k = k.shape[1]
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
elif metadata.layout == 'bhsd':
batch, nheads_q, _, head_size = q.shape
nheads_k = k.shape[1]
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))
elif metadata.layout == 'bshd':
batch, _, nheads_q, head_size = q.shape
nheads_k = k.shape[2]
q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3))
k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3))
v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3))
o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3))
else:
assert False, 'Got unsupported layout.'
return batch, nheads_q, nheads_k, q_strides, k_strides, v_strides, o_strides


class _attention(torch.autograd.Function):

@staticmethod
Expand All @@ -886,24 +866,15 @@ def forward(ctx, q, k, v, o, metadata):
if o is None:
o = torch.empty_like(q, dtype=v.dtype)
metadata.check_args(q, k, v, o)
if metadata.varlen:
total_q, nheads_q, head_size = q.shape
total_k, nheads_k, _ = k.shape
batch = metadata.num_contexts
q_strides = (0, q.stride(1), q.stride(0), q.stride(2))
k_strides = (0, k.stride(1), k.stride(0), k.stride(2))
v_strides = (0, v.stride(1), v.stride(0), v.stride(2))
o_strides = (0, o.stride(1), o.stride(0), o.stride(2))
else:
batch, nheads_q, seqlen_q, head_size = q.shape
_, nheads_k, seqlen_k, _ = k.shape
q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3))
k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3))
v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3))
o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3))

batch, nheads_q, nheads_k, q_strides, k_strides, v_strides, o_strides = \
get_strides_from_layout(metadata, q, k, v, o)
head_size = q.shape[-1]

# Get closest power of 2 over or equal to 32.
padded_d_model = 1 << (head_size - 1).bit_length()
# Smallest head_dim supported is 16. If smaller, the tile in the
# kernel is padded - there is no padding in memory for any dims.
padded_d_model = max(padded_d_model, 16)

grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch)
Expand Down Expand Up @@ -943,7 +914,7 @@ def forward(ctx, q, k, v, o, metadata):
MAX_SEQLENS_K=metadata.max_seqlens_k, IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen,
BLOCK_DMODEL=padded_d_model, BIAS_TYPE=0 if metadata.bias is None else 1,
USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p
> 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, BATCH_SIZE=q.shape[0])
> 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax)

ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
Expand Down Expand Up @@ -1035,38 +1006,47 @@ def backward(ctx, do, _):
attention = _attention.apply


def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype):
def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout):
torch.manual_seed(20)

# Initialize q, k, v
q = torch.randn((Z, HQ, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn((Z, HK, N_CTX_K, D_HEAD), dtype=dtype, device="cuda", requires_grad=True)
if layout == 'bhsd':
q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD)
k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD)
elif layout == 'bshd':
q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD)
k_tensor_shape = (Z, N_CTX_K, HQ, D_HEAD)
else:
assert False, 'Got unsupported tensor layout'
q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=True)
k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True)
v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=True)
sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = N_CTX_Q
input_metadata.max_seqlens_k = N_CTX_K
input_metadata.layout = layout
return q, k, v, input_metadata


def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype):
def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False):
torch.manual_seed(20)

# Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32)
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32)
max_seqlens_q = torch.max(seqlens_q).item()
max_seqlens_k = torch.max(seqlens_k).item()
if not equal_seqlens:
max_seqlens_q = N_CTX_Q // Z
max_seqlens_k = N_CTX_K // Z
seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32)
seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32)
else:
seqlens_q = torch.full((Z, ), N_CTX_Q // Z)
seqlens_k = torch.full((Z, ), N_CTX_K // Z)

# Calculate cumulative sequence lengths
cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)])
cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)])
cu_seqlens_q = cu_seqlens_q.to(device="cuda")
cu_seqlens_k = cu_seqlens_k.to(device="cuda")
# -1 because the last entry of cu_seqlens_q specifies the end of the last seq
# num_ctxs = len(cu_seqlens_q) - 1

# Initialize q, k, v with variable lengths
total_q = cu_seqlens_q[-1].item()
Expand Down Expand Up @@ -1102,7 +1082,9 @@ def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype):
@pytest.mark.parametrize('use_alibi', [True, False])
def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=torch.float16):
torch.manual_seed(20)
q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype)
# TODO: Adapt test for bshd
layout = 'bhsd'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add the layouts quickly for some reason? It would be nice if we could test at least one other layout. If we need to move quick, I am fine adding the tests later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, not necessary to get it in quickly. I guess I was just being lazy. Good to have tests - Done.

q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout)
if causal:
input_metadata.need_causal()

Expand All @@ -1114,9 +1096,6 @@ def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, dtype=to
else:
alibi_slopes = None

if TORCH_HAS_FP8E5:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
o = torch.empty_like(q)

# triton implementation
Expand Down Expand Up @@ -1185,9 +1164,6 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor
q = torch.randn((Z, H, N_CTX_Q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
k = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
v = torch.randn((Z, H, N_CTX_K, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()
if TORCH_HAS_FP8E5:
q = q.to(torch_dtype)
k = k.to(torch_dtype)
o = torch.empty_like(q)

# triton implementation
Expand Down Expand Up @@ -1218,9 +1194,8 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype=tor
(4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)])
@pytest.mark.parametrize('causal', [True, False])
def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
pytest.skip()

q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, D_HEAD, dtype)
q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype)
tri_out = torch.empty_like(q)
ref_out = torch.empty_like(q)

Expand Down Expand Up @@ -1409,20 +1384,19 @@ def varlen_benchmark_configs():
return configs


def run_benchmark(custom):
def run_benchmark(custom, args):

args = parse_args()
dtype = arg_to_torch_dtype[args.dtype]
# hk = args.hq if not args.hk else args.hk
# sk = args.sq if not args.sk else args.sk
hk = args.hq if not args.hk else args.hk
sk = args.sq if not args.sk else args.sk
head_size = 128 if not args.d else args.d
mode = 'fwd'
x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K']
causal = args.causal
varlen = args.varlen
varlen = args.layout == 'thd'
configs = []
if custom:
x_vals_list = [(args.b, args.hq, args.hk, args.sq, args.sk)]
x_vals_list = [(args.b, args.hq, hk, args.sq, sk)]
else:
if varlen:
x_vals_list = varlen_benchmark_configs()
Expand All @@ -1433,7 +1407,7 @@ def run_benchmark(custom):
configs.append(
triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=['triton'],
line_names=[line_names], styles=[('red', '-')], ylabel='ms',
plot_name=f'fused-attention-{mode}-d{head_size}{"-varlen" if varlen else ""}',
plot_name=f'fused-attention-{mode}-d{head_size}-layout{args.layout}',
args={'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode}))

@triton.testing.perf_report(configs)
Expand All @@ -1455,14 +1429,15 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal

flops_per_matmul = 0
if varlen:
q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype)
q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype,
args.equal_seqlens)
for i in range(0, input_metadata.num_contexts):
seqlen_q = input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i]
seqlen_k = input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i]
# x2 for 2 GEMMs
flops_per_matmul += seqlen_q.item() * seqlen_k.item() * HQ * D_HEAD * 2
else:
q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype)
q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout)
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
if causal:
input_metadata.need_causal()
Expand All @@ -1487,6 +1462,15 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
bench_flash_attention.run(save_path=".", print_data=True)


def supported_layouts():
layouts = \
'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \
'bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]' \
'thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]' \
'This layout is sometimes called "varlen" or "grouped" layout.'
return layouts


def parse_args():
parser = argparse.ArgumentParser(
prog="Benchmark FlashAttention",
Expand All @@ -1497,11 +1481,14 @@ def parse_args():
parser.add_argument("-hk", type=int, default=0)
parser.add_argument("-sq", type=int, default=0)
parser.add_argument("-sk", type=int, default=0)
parser.add_argument("-equal_seqlens", action='store_true', default=False,
help='If specified, each context within the thd layout' \
' has same seqlen as sq and sk')
parser.add_argument("-d", type=int, default=0)
parser.add_argument("-causal", action='store_true', default=False)
parser.add_argument("-varlen", action='store_true', default=False)
parser.add_argument("-dtype", default='fp16')
parser.add_argument("-return_time", action='store_true', default=False)
parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts())
return parser.parse_args()


Expand All @@ -1511,6 +1498,8 @@ def parse_args():
def main():
args = parse_args()
custom_config = False
assert args.layout == 'thd' or not args.equal_seqlens, \
"Equal sequence lengths arg must be used with the thd layout."
if args.b or args.hq or args.hk or args.sq or args.sk or args.d:
custom_config = True
assert args.b and args.hq and args.sq and args.d, \
Expand All @@ -1521,7 +1510,7 @@ def main():
assert args.dtype in arg_to_torch_dtype, \
"Only fp16, bf16 and f32 types currently supported."

run_benchmark(custom_config)
run_benchmark(custom_config, args)


if __name__ == '__main__':
Expand Down
Loading