From 74ee7eb4671cd5ab7e2af0b1c897873047695422 Mon Sep 17 00:00:00 2001 From: Ekaterina Aidova Date: Wed, 15 Jan 2025 13:07:38 +0400 Subject: [PATCH] Merge decoder and decoder with past to stateful for seq2seq (#1078) * merge decoder and decoder with past to stateful for seq2seq * fix quantization * fix loading decoder_with_past * fix quant tests * fix tests * fix more tests * make input dynamic and enable sdpa * review comments and kv cache compression disable in fp * fix task recognition * fix quantization tests * respect from_onnx * update test to check that stateful expected * Apply suggestions from code review Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --------- Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> --- optimum/exporters/openvino/__main__.py | 6 +- optimum/exporters/openvino/convert.py | 64 ++++++- optimum/exporters/openvino/model_configs.py | 70 +++++++ optimum/exporters/openvino/model_patcher.py | 53 +++++- optimum/exporters/openvino/stateful.py | 95 ++++++++- .../intel/openvino/modeling_base_seq2seq.py | 69 +++++-- optimum/intel/openvino/modeling_seq2seq.py | 180 ++++++++++++++---- optimum/intel/openvino/quantization.py | 44 +++-- tests/openvino/test_exporters_cli.py | 8 +- tests/openvino/test_modeling.py | 27 ++- tests/openvino/test_quantization.py | 27 ++- 11 files changed, 552 insertions(+), 91 deletions(-) diff --git a/optimum/exporters/openvino/__main__.py b/optimum/exporters/openvino/__main__.py index a5cfb02615..110ed48515 100644 --- a/optimum/exporters/openvino/__main__.py +++ b/optimum/exporters/openvino/__main__.py @@ -274,7 +274,11 @@ def main_export( f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}." ) - if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED: + if ( + is_transformers_version(">=", "4.36") + and is_transformers_version("<=", "4.45.0") + and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED + ): loading_kwargs["attn_implementation"] = "eager" # some models force flash_attn attention by default that does not support load model on cpu diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index 57e94fc0b5..3ab2bbd550 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -28,10 +28,13 @@ from openvino.tools.ovc import convert_model from optimum.exporters import TasksManager from optimum.exporters.utils import ( - _get_submodels_and_export_configs as _default_get_submodels_and_export_configs, + DECODER_NAME, + ENCODER_NAME, + _get_submodels_for_export_encoder_decoder, + get_diffusion_models_for_export, ) from optimum.exporters.utils import ( - get_diffusion_models_for_export, + _get_submodels_and_export_configs as _default_get_submodels_and_export_configs, ) from optimum.intel.utils.import_utils import ( _diffusers_version, @@ -106,10 +109,13 @@ def _set_runtime_options( "diffusers" in library_name or "text-generation" in task or ("image-text-to-text" in task and model_name == "language_model") + or getattr(sub_export_config, "stateful", False) ): sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0" if not quantized_model and ( - "text-generation" in task or ("image-text-to-text" in task and model_name == "language_model") + "text-generation" in task + or ("image-text-to-text" in task and model_name == "language_model") + or getattr(sub_export_config, "stateful", False) ): sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16" @@ -642,10 +648,14 @@ def export_from_model( logger.info(f"Automatic task detection to: {task}.") + is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False) + model_type = getattr(getattr(model, "config", {}), "model_type", "") stateful = stateful and ( - ensure_export_task_support_stateful(task) - or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", "")) + ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type) ) + + if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False): + stateful = False # TODO: support onnx_config.py in the model repo if custom_architecture and custom_export_configs is None: raise ValueError( @@ -687,6 +697,11 @@ def export_from_model( if library_name == "diffusers": export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino") stateful_submodels = False + elif stateful and is_encoder_decoder and not custom_architecture: + export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export( + model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default" + ) + stateful_submodels = [False, True] else: logging.disable(logging.INFO) export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs( @@ -1221,3 +1236,42 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype): models_for_export["text_encoder_2"] = (text_encoder_2, export_config) return models_for_export + + +def _get_encoder_decoder_stateful_models_for_export( + model: Union["PreTrainedModel", "TFPreTrainedModel"], + task: str, + _variant: str, + library_name: str, + int_dtype: str = "int64", + float_dtype: str = "fp32", + preprocessors: Optional[List[Any]] = None, +): + export_config_constructor = TasksManager.get_exporter_config_constructor( + model=model, exporter="openvino", task=task, library_name=library_name + ) + export_config = export_config_constructor( + model.config, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + legacy=False, + ) + + export_config.variant = _variant + all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()]) + logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}") + + models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False) + + encoder_export_config = export_config.with_behavior("encoder") + models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config) + + decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True) + + decoder_export_config_with_past.stateful = True + models_for_export[DECODER_NAME] = ( + models_for_export[DECODER_NAME], + decoder_export_config_with_past, + ) + return None, models_for_export diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index 15a712cff6..4b1dbb50b8 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -20,6 +20,7 @@ from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel from transformers.utils import is_tf_available +from optimum.exporters.onnx.base import ConfigBehavior from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig from optimum.exporters.onnx.model_configs import ( CLIPOnnxConfig, @@ -38,8 +39,10 @@ MistralOnnxConfig, MPTOnnxConfig, PhiOnnxConfig, + T5OnnxConfig, UNetOnnxConfig, VisionOnnxConfig, + WhisperOnnxConfig, ) from optimum.exporters.onnx.model_patcher import ModelPatcher from optimum.exporters.tasks import TasksManager @@ -102,6 +105,7 @@ Qwen2VLVisionEmbMergerPatcher, QwenModelPatcher, RotaryEmbPatcher, + StatefulSeq2SeqDecoderPatcher, UpdateCausalMaskModelPatcher, XverseModelPatcher, ) @@ -2611,3 +2615,69 @@ def patch_model_for_export( self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None ) -> "ModelPatcher": return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs) + + +@register_in_tasks_manager( + "whisper", + *[ + "feature-extraction", + "feature-extraction-with-past", + "audio-classification", + "automatic-speech-recognition", + "automatic-speech-recognition-with-past", + ], + library_name="transformers", +) +class WhisperOpenVINOConfig(WhisperOnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER: + return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + @property + def inputs(self): + common_inputs = super().inputs + if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER: + common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + return common_inputs + + +@register_in_tasks_manager( + "t5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class T5OpenVINOConfig(T5OnnxConfig): + def patch_model_for_export( + self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None + ) -> ModelPatcher: + if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER: + return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + @property + def inputs(self): + common_inputs = super().inputs + if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER: + common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"} + return common_inputs + + +@register_in_tasks_manager( + "mt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class MT5OpenVINOConfig(T5OpenVINOConfig): + pass + + +@register_in_tasks_manager( + "longt5", + *["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"], + library_name="transformers", +) +class LongT5OpenVINOConfig(T5OpenVINOConfig): + pass diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index c12c0099ae..e7a7779389 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -24,7 +24,12 @@ from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling from transformers.utils import is_tf_available -from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments +from optimum.exporters.onnx.model_patcher import ( + DecoderModelPatcher, + ModelPatcher, + Seq2SeqModelPatcher, + override_arguments, +) from optimum.intel.utils.import_utils import ( _openvino_version, _torch_version, @@ -3740,3 +3745,49 @@ def __exit__(self, exc_type, exc_value, traceback): if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa": for layer in self._model.transformer.h: layer.attn._attn = layer.attn._orig_attn + + +class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: Union["PreTrainedModel", "TFPreTrainedModel"], + model_kwargs: Optional[Dict[str, Any]] = None, + ): + model.__orig_forward = model.forward + + @functools.wraps(model.__orig_forward) + def patched_forward(*args, **kwargs): + from transformers.cache_utils import EncoderDecoderCache + + signature = inspect.signature(self.orig_forward) + args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs) + + return_legacy_cache = False + pkv_in_args = False + legacy_pkv = None + if "past_key_values" in kwargs: + legacy_pkv = kwargs.pop("past_key_values", None) + sign_names = list(signature.parameters.keys()) + pkv_argument_index = sign_names.index("past_key_values") + if legacy_pkv is None and len(args) > pkv_argument_index: + legacy_pkv = args[pkv_argument_index] + pkv_in_args = True + if legacy_pkv is not None: + only_self_cache = [cache_item[:2] for cache_item in legacy_pkv] + pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache) + return_legacy_cache = True + if not pkv_in_args: + kwargs["past_key_values"] = pkv + else: + args[pkv_argument_index] = pkv + + outputs = model.__orig_forward(*args, **kwargs) + if return_legacy_cache: + outputs.past_key_values = outputs.past_key_values.to_legacy_cache() + + return outputs + + model.forward = patched_forward + + super().__init__(config, model, model_kwargs) diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 39d64c2aec..a367ea8f00 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -190,18 +190,95 @@ def ensure_stateful_is_available(warn=True): return True +_ENCODER_DECODER_TASKS_WITH_PAST = ( + "automatic-speech-recognition", + "text2text-generation", +) + +_DECODER_TASKS_WITH_PAST = ("text-generation",) + + def ensure_export_task_support_stateful(task: str): from optimum.exporters import TasksManager task = TasksManager.map_from_synonym(task) - return task in ["text-generation-with-past"] + + is_stateful = ( + task.endswith("-with-past") + and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + _DECODER_TASKS_WITH_PAST + ) + return is_stateful def ensure_model_type_support_stateful(model_type: str): return model_type.replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS -def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name: str = "input_ids"): +def remove_parameters_by_names(model: ov.Model, names: list): + parameters = [model.input(name).get_node() for name in names] + for p in parameters: + model.remove_parameter(p) + + +def get_input_nodes(node): + return [input.get_node() for input in node.input_values()] + + +def find_dependent_nodes(model: ov.Model, sources: list): + # Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources` + result = set(sources) + for node in model.get_ordered_ops(): + input_nodes = set(get_input_nodes(node)) + if input_nodes & result: + result.add(node) + return result + + +def get_read_value_ops(model: ov.Model): + return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"] + + +def get_shape_of_ops(model: ov.Model): + return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"] + + +def get_consumer_nodes(node): + consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()]) + return {input.get_node() for input in consumer_inputs} + + +def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list): + # Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's + other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources) + other_nodes = find_dependent_nodes(model, other_inputs) + source_dependent_nodes = find_dependent_nodes(model, sources) + # TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph + nodes = source_dependent_nodes - other_nodes + edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes] + return edge_nodes + + +def insert_state_for_nodes(model: ov.Model, nodes): + # For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression + outputs = sum((node.outputs() for node in nodes), []) + for output in outputs: + consumers = output.get_target_inputs() + # FIXME: get_any_name is not reliable as tensor may not have any names + variable_id = output.get_any_name() + read_value = ov.runtime.opset13.read_value(output, variable_id) + for consumer in consumers: + consumer.replace_source_output(read_value.output(0)) + assign = ov.runtime.opset13.assign(read_value, variable_id) + model.add_sinks([assign]) + + +def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): + if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"): + return patch_stateful_encoder_decoder(config, ov_model) + return patch_stateful_decoder(config, ov_model) + + +def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model): """ Apply stateful transformation to model to hide key values inputs inside model. Select transformation parameters based on model architecture @@ -236,3 +313,17 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name make_stateful( ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None ) + + +def patch_stateful_encoder_decoder(config, ov_model): + encoder_key_value_input_names = [ + key.get_any_name() + for key in ov_model.inputs + if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names()) + ] + remove_parameters_by_names(ov_model, encoder_key_value_input_names) + patch_stateful_decoder(config, ov_model) + insert_state_for_nodes( + ov_model, + find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]), + ) diff --git a/optimum/intel/openvino/modeling_base_seq2seq.py b/optimum/intel/openvino/modeling_base_seq2seq.py index d01c396a42..11ee8f89a7 100644 --- a/optimum/intel/openvino/modeling_base_seq2seq.py +++ b/optimum/intel/openvino/modeling_base_seq2seq.py @@ -25,6 +25,7 @@ from transformers.file_utils import add_start_docstrings from ...exporters.openvino import main_export +from ...exporters.openvino.stateful import model_has_state from ..utils.import_utils import is_transformers_version from .configuration import OVConfig, OVWeightQuantizationConfig from .modeling_base import OVBaseModel @@ -64,7 +65,7 @@ def __init__( **kwargs, ): self.config = config - self.use_cache = decoder_with_past is not None + self.use_cache = decoder_with_past is not None or model_has_state(decoder) self.model_save_dir = model_save_dir self._compile_only = kwargs.get("compile_only", False) self._device = device.upper() @@ -75,7 +76,8 @@ def __init__( if self.is_dynamic and not self._compile_only: encoder = self._reshape(encoder, -1, -1, is_decoder=False) decoder = self._reshape(decoder, -1, -1) - decoder_with_past = self._reshape(decoder_with_past, -1, -1) if self.use_cache else None + if decoder_with_past is not None: + decoder_with_past = self._reshape(decoder_with_past, -1, -1) if self.use_cache else None self.encoder_model = encoder self.decoder_model = decoder self.decoder_with_past_model = decoder_with_past @@ -115,7 +117,7 @@ def _save_pretrained(self, save_directory: Union[str, Path]): """ src_files = [self.encoder_model, self.decoder_model] dst_file_names = [OV_ENCODER_NAME, OV_DECODER_NAME] - if self.use_cache: + if self.decoder_with_past_model is not None: src_files.append(self.decoder_with_past_model) dst_file_names.append(OV_DECODER_WITH_PAST_NAME) @@ -204,7 +206,11 @@ def _from_pretrained( if not compile_only: encoder = cls.load_model(os.path.join(model_id, encoder_file_name), quantization_config) decoder = cls.load_model(os.path.join(model_id, decoder_file_name), quantization_config) - if use_cache: + if ( + use_cache + and not model_has_state(decoder) + and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)) + ): decoder_with_past = cls.load_model( os.path.join(model_id, decoder_with_past_file_name), quantization_config ) @@ -221,7 +227,11 @@ def _from_pretrained( kwargs.get("ov_config"), model_save_dir, ) - if use_cache: + if ( + use_cache + and not model_has_state(decoder) + and os.path.exists(os.path.join(model_id, decoder_with_past_file_name)) + ): decoder_with_past = cls._compile_model( os.path.join(model_id, decoder_with_past_file_name), kwargs.get("device", "CPU"), @@ -232,8 +242,6 @@ def _from_pretrained( # Load model from hub else: model_file_names = {"encoder": encoder_file_name, "decoder": decoder_file_name} - if use_cache: - model_file_names["decoder_with_past"] = decoder_with_past_file_name # If not ONNX then OpenVINO IR : adds binary files if not from_onnx: @@ -257,7 +265,24 @@ def _from_pretrained( if not compile_only: encoder = cls.load_model(file_names["encoder"], quantization_config) decoder = cls.load_model(file_names["decoder"], quantization_config) - if use_cache: + if use_cache and not model_has_state(decoder): + model_file_names["decoder_with_past"] = decoder_with_past_file_name + with_past_files = ["decoder_with_past"] + if not from_onnx: + with_past_files.append("decoder_with_past_bin") + model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") + for name in with_past_files: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) + file_names[name] = model_cache_path decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config) else: encoder = cls._compile_model( @@ -266,7 +291,24 @@ def _from_pretrained( decoder = cls._compile_model( file_names["decoder"], kwargs.get("device", "CPU"), kwargs.get("ov_config"), model_save_dir ) - if use_cache: + if use_cache and not model_has_state(decoder): + model_file_names["decoder_with_past"] = decoder_with_past_file_name + with_past_files = ["decoder_with_past"] + if not from_onnx: + with_past_files.append("decoder_with_past_bin") + model_file_names["decoder_with_past_bin"] = decoder_with_past_file_name.replace(".xml", ".bin") + for name in with_past_files: + model_cache_path = hf_hub_download( + repo_id=model_id, + filename=model_file_names[name], + token=token, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + subfolder=subfolder, + ) + file_names[name] = model_cache_path decoder_with_past = cls._compile_model( file_names["decoder_with_past"], kwargs.get("device", "CPU"), @@ -365,6 +407,7 @@ def _from_transformers( ov_config = None else: ov_config = OVConfig(dtype="fp32") + stateful = kwargs.get("stateful", True) main_export( model_name_or_path=model_id, @@ -378,6 +421,7 @@ def _from_transformers( force_download=force_download, trust_remote_code=trust_remote_code, ov_config=ov_config, + stateful=stateful, ) return cls._from_pretrained( @@ -400,7 +444,8 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng elif inputs.get_any_name().startswith("cache_position"): shapes[inputs][0] = sequence_length elif is_decoder and not inputs.get_any_name().startswith("encoder"): - shapes[inputs][1] = -1 + if not inputs.get_any_name().startswith("beam_idx"): + shapes[inputs][1] = -1 else: shapes[inputs][1] = sequence_length model.reshape(shapes) @@ -424,7 +469,7 @@ def reshape(self, batch_size: int, sequence_length: int): self.is_dynamic = True if batch_size == -1 and sequence_length == -1 else False self.encoder_model = self._reshape(self.encoder_model, batch_size, sequence_length, is_decoder=False) self.decoder_model = self._reshape(self.decoder_model, batch_size, sequence_length) - if self.use_cache: + if self.decoder_with_past_model is not None: self.decoder_with_past_model = self._reshape(self.decoder_with_past_model, batch_size, sequence_length) def half(self): @@ -439,7 +484,7 @@ def half(self): apply_moc_transformations(self.decoder_model, cf=False) compress_model_transformation(self.encoder_model) compress_model_transformation(self.decoder_model) - if self.use_cache: + if self.decoder_with_past_model is not None: apply_moc_transformations(self.decoder_with_past_model, cf=False) compress_model_transformation(self.decoder_with_past_model) return self diff --git a/optimum/intel/openvino/modeling_seq2seq.py b/optimum/intel/openvino/modeling_seq2seq.py index fa48430a77..61911fc6d4 100644 --- a/optimum/intel/openvino/modeling_seq2seq.py +++ b/optimum/intel/openvino/modeling_seq2seq.py @@ -35,6 +35,7 @@ from transformers.generation import GenerationMixin from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput +from ...exporters.openvino.stateful import model_has_state from .. import OVConfig, OVQuantizer from ..utils import is_transformers_version from .configuration import OVQuantizationConfig, OVQuantizationConfigBase @@ -331,7 +332,7 @@ def __init__( self.encoder = OVEncoder(self.encoder_model, parent_model=self) self.decoder = OVDecoder(self.decoder_model, parent_model=self) - if self.use_cache: + if self.use_cache and not model_has_state(self.decoder_model): self.decoder_with_past = OVDecoder(self.decoder_with_past_model, parent_model=self) if enable_compilation: self.compile() @@ -373,10 +374,14 @@ def forward( # Decode if past_key_values is None or self.decoder_with_past is None: decoder_outputs = self.decoder( - input_ids=decoder_input_ids, + input_ids=( + decoder_input_ids[:, -1:] if past_key_values is not None and self.use_cache else decoder_input_ids + ), + past_key_values=past_key_values, encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=attention_mask, decoder_attention_mask=decoder_attention_mask, + cache_position=cache_position, ) else: decoder_outputs = self.decoder_with_past( @@ -416,16 +421,8 @@ def prepare_inputs_for_generation( def get_encoder(self): return self.encoder - # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache - @staticmethod - def _reorder_cache(past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: - reordered_past = () - for layer_past in past: - # Cached cross_attention states don't have to be reordered -> they are always the same - reordered_past += ( - tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:], - ) - return reordered_past + def _reorder_cache(self, past, beam_idx) -> Tuple[Tuple[torch.FloatTensor]]: + self.decoder._reorder_cache(past, beam_idx) def reshape(self, batch_size: int, sequence_length: int): """ @@ -460,13 +457,13 @@ def clear_requests(self): ) self.encoder.request = None self.decoder.request = None - if self.use_cache: + if self.decoder_with_past is not None: self.decoder_with_past.request = None def compile(self): self.encoder._compile() self.decoder._compile() - if self.use_cache: + if self.decoder_with_past is not None: self.decoder_with_past._compile() @@ -576,7 +573,11 @@ def __init__(self, model: openvino.runtime.Model, parent_model: OVModelForSeq2Se self.output_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.outputs)} self.output_dtypes = {key.get_any_name(): key.get_element_type().get_type_name() for key in self.model.outputs} self.key_value_output_names = [key for key in self.output_names if "key_values" in key or "present" in key] + self.stateful = model_has_state(self.model) is_legacy = any("past_key_values" in key.get_any_name() for key in self.model.outputs) + self.use_past = len(self.key_value_input_names) > 0 or self.stateful + self.next_beam_idx = None + self._past_length = 0 if len(self.key_value_input_names) > 0 and not is_legacy: self.use_past = True @@ -623,7 +624,11 @@ def forward( # Model inputs inputs = {} - if past_key_values is not None: + if self.stateful and past_key_values is None: + self.request.reset_state() + self._past_length = 0 + + if past_key_values is not None and not self.stateful: # Flatten the past_key_values past_key_values = tuple( past_key_value for pkv_per_layer in past_key_values for past_key_value in pkv_per_layer @@ -645,34 +650,54 @@ def forward( if "decoder_attention_mask" in self.input_names and decoder_attention_mask is not None: inputs["decoder_attention_mask"] = decoder_attention_mask - if "cache_position" in self.input_names and cache_position is not None: + if "cache_position" in self.input_names: + if cache_position is None: + past_len = self._get_past_length(past_key_values) + cache_position = np.arange(past_len, past_len + input_ids.shape[1]) inputs["cache_position"] = cache_position + if "beam_idx" in self.input_names: + batch_size = input_ids.shape[0] + inputs["beam_idx"] = ( + self.next_beam_idx if self.next_beam_idx is not None else np.arange(batch_size, dtype=np.int32) + ) + # Run inference self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device) - - # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the - # self-attention layer and 2 to the cross-attention layer) - out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) - - # Tuple of tuple of length `n_layers`, with each tuple of length equal to: - # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) - # * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant) - if self.use_past is False: - out_past_key_values = tuple( - out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) - ) - else: - # grab the cross attention key/values from the inputs - out_past_key_values = tuple( - out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] - for i in range(0, len(out_past_key_values), self.num_pkv) - ) + self._past_length += input_ids.shape[1] + + out_past_key_values = () + + if not self.stateful: + # Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the + # self-attention layer and 2 to the cross-attention layer) + out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names) + + # Tuple of tuple of length `n_layers`, with each tuple of length equal to: + # * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention) + # * 2 for the decoder with cache (k/v of self-attention as cross-attention cache is constant) + if self.use_past is False: + out_past_key_values = tuple( + out_past_key_values[i : i + self.num_pkv] for i in range(0, len(out_past_key_values), self.num_pkv) + ) + else: + # grab the cross attention key/values from the inputs + out_past_key_values = tuple( + out_past_key_values[i : i + self.num_pkv] + past_key_values[2 * i + 2 : 2 * i + 2 + self.num_pkv] + for i in range(0, len(out_past_key_values), self.num_pkv) + ) return Seq2SeqLMOutput(logits=logits, past_key_values=out_past_key_values) + def _get_past_length(self, past_key_values=None): + if past_key_values is None: + return 0 + if self.stateful: + return self._past_length + return past_key_values[0][0].shape[-2] + def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) @@ -694,6 +719,26 @@ def _compile(self): if "OPENVINO_LOG_LEVEL" in os.environ and int(os.environ["OPENVINO_LOG_LEVEL"]) > 2: _print_compiled_model_properties(compiled_model) + def _reorder_cache( + self, past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: + """ + This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or + [`~PreTrainedModel.beam_sample`] is called. + This is required to match `past_key_values` with the correct beam_idx at every generation step. + """ + if self.stateful: + self.next_beam_idx = np.array(beam_idx) + return past_key_values + else: + reordered_past = () + for layer_past in past_key_values: + # Cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(np.take(past_state, beam_idx, 0) for past_state in layer_past[:2]) + layer_past[2:], + ) + return reordered_past + @add_start_docstrings( """ @@ -785,7 +830,9 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng if is_decoder: if inputs.get_any_name().startswith("past_key_values"): shapes[inputs][2] = -1 - elif not inputs.get_any_name().startswith("encoder"): + elif not inputs.get_any_name().startswith("encoder") and not inputs.get_any_name().startswith( + "beam_idx" + ): shapes[inputs][1] = -1 model.reshape(shapes) return model @@ -868,7 +915,9 @@ def _reshape(self, model: openvino.runtime.Model, batch_size: int, sequence_leng if is_decoder: if inputs.get_any_name().startswith("past_key_values"): shapes[inputs][2] = -1 - elif not inputs.get_any_name().startswith("encoder"): + elif not inputs.get_any_name().startswith("encoder") and not inputs.get_any_name().startswith( + "beam_idx" + ): shapes[inputs][1] = -1 model.reshape(shapes) return model @@ -1011,3 +1060,62 @@ def __init__(self, stride): # a dummy model attribute that's used in the generate method to compute the input stride # input_stride = self.model.encoder.conv1.stride[0] * self.model.encoder.conv2.stride[0] model = DummyWhisperModel() + + # Adopeted for stateful support from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/modeling_whisper.py#L1810 + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past_key_values=None, + use_cache=None, + encoder_outputs=None, + attention_mask=None, + decoder_attention_mask=None, + cache_position=None, + **kwargs, + ): + # Overwritten -- encoder-decoder whisper has custom logic, but it's close to the general function. Next time + # this function needs to be touched, let's try to sort out the commonalities between the two and remove the + # overwrite. + + decoder_position_ids = None + if decoder_attention_mask is not None: + decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0) + + past_length = 0 + if past_key_values is not None: + self.decoder._get_past_length(past_key_values) + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] + + if decoder_position_ids is not None: + decoder_position_ids = decoder_position_ids[:, remove_prefix_length:] + # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. + decoder_position_ids = decoder_position_ids.clone(memory_format=torch.contiguous_format) + + if cache_position is None: + cache_position = torch.arange( + past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device + ) + elif use_cache: + cache_position = cache_position[-decoder_input_ids.shape[1] :] + + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + decoder_input_ids = decoder_input_ids.contiguous() + + return { + "encoder_outputs": encoder_outputs, + "past_key_values": past_key_values, + "decoder_input_ids": decoder_input_ids, + "use_cache": use_cache, + "decoder_attention_mask": decoder_attention_mask, + "decoder_position_ids": decoder_position_ids, + "cache_position": cache_position, + } diff --git a/optimum/intel/openvino/quantization.py b/optimum/intel/openvino/quantization.py index 962738e0e1..38b96b209d 100644 --- a/optimum/intel/openvino/quantization.py +++ b/optimum/intel/openvino/quantization.py @@ -834,12 +834,14 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB decoder_model.request, decoder_calibration_data, apply_caching=True ) - decoder_w_p_calibration_data = [] - decoder_w_p_model = self.model.decoder_with_past - decoder_w_p_model._compile() - decoder_w_p_model.request = InferRequestWrapper( - decoder_w_p_model.request, decoder_w_p_calibration_data, apply_caching=True - ) + decoder_w_p_model = None + if self.model.decoder_with_past_model is not None: + decoder_w_p_calibration_data = [] + decoder_w_p_model = self.model.decoder_with_past + decoder_w_p_model._compile() + decoder_w_p_model.request = InferRequestWrapper( + decoder_w_p_model.request, decoder_w_p_calibration_data, apply_caching=True + ) dataset_metadata = PREDEFINED_SPEECH_TO_TEXT_DATASETS[config.dataset] @@ -873,13 +875,16 @@ def _prepare_speech_to_text_calibration_data(self, config: OVQuantizationConfigB finally: encoder_model.request = encoder_model.request.request decoder_model.request = decoder_model.request.request - decoder_w_p_model.request = decoder_w_p_model.request.request + if decoder_w_p_model is not None: + decoder_w_p_model.request = decoder_w_p_model.request.request - return ( + datasets = [ nncf.Dataset(encoder_calibration_data), nncf.Dataset(decoder_calibration_data), - nncf.Dataset(decoder_w_p_calibration_data), - ) + ] + if decoder_w_p_model is not None: + datasets.append(nncf.Dataset(decoder_w_p_calibration_data)) + return datasets def _prepare_text_generation_calibration_data( self, quantization_config: OVQuantizationConfigBase, calibration_dataloader: OVDataLoader @@ -992,15 +997,16 @@ def _quantize_whisper_model(self, quantization_config, calibration_dataset, **kw self.model.decoder.model = quantized_decoder_model self.model.decoder.request = None - # Quantize decoder with past model - config = copy.deepcopy(quantization_config) - config.num_samples = calibration_dataset[2].get_length() - quantized_decoder_w_p_model = _full_quantization( - self.model.decoder_with_past_model, config, calibration_dataset[2], **kwargs - ) - self.model.decoder_with_past_model = quantized_decoder_w_p_model - self.model.decoder_with_past.model = quantized_decoder_w_p_model - self.model.decoder_with_past.request = None + if self.model.decoder_with_past_model is not None: + # Quantize decoder with past model + config = copy.deepcopy(quantization_config) + config.num_samples = calibration_dataset[2].get_length() + quantized_decoder_w_p_model = _full_quantization( + self.model.decoder_with_past_model, config, calibration_dataset[2], **kwargs + ) + self.model.decoder_with_past_model = quantized_decoder_w_p_model + self.model.decoder_with_past.model = quantized_decoder_w_p_model + self.model.decoder_with_past.request = None def _weight_only_quantization( diff --git a/tests/openvino/test_exporters_cli.py b/tests/openvino/test_exporters_cli.py index ab6b935a16..6e715de2a2 100644 --- a/tests/openvino/test_exporters_cli.py +++ b/tests/openvino/test_exporters_cli.py @@ -340,7 +340,7 @@ def test_exporters_cli_int8(self, task: str, model_type: str): if task.startswith("text2text-generation"): models = [model.encoder, model.decoder] - if task.endswith("with-past"): + if task.endswith("with-past") and not model.decoder.stateful: models.append(model.decoder_with_past) elif model_type.startswith("stable-diffusion") or model_type.startswith("flux"): models = [model.unet or model.transformer, model.vae_encoder, model.vae_decoder] @@ -425,7 +425,11 @@ def test_exporters_cli_full_quantization( submodels = [] if task == "automatic-speech-recognition": - submodels = [model.encoder, model.decoder, model.decoder_with_past] + submodels = [model.encoder, model.decoder] + if model.decoder_with_past is not None: + submodels.append(model.decoder_with_past) + else: + expected_num_fq_nodes_per_model = expected_num_fq_nodes_per_model[:-1] self.assertEqual(len(expected_num_fq_nodes_per_model), len(submodels)) for i, model in enumerate(submodels): actual_num_fq_nodes, actual_num_weight_nodes = get_num_quantized_nodes(model) diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index 5235fe5c87..efc84ee76a 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -66,6 +66,7 @@ from utils_tests import MODEL_NAMES, TEST_IMAGE_URL, mock_torch_cuda_is_available, patch_awq_for_inference from optimum.exporters.openvino.model_patcher import patch_update_causal_mask +from optimum.exporters.openvino.stateful import model_has_state from optimum.intel import ( OVDiffusionPipeline, OVFluxPipeline, @@ -607,8 +608,9 @@ def test_seq2seq_load_from_hub(self): with TemporaryDirectory() as tmpdirname: ov_exported_pipe.save_pretrained(tmpdirname) folder_contents = os.listdir(tmpdirname) - self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents) - self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents) + if not ov_exported_pipe.model.decoder.stateful: + self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents) + self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents) ov_exported_pipe = optimum_pipeline("text2text-generation", tmpdirname, accelerator="openvino") self.assertIsInstance(ov_exported_pipe.model, OVBaseModel) @@ -1624,16 +1626,23 @@ class OVModelForSeq2SeqLMIntegrationTest(unittest.TestCase): GENERATION_LENGTH = 100 SPEEDUP_CACHE = 1.1 + SUPPORT_STATEFUL = ("t5", "mt5") + @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_compare_to_transformers(self, model_arch): model_id = MODEL_NAMES[model_arch] set_seed(SEED) ov_model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) - + expected_stateful = is_transformers_version(">", "4.43") and model_arch in self.SUPPORT_STATEFUL + self.assertEqual(ov_model.decoder.stateful, expected_stateful) + self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful) + check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone + check_with_past_available(ov_model.decoder_with_past) self.assertIsInstance(ov_model.encoder, OVEncoder) self.assertIsInstance(ov_model.decoder, OVDecoder) - self.assertIsInstance(ov_model.decoder_with_past, OVDecoder) - self.assertIsInstance(ov_model.config, PretrainedConfig) + if not ov_model.decoder.stateful: + self.assertIsInstance(ov_model.decoder_with_past, OVDecoder) + self.assertIsInstance(ov_model.config, PretrainedConfig) transformers_model = AutoModelForSeq2SeqLM.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -1718,7 +1727,7 @@ def test_generate_utils(self, model_arch): gc.collect() def test_compare_with_and_without_past_key_values(self): - model_id = MODEL_NAMES["t5"] + model_id = MODEL_NAMES["bart"] tokenizer = AutoTokenizer.from_pretrained(model_id) text = "This is a sample input" tokens = tokenizer(text, return_tensors="pt") @@ -2337,6 +2346,12 @@ def test_compare_to_transformers(self, model_arch): transformers_model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id) ov_model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True, ov_config=F32_CONFIG) self.assertIsInstance(ov_model.config, PretrainedConfig) + # whisper cache class support implemented in 4.43 + expected_stateful = is_transformers_version(">", "4.43") + self.assertEqual(ov_model.decoder.stateful, expected_stateful) + self.assertEqual(model_has_state(ov_model.decoder_model), expected_stateful) + check_with_past_available = self.assertIsNone if expected_stateful else self.assertIsNotNone + check_with_past_available(ov_model.decoder_with_past) processor = get_preprocessor(model_id) data = self._generate_random_audio_data() diff --git a/tests/openvino/test_quantization.py b/tests/openvino/test_quantization.py index 493124dd94..26ad44401a 100644 --- a/tests/openvino/test_quantization.py +++ b/tests/openvino/test_quantization.py @@ -213,8 +213,12 @@ def test_ov_model_static_quantization_with_auto_dataset( ov_model.save_pretrained(tmp_dir) if model_cls == OVModelForSpeechSeq2Seq: + models = [ov_model.encoder.model, ov_model.decoder.model] + + if ov_model.decoder_with_past is not None: + models.append(ov_model.decoder_with_past.model) for model, expected_fq, expected_i8 in zip( - (ov_model.encoder.model, ov_model.decoder.model, ov_model.decoder_with_past.model), + models, expected_fake_quantize, expected_int8, ): @@ -675,7 +679,9 @@ def test_ovmodel_load_with_compressed_weights(self, model_cls, model_type, trust self.assertEqual(model._openvino_config.dtype, "int8") if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder, model.decoder_with_past] + models = [model.encoder, model.decoder] + if model.decoder_with_past is not None: + models.append(model.decoder_with_past) elif model.export_feature == "text-to-image": models = [model.unet, model.vae_encoder, model.vae_decoder] models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2) @@ -821,7 +827,9 @@ def test_ovmodel_load_with_uncompressed_weights(self, model_cls, model_type, tru MODEL_NAMES[model_type], export=True, load_in_8bit=False, trust_remote_code=trust_remote_code ) if model.export_feature.startswith("text2text-generation"): - models = [model.encoder, model.decoder, model.decoder_with_past] + models = [model.encoder, model.decoder] + if model.decoder_with_past is not None: + models.append(model.decoder_with_past) elif model.export_feature == "text-to-image": models = [model.unet, model.vae_encoder, model.vae_decoder] models.append(model.text_encoder if model_type == "stable-diffusion" else model.text_encoder_2) @@ -1256,9 +1264,14 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching): processor = AutoProcessor.from_pretrained(model_id) calibration_data = [] - ov_model.decoder_with_past.request = InferRequestWrapper( - ov_model.decoder_with_past.request, calibration_data, apply_caching=apply_caching - ) + if not ov_model.decoder.stateful: + ov_model.decoder_with_past.request = InferRequestWrapper( + ov_model.decoder_with_past.request, calibration_data, apply_caching=apply_caching + ) + else: + ov_model.decoder.request = InferRequestWrapper( + ov_model.decoder.request, calibration_data, apply_caching=apply_caching + ) for _ in range(2): input_features = self._generate_random_audio_data(processor) ov_model.generate(input_features, max_new_tokens=10, min_new_tokens=10) @@ -1268,7 +1281,7 @@ def test_calibration_data_uniqueness(self, model_name, apply_caching): for inputs_dict in calibration_data: for k, v in inputs_dict.items(): - if k == "input_ids": + if k in ["input_ids", "beam_idx"]: continue x = (v.numpy() if isinstance(v, torch.Tensor) else v).copy()