Skip to content

Commit

Permalink
Use enable_gqa in place of repeat_kv
Browse files Browse the repository at this point in the history
ghstack-source-id: e8781d3e797737c073bc487197ce804bce15502c
Pull Request resolved: #641
  • Loading branch information
awgu committed Oct 22, 2024
1 parent 0edd2fb commit 92ec349
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 38 deletions.
24 changes: 5 additions & 19 deletions torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,6 @@ def apply_rotary_emb(
return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
bs, slen, n_kv_heads, head_dim = x.shape
if n_rep == 1:
return x
return (
torch.unsqueeze(x, dim=3)
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
)


class Attention(nn.Module):
"""
Multi-head attention module.
Expand Down Expand Up @@ -198,16 +186,14 @@ def forward(

xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)
xv = xv.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)

# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
output = F.scaled_dot_product_attention(
xq, xk, xv, is_causal=True, enable_gqa=self.n_rep > 1
)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
Expand Down
24 changes: 5 additions & 19 deletions torchtitan/models/llama_multimodal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,6 @@ def apply_rotary_emb(
return xq_out.type_as(xq), xk_out.type_as(xk)


def repeat_kv(x: torch.Tensor, num_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=2, repeats=num_rep)"""
bsz, seq_len, num_kv_heads, head_dim = x.shape
if num_rep == 1:
return x
return (
torch.unsqueeze(x, dim=3)
.expand(bsz, seq_len, num_kv_heads, num_rep, head_dim)
.reshape(bsz, seq_len, num_kv_heads * num_rep, head_dim)
)


class Attention(nn.Module):
"""
Multi-head attention module.
Expand Down Expand Up @@ -222,16 +210,14 @@ def forward(
): # Only used in the self attention layers for text decoder
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

# repeat k/v heads if num_kv_heads < n_heads
keys = repeat_kv(xk, self.num_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(xv, self.num_rep) # (bs, seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xk = xk.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)
xv = xv.transpose(1, 2) # (bs, n_local_kv_heads, seqlen, head_dim)

# we use casual mask for training
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=self.is_causal)
output = F.scaled_dot_product_attention(
xq, xk, xv, is_causal=self.is_causal, enable_gqa=self.num_rep > 1
)
output = output.transpose(
1, 2
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
Expand Down

0 comments on commit 92ec349

Please sign in to comment.