Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge decoder and decoder with past to stateful for seq2seq #1078

Merged
merged 13 commits into from
Jan 15, 2025
6 changes: 5 additions & 1 deletion optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,11 @@ def main_export(
f"Asked to export a {model_type} model for the task {task}{autodetected_message}, but the Optimum OpenVINO exporter only supports the tasks {', '.join(model_tasks.keys())} for {model_type}. Please use a supported task. Please open an issue at https://github.com/huggingface/optimum/issues if you would like the task {task} to be supported in the ONNX export for {model_type}."
)

if is_transformers_version(">=", "4.36") and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED:
if (
is_transformers_version(">=", "4.36")
and is_transformers_version("<=", "4.45.0")
and model_type in SDPA_ARCHS_ONNX_EXPORT_NOT_SUPPORTED
):
loading_kwargs["attn_implementation"] = "eager"

# some models force flash_attn attention by default that does not support load model on cpu
Expand Down
64 changes: 59 additions & 5 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,13 @@
from openvino.tools.ovc import convert_model
from optimum.exporters import TasksManager
from optimum.exporters.utils import (
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
DECODER_NAME,
ENCODER_NAME,
_get_submodels_for_export_encoder_decoder,
get_diffusion_models_for_export,
)
from optimum.exporters.utils import (
get_diffusion_models_for_export,
_get_submodels_and_export_configs as _default_get_submodels_and_export_configs,
)
from optimum.intel.utils.import_utils import (
_diffusers_version,
Expand Down Expand Up @@ -105,10 +108,13 @@ def _set_runtime_options(
"diffusers" in library_name
or "text-generation" in task
or ("image-text-to-text" in task and model_name == "language_model")
or getattr(sub_export_config, "stateful", False)
):
sub_export_config.runtime_options["ACTIVATIONS_SCALE_FACTOR"] = "8.0"
if not quantized_model and (
"text-generation" in task or ("image-text-to-text" in task and model_name == "language_model")
"text-generation" in task
or ("image-text-to-text" in task and model_name == "language_model")
or getattr(sub_export_config, "stateful", False)
):
sub_export_config.runtime_options["KV_CACHE_PRECISION"] = "f16"

Expand Down Expand Up @@ -635,10 +641,14 @@ def export_from_model(

logger.info(f"Automatic task detection to: {task}.")

is_encoder_decoder = getattr(getattr(model, "config", {}), "is_encoder_decoder", False)
model_type = getattr(getattr(model, "config", {}), "model_type", "")
stateful = stateful and (
ensure_export_task_support_stateful(task)
or ensure_model_type_support_stateful(getattr(getattr(model, "config", {}), "model_type", ""))
ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type)
eaidova marked this conversation as resolved.
Show resolved Hide resolved
)

if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False):
stateful = False
# TODO: support onnx_config.py in the model repo
if custom_architecture and custom_export_configs is None:
raise ValueError(
Expand Down Expand Up @@ -677,6 +687,11 @@ def export_from_model(
if library_name == "diffusers":
export_config, models_and_export_configs = get_diffusion_models_for_export_ext(model, exporter="openvino")
stateful_submodels = False
elif stateful and is_encoder_decoder and not custom_architecture:
export_config, models_and_export_configs = _get_encoder_decoder_stateful_models_for_export(
model=model, task=task, preprocessors=preprocessors, library_name=library_name, _variant="default"
)
stateful_submodels = [False, True]
else:
logging.disable(logging.INFO)
export_config, models_and_export_configs, stateful_submodels = _get_submodels_and_export_configs(
Expand Down Expand Up @@ -1211,3 +1226,42 @@ def get_flux_models_for_export(pipeline, exporter, int_dtype, float_dtype):
models_for_export["text_encoder_2"] = (text_encoder_2, export_config)

return models_for_export


def _get_encoder_decoder_stateful_models_for_export(
model: Union["PreTrainedModel", "TFPreTrainedModel"],
task: str,
_variant: str,
library_name: str,
int_dtype: str = "int64",
float_dtype: str = "fp32",
preprocessors: Optional[List[Any]] = None,
):
export_config_constructor = TasksManager.get_exporter_config_constructor(
model=model, exporter="openvino", task=task, library_name=library_name
)
export_config = export_config_constructor(
model.config,
int_dtype=int_dtype,
float_dtype=float_dtype,
preprocessors=preprocessors,
legacy=False,
)

export_config.variant = _variant
all_variants = "\n".join([f" - {name}: {description}" for name, description in export_config.VARIANTS.items()])
logger.info(f"Using the export variant {export_config.variant}. Available variants are:\n{all_variants}")

models_for_export = _get_submodels_for_export_encoder_decoder(model, use_past=False)

encoder_export_config = export_config.with_behavior("encoder")
models_for_export[ENCODER_NAME] = (models_for_export[ENCODER_NAME], encoder_export_config)

decoder_export_config_with_past = export_config.with_behavior("decoder", use_past=True, use_past_in_inputs=True)

decoder_export_config_with_past.stateful = True
models_for_export[DECODER_NAME] = (
models_for_export[DECODER_NAME],
decoder_export_config_with_past,
)
return None, models_for_export
70 changes: 70 additions & 0 deletions optimum/exporters/openvino/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
from transformers.utils import is_tf_available

from optimum.exporters.onnx.base import ConfigBehavior
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
from optimum.exporters.onnx.model_configs import (
CLIPOnnxConfig,
Expand All @@ -38,8 +39,10 @@
MistralOnnxConfig,
MPTOnnxConfig,
PhiOnnxConfig,
T5OnnxConfig,
UNetOnnxConfig,
VisionOnnxConfig,
WhisperOnnxConfig,
)
from optimum.exporters.onnx.model_patcher import ModelPatcher
from optimum.exporters.tasks import TasksManager
Expand Down Expand Up @@ -102,6 +105,7 @@
Qwen2VLVisionEmbMergerPatcher,
QwenModelPatcher,
RotaryEmbPatcher,
StatefulSeq2SeqDecoderPatcher,
UpdateCausalMaskModelPatcher,
XverseModelPatcher,
)
Expand Down Expand Up @@ -2611,3 +2615,69 @@ def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
return GptBigCodeModelPatcher(self, model, model_kwargs=model_kwargs)


@register_in_tasks_manager(
"whisper",
*[
"feature-extraction",
"feature-extraction-with-past",
"audio-classification",
"automatic-speech-recognition",
"automatic-speech-recognition-with-past",
],
library_name="transformers",
)
class WhisperOpenVINOConfig(WhisperOnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
eaidova marked this conversation as resolved.
Show resolved Hide resolved
return common_inputs


@register_in_tasks_manager(
"t5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class T5OpenVINOConfig(T5OnnxConfig):
def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> ModelPatcher:
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
return StatefulSeq2SeqDecoderPatcher(self, model, model_kwargs)
return super().patch_model_for_export(model, model_kwargs)

@property
def inputs(self):
common_inputs = super().inputs
if getattr(self, "stateful", False) and self._behavior == ConfigBehavior.DECODER:
common_inputs["decoder_input_ids"] = {0: "batch_size", 1: "seq_length"}
eaidova marked this conversation as resolved.
Show resolved Hide resolved
return common_inputs


@register_in_tasks_manager(
"mt5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class MT5OpenVINOConfig(T5OpenVINOConfig):
pass


@register_in_tasks_manager(
"longt5",
*["feature-extraction", "feature-extraction-with-past", "text2text-generation", "text2text-generation-with-past"],
library_name="transformers",
)
class LongT5OpenVINOConfig(T5OpenVINOConfig):
pass
53 changes: 52 additions & 1 deletion optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling
from transformers.utils import is_tf_available

from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments
from optimum.exporters.onnx.model_patcher import (
DecoderModelPatcher,
ModelPatcher,
Seq2SeqModelPatcher,
override_arguments,
)
from optimum.intel.utils.import_utils import (
_openvino_version,
_torch_version,
Expand Down Expand Up @@ -3740,3 +3745,49 @@ def __exit__(self, exc_type, exc_value, traceback):
if getattr(self._model.config, "_attn_implementation", "eager") == "sdpa":
for layer in self._model.transformer.h:
layer.attn._attn = layer.attn._orig_attn


class StatefulSeq2SeqDecoderPatcher(Seq2SeqModelPatcher):
def __init__(
self,
config: "OnnxConfig",
model: Union["PreTrainedModel", "TFPreTrainedModel"],
model_kwargs: Optional[Dict[str, Any]] = None,
):
model.__orig_forward = model.forward

@functools.wraps(model.__orig_forward)
def patched_forward(*args, **kwargs):
from transformers.cache_utils import EncoderDecoderCache

signature = inspect.signature(self.orig_forward)
args, kwargs = override_arguments(args, kwargs, signature, model_kwargs=self.model_kwargs)

return_legacy_cache = False
pkv_in_args = False
legacy_pkv = None
if "past_key_values" in kwargs:
legacy_pkv = kwargs.pop("past_key_values", None)
sign_names = list(signature.parameters.keys())
pkv_argument_index = sign_names.index("past_key_values")
if legacy_pkv is None and len(args) > pkv_argument_index:
legacy_pkv = args[pkv_argument_index]
pkv_in_args = True
if legacy_pkv is not None:
only_self_cache = [cache_item[:2] for cache_item in legacy_pkv]
pkv = EncoderDecoderCache.from_legacy_cache(only_self_cache)
return_legacy_cache = True
if not pkv_in_args:
kwargs["past_key_values"] = pkv
else:
args[pkv_argument_index] = pkv

outputs = model.__orig_forward(*args, **kwargs)
if return_legacy_cache:
outputs.past_key_values = outputs.past_key_values.to_legacy_cache()

return outputs

model.forward = patched_forward

super().__init__(config, model, model_kwargs)
95 changes: 93 additions & 2 deletions optimum/exporters/openvino/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,18 +190,95 @@ def ensure_stateful_is_available(warn=True):
return True


_ENCODER_DECODER_TASKS_WITH_PAST = (
"automatic-speech-recognition",
"text2text-generation",
)

_DECODER_TASKS_WITH_PAST = ("text-generation",)


def ensure_export_task_support_stateful(task: str):
from optimum.exporters import TasksManager

task = TasksManager.map_from_synonym(task)
return task in ["text-generation-with-past"]

is_stateful = (
task.endswith("-with-past")
and task.replace("-with-past", "") in _ENCODER_DECODER_TASKS_WITH_PAST + _DECODER_TASKS_WITH_PAST
)
return is_stateful


def ensure_model_type_support_stateful(model_type: str):
return model_type.replace("_", "-") in MULTI_MODAL_TEXT_GENERATION_MODELS


def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name: str = "input_ids"):
def remove_parameters_by_names(model: ov.Model, names: list):
parameters = [model.input(name).get_node() for name in names]
for p in parameters:
model.remove_parameter(p)


def get_input_nodes(node):
return [input.get_node() for input in node.input_values()]


def find_dependent_nodes(model: ov.Model, sources: list):
# Finds all nodes in `model` that are directly or indirectly dependent on at least one node from the list of nodes in `sources`, including `sources`
result = set(sources)
for node in model.get_ordered_ops():
input_nodes = set(get_input_nodes(node))
if input_nodes & result:
result.add(node)
return result


def get_read_value_ops(model: ov.Model):
return [op for op in model.get_ops() if op.get_type_name() == "ReadValue"]


def get_shape_of_ops(model: ov.Model):
return [op for op in model.get_ops() if op.get_type_name() == "ShapeOf"]


def get_consumer_nodes(node):
consumer_inputs = set().union(*[output.get_target_inputs() for output in node.outputs()])
return {input.get_node() for input in consumer_inputs}


def find_output_nodes_of_dependent_subgraph(model: ov.Model, sources: list):
# Search for nodes in the model graph that depend on nodes in `starts` list but independent of other model Parameter's/ReadValue's
other_inputs = set(model.get_parameters() + get_read_value_ops(model) + get_shape_of_ops(model)) - set(sources)
other_nodes = find_dependent_nodes(model, other_inputs)
source_dependent_nodes = find_dependent_nodes(model, sources)
# TODO: Use symbols on dimensions to filter out ShapeOf subexpressions that do not bring new symbols in the subgraph
nodes = source_dependent_nodes - other_nodes
edge_nodes = [node for node in nodes if get_consumer_nodes(node) & other_nodes]
return edge_nodes


def insert_state_for_nodes(model: ov.Model, nodes):
# For each output in a given list `nodes` of ov.Node's, insert ReadValue-Assign pair and use the node output as initialization sub-expression
outputs = sum((node.outputs() for node in nodes), [])
for output in outputs:
consumers = output.get_target_inputs()
# FIXME: get_any_name is not reliable as tensor may not have any names
variable_id = output.get_any_name()
read_value = ov.runtime.opset13.read_value(output, variable_id)
for consumer in consumers:
consumer.replace_source_output(read_value.output(0))
assign = ov.runtime.opset13.assign(read_value, variable_id)
model.add_sinks([assign])


def patch_stateful(config: PretrainedConfig, ov_model: ov.Model):
if config.is_encoder_decoder and model_has_input_output_name(ov_model, "encoder_hidden_states"):
return patch_stateful_encoder_decoder(config, ov_model)
return patch_stateful_decoder(config, ov_model)


def patch_stateful_decoder(config: PretrainedConfig, ov_model: ov.Model):
"""
Apply stateful transformation to model to hide key values inputs inside model.
Select transformation parameters based on model architecture
Expand Down Expand Up @@ -236,3 +313,17 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model, main_input_name
make_stateful(
ov_model, not_kv_inputs, key_value_input_names, key_value_output_names, batch_dim, num_attention_heads, None
)


def patch_stateful_encoder_decoder(config, ov_model):
encoder_key_value_input_names = [
key.get_any_name()
for key in ov_model.inputs
if any("key_values" in key_name and "encoder" in key_name for key_name in key.get_names())
]
remove_parameters_by_names(ov_model, encoder_key_value_input_names)
patch_stateful_decoder(config, ov_model)
insert_state_for_nodes(
ov_model,
find_output_nodes_of_dependent_subgraph(ov_model, [ov_model.input("encoder_hidden_states").get_node()]),
)
Loading
Loading