Skip to content

Commit

Permalink
try to remove onnx fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jan 16, 2025
1 parent fe10aaa commit f6d8365
Showing 1 changed file with 34 additions and 71 deletions.
105 changes: 34 additions & 71 deletions optimum/exporters/openvino/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,80 +415,43 @@ def export_pytorch(
dummy_inputs = config.rename_ambiguous_inputs(dummy_inputs)
dummy_inputs, dict_inputs = remove_none_from_dummy_inputs(dummy_inputs)

try:
# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
for i in range(len(dict_inputs)):
input_name, keys = dict_inputs[i]
tuple_input = kwargs[input_name]
input_dict = dict(zip(keys, tuple_input))
kwargs[input_name] = input_dict
outputs = patched_forward(*args, **kwargs)
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])

patcher.patched_forward = ts_patched_forward

ts_decoder_kwargs = {}
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)

except Exception as ex:
logger.warning(f"Export model to OpenVINO directly failed with: \n{ex}.\nModel will be exported to ONNX")

if stateful:
# cannot raise because stateful is enabled by default and it would break backward compatibility for models that couldn't convert to OV directly
# TODO: Implement stateful for ONNX path as well, not doing it right now because of lack of validation
logger.warning(
"[ WARNING ] Making stateful models is not supported when exporting to ONNX as an intermediate step. "
"A stateless model will be exported instead. It may result in sub-optimal inference performance."
"Provide a model that can be converted to OpenVINO without fallback to ONNX conversion path."
)

# TorchScript used behind OpenVINO conversion. Optimum supports only return_dict=True models for patching,
# while TorchScript do not support dictionary with values of mixed types (e.g. Tensor and None) in model input/output
# To handle it, additional wrapper on patcher forward applied.
# model.config.torchscript = True can not be used for patching, because it overrides return_dict to False
patcher = config.patch_model_for_export(model, model_kwargs=model_kwargs)
patched_forward = patcher.patched_forward

@functools.wraps(patched_forward)
def ts_patched_forward(*args, **kwargs):
for i in range(len(dict_inputs)):
input_name, keys = dict_inputs[i]
tuple_input = kwargs[input_name]
input_dict = dict(zip(keys, tuple_input))
kwargs[input_name] = input_dict
outputs = patched_forward(*args, **kwargs)
return tuple([value if not isinstance(value, list) else tuple(value) for value in outputs.values()])

patcher.patched_forward = ts_patched_forward

ts_decoder_kwargs = {}
if library_name == "diffusers" and is_openvino_version(">=", "2025.0"):
ts_decoder_kwargs["trace_kwargs"] = {"check_trace": False}

with patcher:
if patch_16bit_model:
from openvino.frontend.pytorch.patch_model import unpatch_model

unpatch_model(model, "_openvino_module_extension_patch_orig_forward")
for m in model.modules():
if any(p.dtype in [torch.float16, torch.bfloat16] for p in m.parameters(False)) or any(
b.dtype in [torch.float16, torch.bfloat16] for b in m.buffers(False)
):
m.float()

return export_pytorch_via_onnx(
model,
config,
opset,
output,
device,
input_shapes,
model_kwargs,
ov_config=ov_config,
library_name=library_name,
from openvino.frontend.pytorch.patch_model import __make_16bit_traceable

__make_16bit_traceable(model)
check_dummy_inputs_are_allowed(model, dummy_inputs)
input_info = _get_input_info(model, config, dummy_inputs)
ts_decoder = TorchScriptPythonDecoder(model, example_input=dummy_inputs, **ts_decoder_kwargs)
ov_model = convert_model(
ts_decoder,
example_input=dummy_inputs,
input=[(item.shape, item.type) for item in input_info],
)

ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation?

output_names = list(config.outputs.keys())
for idx, out_tensor in enumerate(ov_model.outputs):
if idx < len(output_names):
Expand Down

0 comments on commit f6d8365

Please sign in to comment.