Skip to content

From PyTorch to ONNX

WolframRhodium edited this page Mar 1, 2022 · 4 revisions

ONNX is an open format built to represent machine learning models. Existing vs-mlrt runtimes only support this format for inference.

Given any PyTorch model loaded in Python with type torch.nn.Module (different models may have to be loaded in different ways), the conversion to the ONNX format is sometimes as easy as follows:

# https://github.com/onnx/onnx/issues/654
dynamic_axes = {'input': {0:'batch_size', 2:'width', 3:'height'}, 'output': {0:'batch_size' , 2:'width', 3:'height'}}

channels = 3
input = torch.ones(1, channels, 64, 64)

torch.export(model, input, "output.onnx", input_names=["input"], dynamic_axes=dynamic_axes, opset_version=14)

However, sometimes errors may be raised, and the code defining the network structure has to be modified.

Common fixes includes:

  • torch.nn.functional.pad(x, -4, -4, -4, -4) => x[..., 4:-4, 4:-4]
  • remove x.shape / x.size()