Skip to content

Commit

Permalink
Merge decoder and decoder with past to stateful for seq2seq (#1078)
Browse files Browse the repository at this point in the history
* merge decoder and decoder with past to stateful for seq2seq

* fix quantization

* fix loading decoder_with_past

* fix quant tests

* fix tests

* fix more tests

* make input dynamic and enable sdpa

* review comments and kv cache compression disable in fp

* fix task recognition

* fix quantization tests

* respect from_onnx

* update test to check that stateful expected

* Apply suggestions from code review

Co-authored-by: Ilyas Moutawwakil <[email protected]>

---------

Co-authored-by: Ilyas Moutawwakil <[email protected]>
  • Loading branch information
eaidova and IlyasMoutawwakil authored Jan 15, 2025
1 parent fe55db5 commit 74ee7eb
Show file tree
Hide file tree
Showing 11 changed files with 552 additions and 91 deletions.
6 changes: 5 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,11 @@ def main_export(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
if (
is_transformers_version(">=", "4.36")
and is_transformers_version("<=", "4.45.0")
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
):
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default that does not support load model on cpu
Expand Down
64 changes: 59 additions & 5 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
from openvino.tools.ovc import convert_model
from optimum.exporters import TasksManager
from optimum.exporters.utils import (
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
DECODER_NAME,
ENCODER_NAME,
_get_submodels_for_export_encoder_decoder,
get_diffusion_models_for_export,
)
from optimum.exporters.utils import (
get_diffusion_models_for_export,
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
)
from optimum.intel.utils.import_utils import (
_diffusers_version,
Expand Down Expand Up @@ -106,10 +109,13 @@ def _set_runtime_options(
"diffusers" in library_name
or "text-generation" in task
or ("image-text-to-text" in task and model_name == "language_model")
or getattr(sub_export_config, "stateful", False)
):
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
if not quantized_model and (
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
"text-generation" in task
or ("image-text-to-text" in task and model_name == "language_model")
or getattr(sub_export_config, "stateful", False)
):
sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16"

Expand Down Expand Up @@ -642,10 +648,14 @@ def export_from_model(

logger.info(f"Automatic task detection to: {task}.")

is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
model_type = getattr(getattr(model, "config", {}), "model_type", "")
stateful = stateful and (
ensure_export_task_support_stateful(task)
or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", ""))
ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type)
)

if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
stateful = False
# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_export_configs is None:
raise ValueError(
Expand Down Expand Up @@ -687,6 +697,11 @@ def export_from_model(
if library_name == "diffusers":
export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
stateful_submodels = False
elif stateful and is_encoder_decoder and not custom_architecture:
export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export(
model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default"
)
stateful_submodels = [False, True]
else:
logging.disable(logging.INFO)
export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
Expand Down Expand Up @@ -1221,3 +1236,42 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)

return models_for_export


def _get_encoder_decoder_stateful_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
task: str,
_variant: str,
library_name: str,
int_dtype: str = "int64",
float_dtype: str = "fp32",
preprocessors: Optional[List[Any]] = None,
):
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="openvino", task=task, library_name=library_name
)
export_config = export_config_constructor(
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
legacy=False,
)

export_config.variant = _variant
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")

models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False)

encoder_export_config = export_config.with_behavior("encoder")
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)

decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)

decoder_export_config_with_past.stateful = True
models_for_export[DECODER_NAME] = (
models_for_export[DECODER_NAME],
decoder_export_config_with_past,
)
return None, models_for_export
70 changes: 70 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers.utils import is_tf_available

