Skip to content

Commit

Permalink
Support transformers 4.47 (#1088)
Browse files Browse the repository at this point in the history
* test 4.47

* update optimum

* patch gemma attn functions

* style

* force attn model

* latest qwen2 vl position_ids formula

* latest qwen2 vl position_ids formula

* revert
  • Loading branch information
IlyasMoutawwakil authored Jan 6, 2025
1 parent 753f84d commit bb1c68a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 23 deletions.
2 changes: 1 addition & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
)


FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"}
FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager", "gemma2": "sdpa"}

if TYPE_CHECKING:
from optimum.intel.openvino.configuration import OVConfig
Expand Down
20 changes: 0 additions & 20 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2827,26 +2827,6 @@ def patched_forward(*args, **kwargs):

self.patched_forward = patched_forward

def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.45.0"):
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES

sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
eager_attn = GEMMA2_ATTENTION_CLASSES["eager"]

for layer in self._model.model.layers:
if isinstance(layer.self_attn, eager_attn):
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
if is_transformers_version(">=", "4.45.0"):
for layer in self._model.model.layers:
if hasattr(layer.self_attn, "_orig_forward"):
layer.self_attn.forward = layer.self_attn._orig_forward


def _decilm_attn_forward(
self,
Expand Down
25 changes: 24 additions & 1 deletion optimum/intel/openvino/modeling_visual_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


if TYPE_CHECKING:
from PIL import Image
from PIL.Image import Image


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -2100,6 +2100,8 @@ def __init__(
quantization_config=quantization_config,
**kwargs,
)
self.rope_deltas = None # cache rope_deltas here

if is_transformers_version(">=", "4.45.0"):
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLForConditionalGeneration,
Expand Down Expand Up @@ -2197,6 +2199,7 @@ def get_multimodal_embeddings(
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
cache_position=None,
**kwargs,
):
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
Expand All @@ -2209,6 +2212,26 @@ def get_multimodal_embeddings(
video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw))
video_mask = input_ids == self.config.video_token_id
inputs_embeds[video_mask] = video_embeds

# if we get 4D attention mask we cannot calculate rope deltas anymore.
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
# calculate RoPE index once per generation in the pre-fill stage only
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids, image_grid_thw, video_grid_thw, attention_mask
)
self.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
else:
batch_size, seq_length, _ = inputs_embeds.shape
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
if cache_position is not None: # otherwise `deltas` is an int `0`
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
position_ids = position_ids.add(delta)
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

return inputs_embeds, attention_mask, position_ids

def forward(
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
INSTALL_REQUIRE = [
"torch>=1.11",
"optimum@git+https://github.com/huggingface/optimum.git",
"transformers>=4.36,<4.47",
"transformers>=4.36,<4.48",
"datasets>=1.4.0",
"sentencepiece",
"setuptools",
Expand Down

0 comments on commit bb1c68a

Please sign in to comment.