Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
Signed-off-by: jiqing-feng <[email protected]>
  • Loading branch information
jiqing-feng committed Jan 15, 2025
2 parents 06798e2 + c6d2d0f commit 00e6bf3
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 7 deletions.
6 changes: 5 additions & 1 deletion optimum/exporters/ipex/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from transformers import Cache, PretrainedConfig


# May need to tune based on sequence length and different models but default to 16 currently.
BLOCK_SIZE = 16


class IPEXPagedCache(Cache):
"""
A PagedCache that grows dynamically as more tokens are generated. everytime it grows block-size memory, vendor could set the pageCache memory layout.
Expand Down Expand Up @@ -44,7 +48,7 @@ def __init__(
self.batch_size = batch_size
# Used in `generate` to keep tally of how many tokens the cache has seen
self._seen_tokens = torch.zeros([batch_size], dtype=torch.int32, device=device)
self.block_size = 64
self.block_size = BLOCK_SIZE
self.num_blocks = (max_cache_len // self.block_size + (max_cache_len % self.block_size != 0)) * batch_size
self.block_tables = -1 * torch.ones([self.num_blocks], dtype=torch.int32, device=device).reshape(
batch_size, -1
Expand Down
2 changes: 2 additions & 0 deletions optimum/exporters/ipex/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from transformers.models.bert.modeling_bert import BertIntermediate
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.gpt2.modeling_gpt2 import GPT2MLP, GPT2Attention, GPT2Block, GPT2Model
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer,
LlamaModel,
Expand All @@ -33,6 +34,7 @@
from .modeling_utils import (
_IPEX_MINIMUM_VERSION_FOR_PATCHING,
_IPEXGPT2MLP,
_IPEXGPT2MLP,
_falcon_model_forward,
_gpt2_block_forward,
_gpt2_model_forward,
Expand Down
8 changes: 3 additions & 5 deletions optimum/exporters/ipex/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,19 +769,17 @@ def attention_interface(
is_causal=True,
)
self.use_sdpa = True
elif self.has_flash_attn(query):
elif self.has_flash_attn(query) and past_len == 0:
attn_output = torch.empty_like(query)
seq_len_tensor = torch.cat((input_lens.new_tensor([0]), input_lens.cumsum(-1).int()))
query_len_tensor = seq_len_tensor if past_len == 0 else torch.arange(seq_len_tensor.shape[0]).int()
query_max_len = input_lens.max() if past_len == 0 else 1
PagedAttention.flash_attn_varlen_func(
attn_output,
query.contiguous() if query.device.type == "xpu" else query,
key_cache.contiguous() if key_cache.device.type == "xpu" else key_cache,
value_cache.contiguous() if value_cache.device.type == "xpu" else value_cache,
query_len_tensor,
seq_len_tensor,
query_max_len,
seq_len_tensor,
input_lens.max(),
input_lens.max(),
1.0 / math.sqrt(self.head_dim),
True,
Expand Down
9 changes: 8 additions & 1 deletion optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
compare_versions,
is_diffusers_version,
is_openvino_tokenizers_version,
is_openvino_version,
is_tokenizers_version,
is_transformers_version,
)
Expand Down Expand Up @@ -366,6 +367,7 @@ def export_pytorch(
import torch
from torch.utils._pytree import tree_map

from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
from optimum.exporters.utils import check_dummy_inputs_are_allowed

logger.info(f"Using framework PyTorch: {torch.__version__}")
Expand Down Expand Up @@ -428,15 +430,20 @@ def ts_patched_forward(*args, **kwargs):

patcher.patched_forward = ts_patched_forward

ts_decoder_kwargs = {}
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
model,
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)
Expand Down

0 comments on commit 00e6bf3

Please sign in to comment.