From c6c4a2558a3fa957b299e1b30414001b60bbbeb8 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 6 Jan 2025 10:45:32 +0100 Subject: [PATCH] latest qwen2 vl position_ids formula --- optimum/exporters/openvino/model_patcher.py | 12 ++++---- .../openvino/modeling_visual_language.py | 28 ++++++++++++++++--- 2 files changed, 30 insertions(+), 10 deletions(-) diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index b19525810..7a6b2998c 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" @@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( - mask_slice - ) + causal_mask[ + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] + ] = mask_slice if ( self.config._attn_implementation == "sdpa" diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index fe85f9212..d0b281e19 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -53,7 +53,7 @@ if TYPE_CHECKING: - from PIL import Image + from PIL.Image import Image logger = logging.getLogger(__name__) @@ -166,9 +166,6 @@ def prepare_inputs( if past_len: position_ids = position_ids[:, -inputs_embeds.shape[1] :] - if self.config.model_type == "qwen2_vl" and position_ids.ndim != 3: - position_ids = np.repeat(np.expand_dims(position_ids, 0), 3, axis=0) - inputs["position_ids"] = position_ids if "beam_idx" in self.input_names: @@ -2100,6 +2097,8 @@ def __init__( quantization_config=quantization_config, **kwargs, ) + self.rope_deltas = None # cache rope_deltas here + if is_transformers_version(">=", "4.45.0"): from transformers.models.qwen2_vl.modeling_qwen2_vl import ( Qwen2VLForConditionalGeneration, @@ -2197,6 +2196,7 @@ def get_multimodal_embeddings( pixel_values_videos=None, image_grid_thw=None, video_grid_thw=None, + cache_position=None, **kwargs, ): inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids)) @@ -2209,6 +2209,26 @@ def get_multimodal_embeddings( video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw)) video_mask = input_ids == self.config.video_token_id inputs_embeds[video_mask] = video_embeds + + # if we get 4D attention mask we cannot calculate rope deltas anymore. + if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + return inputs_embeds, attention_mask, position_ids def forward(