diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index cd38ce56e6..6d4be2a65a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -46,7 +46,9 @@ def infer_module_output_dtypes( kwarg_inputs = {} torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) module_outputs = module(*torch_inputs, **torch_kwarg_inputs) - if not isinstance(module_outputs, (list, tuple)): + if isinstance(module_outputs, dict): + module_outputs = list(module_outputs.values()) + elif not isinstance(module_outputs, (list, tuple)): module_outputs = [module_outputs] # Int64 outputs can sometimes be generated from within other operators