From a5fbe716378948630783deef8ee435e7e3bdc918 Mon Sep 17 00:00:00 2001 From: Han123su <107395380+Han123su@users.noreply.github.com> Date: Fri, 23 Aug 2024 12:09:23 +0800 Subject: [PATCH] Refactor Export for Model Conversion and Saving (#7934) Fixes #6375 . ### Description Changes to be made based on the [previous discussion #7835](https://github.com/Project-MONAI/MONAI/pull/7835). Modify the `_export` function to call the `saver` parameter for saving different models. Rewrite the `onnx_export` function using the updated `_export` to achieve consistency in model format conversion and saving. * Rewrite `onnx_export` to call `_export` with `convert_to_onnx` and appropriate `kwargs`. * Add a `saver: Callable` parameter to `_export`, replacing `save_net_with_metadata`. * Pass `save_net_with_metadata` function wrapped with `partial` to set parameters like `include_config_vals` and `append_timestamp`. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Han123su Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/bundle/scripts.py | 46 ++++++++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 6dd83c1f81..142a366669 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -18,6 +18,7 @@ import warnings import zipfile from collections.abc import Mapping, Sequence +from functools import partial from pathlib import Path from pydoc import locate from shutil import copyfile @@ -1254,6 +1255,7 @@ def verify_net_in_out( def _export( converter: Callable, + saver: Callable, parser: ConfigParser, net_id: str, filepath: str, @@ -1268,6 +1270,8 @@ def _export( Args: converter: a callable object that takes a torch.nn.module and kwargs as input and converts the module to another type. + saver: a callable object that accepts the converted model to save, a filepath to save to, meta values + (extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input. parser: a ConfigParser of the bundle to be converted. net_id: ID name of the network component in the parser, it must be `torch.nn.Module`. filepath: filepath to export, if filename has no extension, it becomes `.ts`. @@ -1307,14 +1311,9 @@ def _export( # add .json extension to all extra files which are always encoded as JSON extra_files = {k + ".json": v for k, v in extra_files.items()} - save_net_with_metadata( - jit_obj=net, - filename_prefix_or_stream=filepath, - include_config_vals=False, - append_timestamp=False, - meta_values=parser.get().pop("_meta_", None), - more_extra_files=extra_files, - ) + meta_values = parser.get().pop("_meta_", None) + saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files) + logger.info(f"exported to file: {filepath}.") @@ -1413,17 +1412,23 @@ def onnx_export( input_shape_ = _get_fake_input_shape(parser=parser) inputs_ = [torch.rand(input_shape_)] - net = parser.get_parsed_content(net_id_) - if has_ignite: - # here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver - Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_) - else: - ckpt = torch.load(ckpt_file_) - copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_]) converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) - onnx_model = convert_to_onnx(model=net, **converter_kwargs_) - onnx.save(onnx_model, filepath_) + + def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None: + onnx.save(onnx_obj, filename_prefix_or_stream) + + _export( + convert_to_onnx, + save_onnx, + parser, + net_id=net_id_, + filepath=filepath_, + ckpt_file=ckpt_file_, + config_file=config_file_, + key_in_ckpt=key_in_ckpt_, + **converter_kwargs_, + ) def ckpt_export( @@ -1544,8 +1549,12 @@ def ckpt_export( converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_}) # Use the given converter to convert a model and save with metadata, config content + + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_torchscript, + save_ts, parser, net_id=net_id_, filepath=filepath_, @@ -1715,8 +1724,11 @@ def trt_export( } converter_kwargs_.update(trt_api_parameters) + save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False) + _export( convert_to_trt, + save_ts, parser, net_id=net_id_, filepath=filepath_,