From 35c47a2252e63dcb284a666c0bf9ef497d05c8c5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 6 Jan 2025 09:32:57 +0100 Subject: [PATCH] force attn model --- optimum/exporters/openvino/__main__.py | 2 +- optimum/exporters/openvino/model_patcher.py | 46 +++------------------ 2 files changed, 7 insertions(+), 41 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index 592cd85a4..859360e8b 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -49,7 +49,7 @@ ) -FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"} +FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager", "gemma2": "sdpa"} if TYPE_CHECKING: from optimum.intel.openvino.configuration import OVConfig diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index a43b48ae3..b19525810 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -421,9 +421,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -2058,9 +2058,9 @@ def _dbrx_update_causal_mask_legacy( offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype - causal_mask[ - : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] - ] = mask_slice + causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( + mask_slice + ) if ( self.config._attn_implementation == "sdpa" @@ -2710,40 +2710,6 @@ def patched_forward(*args, **kwargs): self.patched_forward = patched_forward - def __enter__(self): - super().__enter__() - - if is_transformers_version(">=", "4.47.0"): - from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION - - GEMMA2_ATTENTION_FUNCTION["original_eager"] = GEMMA2_ATTENTION_FUNCTION["eager"] - GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["sdpa"] - - elif is_transformers_version(">=", "4.45.0"): - from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES - - sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"] - eager_attn = GEMMA2_ATTENTION_CLASSES["eager"] - - for layer in self._model.model.layers: - if isinstance(layer.self_attn, eager_attn): - layer.self_attn._orig_forward = layer.self_attn.forward - layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn) - - def __exit__(self, exc_type, exc_value, traceback): - super().__exit__(exc_type, exc_value, traceback) - - if is_transformers_version(">=", "4.47.0"): - from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_FUNCTION - - GEMMA2_ATTENTION_FUNCTION["eager"] = GEMMA2_ATTENTION_FUNCTION["original_eager"] - del GEMMA2_ATTENTION_FUNCTION["original_eager"] - - elif is_transformers_version(">=", "4.45.0"): - for layer in self._model.model.layers: - if hasattr(layer.self_attn, "_orig_forward"): - layer.self_attn.forward = layer.self_attn._orig_forward - def _decilm_attn_forward( self,