Skip to content

Commit

Permalink
prefill attn
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Jan 14, 2025
1 parent 4dd2e44 commit 372d3f8
Showing 1 changed file with 28 additions and 24 deletions.
52 changes: 28 additions & 24 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,16 +634,31 @@ def has_flash_attn(self, query):
elif query.device.type == "xpu":
return is_torch_version(">", "2.5.99")

def varlen_attn(self, query, key, value, past_key_value, input_lens):
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
if self.has_flash_attn(query):
def prefill_attn(self, query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens):
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn(query):
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
PagedAttention.flash_attn_varlen_func(
attn_output,
query,
key,
value,
key_cache,
value_cache,
seq_len_tensor,
seq_len_tensor,
input_lens.max(),
Expand All @@ -654,6 +669,9 @@ def varlen_attn(self, query, key, value, past_key_value, input_lens):
None,
)
else:
# prefill, remove padding
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
varlen_attention(
query.contiguous() if query.device.type == "xpu" else query,
key.contiguous() if key.device.type == "xpu" else key,
Expand Down Expand Up @@ -697,23 +715,9 @@ def forward(

if past_len == 0:
# prefill
if past_key_value is None:
n_rep = query.shape[1] // key.shape[1]
attn_output = torch.nn.functional.scaled_dot_product_attention(
query.reshape(input_lens.shape[0], input_lens.max().item(), -1, query.shape[-1]).transpose(1, 2),
key.reshape(input_lens.shape[0], input_lens.max().item(), -1, key.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
value.reshape(input_lens.shape[0], input_lens.max().item(), -1, value.shape[-1])
.transpose(1, 2)
.repeat_interleave(n_rep, 1),
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=True,
)
self.use_sdpa = True
else:
attn_output = self.varlen_attn(query, key_cache, value_cache, past_key_value, input_lens)
attn_output = self.prefill_attn(
query, key_cache, value_cache, key, value, past_key_value, attention_mask, input_lens
)
else:
# decode
attn_output = torch.empty_like(query)
Expand Down

0 comments on commit 372d3f8

Please sign in to comment.