Skip to content

Commit

Permalink
Refactor Export for Model Conversion and Saving (Project-MONAI#7934)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#6375 .

### Description
Changes to be made based on the [previous discussion
Project-MONAI#7835](Project-MONAI#7835).

Modify the `_export` function to call the `saver` parameter for saving
different models. Rewrite the `onnx_export` function using the updated
`_export` to achieve consistency in model format conversion and saving.

* Rewrite `onnx_export` to call `_export` with `convert_to_onnx` and
appropriate `kwargs`.
* Add a `saver: Callable` parameter to `_export`, replacing
`save_net_with_metadata`.
* Pass `save_net_with_metadata` function wrapped with `partial` to set
parameters like `include_config_vals` and `append_timestamp`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Han123su <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Eric Kerfoot <[email protected]>
Co-authored-by: YunLiu <[email protected]>
  • Loading branch information
4 people authored Aug 23, 2024
1 parent de2a819 commit a5fbe71
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import warnings
import zipfile
from collections.abc import Mapping, Sequence
from functools import partial
from pathlib import Path
from pydoc import locate
from shutil import copyfile
Expand Down Expand Up @@ -1254,6 +1255,7 @@ def verify_net_in_out(

def _export(
converter: Callable,
saver: Callable,
parser: ConfigParser,
net_id: str,
filepath: str,
Expand All @@ -1268,6 +1270,8 @@ def _export(
Args:
converter: a callable object that takes a torch.nn.module and kwargs as input and
converts the module to another type.
saver: a callable object that accepts the converted model to save, a filepath to save to, meta values
(extracted from the parser), and a dictionary of extra JSON files (name -> contents) as input.
parser: a ConfigParser of the bundle to be converted.
net_id: ID name of the network component in the parser, it must be `torch.nn.Module`.
filepath: filepath to export, if filename has no extension, it becomes `.ts`.
Expand Down Expand Up @@ -1307,14 +1311,9 @@ def _export(
# add .json extension to all extra files which are always encoded as JSON
extra_files = {k + ".json": v for k, v in extra_files.items()}

save_net_with_metadata(
jit_obj=net,
filename_prefix_or_stream=filepath,
include_config_vals=False,
append_timestamp=False,
meta_values=parser.get().pop("_meta_", None),
more_extra_files=extra_files,
)
meta_values = parser.get().pop("_meta_", None)
saver(net, filepath, meta_values=meta_values, more_extra_files=extra_files)

logger.info(f"exported to file: {filepath}.")


Expand Down Expand Up @@ -1413,17 +1412,23 @@ def onnx_export(
input_shape_ = _get_fake_input_shape(parser=parser)

inputs_ = [torch.rand(input_shape_)]
net = parser.get_parsed_content(net_id_)
if has_ignite:
# here we use ignite Checkpoint to support nested weights and be compatible with MONAI CheckpointSaver
Checkpoint.load_objects(to_load={key_in_ckpt_: net}, checkpoint=ckpt_file_)
else:
ckpt = torch.load(ckpt_file_)
copy_model_state(dst=net, src=ckpt if key_in_ckpt_ == "" else ckpt[key_in_ckpt_])

converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
onnx_model = convert_to_onnx(model=net, **converter_kwargs_)
onnx.save(onnx_model, filepath_)

def save_onnx(onnx_obj: Any, filename_prefix_or_stream: str, **kwargs: Any) -> None:
onnx.save(onnx_obj, filename_prefix_or_stream)

_export(
convert_to_onnx,
save_onnx,
parser,
net_id=net_id_,
filepath=filepath_,
ckpt_file=ckpt_file_,
config_file=config_file_,
key_in_ckpt=key_in_ckpt_,
**converter_kwargs_,
)


def ckpt_export(
Expand Down Expand Up @@ -1544,8 +1549,12 @@ def ckpt_export(

converter_kwargs_.update({"inputs": inputs_, "use_trace": use_trace_})
# Use the given converter to convert a model and save with metadata, config content

save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)

_export(
convert_to_torchscript,
save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
Expand Down Expand Up @@ -1715,8 +1724,11 @@ def trt_export(
}
converter_kwargs_.update(trt_api_parameters)

save_ts = partial(save_net_with_metadata, include_config_vals=False, append_timestamp=False)

_export(
convert_to_trt,
save_ts,
parser,
net_id=net_id_,
filepath=filepath_,
Expand Down

0 comments on commit a5fbe71

Please sign in to comment.