from optimum.exporters.onnx.base import ConfigBehavior
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.model_configs import (
CLIPOnnxConfig,
Expand All @@ -38,8 +39,10 @@
MistralOnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
T5OnnxConfig,
UNetOnnxConfig,
VisionOnnxConfig,
WhisperOnnxConfig,
)
from optimum.exporters.onnx.model_patcher import ModelPatcher
from optimum.exporters.tasks import TasksManager
Expand Down Expand Up @@ -102,6 +105,7 @@
Qwen2VLVisionEmbMergerPatcher,
QwenModelPatcher,
RotaryEmbPatcher,
StatefulSeq2SeqDecoderPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
)
Expand Down Expand Up @@ -2611,3 +2615,69 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"whisper",
*[
"feature-extraction",
"feature-extraction-with-past",
"audio-classification",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
],
library_name="transformers",
)
class WhisperOpenVINOConfig(WhisperOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
return common_inputs


@register_in_tasks_manager(
"t5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class T5OpenVINOConfig(T5OnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "decoder_sequence_length"}
return common_inputs


@register_in_tasks_manager(
"mt5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class MT5OpenVINOConfig(T5OpenVINOConfig):
pass


@register_in_tasks_manager(
"longt5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class LongT5OpenVINOConfig(T5OpenVINOConfig):
pass
53 changes: 52 additions & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.utils import is_tf_available

from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
from optimum.exporters.onnx.model_patcher import (
DecoderModelPatcher,
ModelPatcher,
Seq2SeqModelPatcher,
override_arguments,
)
from optimum.intel.utils.import_utils import (
_openvino_version,
_torch_version,
Expand Down Expand Up @@ -3740,3 +3745,49 @@ def __exit__(self, exc_type, exc_value, traceback):
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._attn = layer.attn._orig_attn


class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
model.__orig_forward = model.forward

@functools.wraps(model.__orig_forward)
def patched_forward(*args, **kwargs):
from transformers.cache_utils import EncoderDecoderCache

signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

return_legacy_cache = False
pkv_in_args = False
legacy_pkv = None
if "past_key_values" in kwargs:
legacy_pkv = kwargs.pop("past_key_values", None)
sign_names = list(signature.parameters.keys())
pkv_argument_index = sign_names.index("past_key_values")
if legacy_pkv is None and len(args) > pkv_argument_index:
legacy_pkv = args[pkv_argument_index]
pkv_in_args = True
if legacy_pkv is not None:
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
return_legacy_cache = True
if not pkv_in_args:
kwargs["past_key_values"] = pkv
else:
args[pkv_argument_index] = pkv

outputs = model.__orig_forward(*args, **kwargs)
if return_legacy_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()

return outputs

model.forward = patched_forward

super().__init__(config, model, model_kwargs)
95 changes: 93 additions & 2 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,95 @@ def ensure_stateful_is_available(warn=True):
return True


_ENCODER_DECODER_TASKS_WITH_PAST = (
"automatic-speech-recognition",
"text2text-generation",
)

_DECODER_TASKS_WITH_PAST = ("text-generation",)


def ensure_export_task_support_stateful(task: str):
from optimum.exporters import TasksManager

task = TasksManager.map_from_synonym(task)
return task in ["text-generation-with-past"]

is_stateful = (
task.endswith("-with-past")
and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + _DECODER_TASKS_WITH_PAST
)
return is_stateful


def ensure_model_type_support_stateful(model_type: str):
return model_type.replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS


def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name: str = "input_ids"):
def remove_parameters_by_names(model: ov.Model, names: list):
parameters = [model.input(name).get_node() for name in names]
for p in parameters:
model.remove_parameter(p)


def get_input_nodes(node):
return [input.get_node() for input in node.input_values()]


def find_dependent_nodes(model: ov.Model, sources: list):
# Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources`
result = set(sources)
for node in model.get_ordered_ops():
input_nodes = set(get_input_nodes(node))
if input_nodes & result:
result.add(node)
return result


def get_read_value_ops(model: ov.Model):
return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"]


def get_shape_of_ops(model: ov.Model):
return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"]


def get_consumer_nodes(node):
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
return {input.get_node() for input in consumer_inputs}


def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):
# Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's
other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources)
other_nodes = find_dependent_nodes(model, other_inputs)
source_dependent_nodes = find_dependent_nodes(model, sources)
# TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph
nodes = source_dependent_nodes - other_nodes
edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes]
return edge_nodes


def insert_state_for_nodes(model: ov.Model, nodes):
# For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression
outputs = sum((node.outputs() for node in nodes), [])
for output in outputs:
consumers = output.get_target_inputs()
# FIXME: get_any_name is not reliable as tensor may not have any names
variable_id = output.get_any_name()
read_value = ov.runtime.opset13.read_value(output, variable_id)
for consumer in consumers:
consumer.replace_source_output(read_value.output(0))
assign = ov.runtime.opset13.assign(read_value, variable_id)
model.add_sinks([assign])


def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
return patch_stateful_encoder_decoder(config, ov_model)
return patch_stateful_decoder(config, ov_model)


def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
"""
Apply stateful transformation to model to hide key values inputs inside model.
Select transformation parameters based on model architecture
Expand Down Expand Up @@ -236,3 +313,17 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name
make_stateful(
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
)


def patch_stateful_encoder_decoder(config, ov_model):
encoder_key_value_input_names = [
key.get_any_name()
for key in ov_model.inputs
if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())
]
remove_parameters_by_names(ov_model, encoder_key_value_input_names)
patch_stateful_decoder(config, ov_model)
insert_state_for_nodes(
ov_model,
find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]),
)
Loading

0 comments on commit 74ee7eb

Please sign in to comment.