Skip to content

Commit

Permalink
fix falcon
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Nov 6, 2024
1 parent f70de82 commit ad18fc8
Showing 1 changed file with 105 additions and 3 deletions.
108 changes: 105 additions & 3 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ def patch_model_with_bettertransformer(model):
return model


def patch_update_causal_mask(model, transformers_version, inner_model_name="model"):
def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None):
if is_transformers_version(">=", transformers_version):
inner_model = getattr(model, inner_model_name, None)
if inner_model is not None:
if hasattr(inner_model, "_update_causal_mask"):
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
inner_model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, inner_model)
patch_fn = patch_fn or _llama_gemma_update_causal_mask
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)


def unpatch_update_causal_mask(model, inner_model_name="model"):
Expand Down Expand Up @@ -2431,14 +2432,115 @@ def __enter__(self):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)


def _falcon_update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: "Cache",
output_attentions: bool,
head_mask: torch.Tensor,
alibi: torch.Tensor,
):
# copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130
from transformers.cache_utils import StaticCache
from transformers.modeling_attn_mask_utils import AttentionMaskConverter

# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

if hasattr(self, "_prepare_4d_causal_attention_mask_with_cache_position"):
_prepare_4d_causal_attention_mask_with_cache_position = (
self._prepare_4d_causal_attention_mask_with_cache_position
)
else:
from transformers.models.falcon.modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position

if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None

# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if (
self.config._attn_implementation == "sdpa"
and not using_static_cache
and not output_attentions
and head_mask is None
and alibi is None
):
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None

dtype, device = input_tensor.dtype, input_tensor.device
# difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution
min_dtype = torch.finfo(torch.float16).min
batch_size, sequence_length, _ = input_tensor.shape
if using_static_cache:
target_length = past_key_values.get_max_length()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length
)

# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
min_dtype=min_dtype,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)

# We take care to integrate alibi bias in the causal_mask here
if head_mask is None and alibi is not None:
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
causal_mask = torch.masked_fill(
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
causal_mask < -1,
min_dtype,
)

if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

return causal_mask


class FalconModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version("<", "4.44.99"):
for layer in self._model.transformer.h:
_reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb)
else:
patch_update_causal_mask(self._model, "4.45.0", "transformer")
patch_update_causal_mask(self._model, "4.45.0", "transformer", _falcon_update_causal_mask)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
Expand Down

0 comments on commit ad18fc8

Please sign in to comment.