Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] torch_tensorrt.compile does not work with nn.ConvTranspose2d and output_padding #3352

Open
xavierjimenezp opened this issue Jan 10, 2025 · 1 comment
Labels
bug Something isn't working

Comments

@xavierjimenezp
Copy link

Bug Description

Trying to use torch_tensorrt.compile to compile a model using nn.ConvTranspose2d with output_padding = 1 raises the following error:

RuntimeError: Target aten.convolution.default does not support `transposed=True` 

While executing %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %deconv_weight, %deconv_bias, [2, 2], [0, 0], [1, 1], True, [1, 1], 1), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x72fa0452d6b0>: ((4, 256, 296, 296), torch.float32, False, (22429696, 87616, 296, 1), torch.contiguous_format, False, {})}})

Note: using default value output_padding = 0 works fine. I have not tried with other values.

To Reproduce

The followig code allows to reproduce the error

import torch
import torch.nn as nn
import torch_tensorrt


class ToyModel(nn.Module):
    def __init__(self) -> None:
        super(ToyModel, self).__init__()

        self.deconv = nn.ConvTranspose2d(
            256,
            128,
            kernel_size=3,
            output_padding=1,
            stride=2,
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.deconv(x)
        return x


def main():
    # Create a toy model instance
    model = ToyModel().eval().cuda()

    # Create dummy input
    input_tensor = torch.randn(4, 256, 296, 296).cuda()

    # Compile the model with torch_tensorrt
    trt_model = torch_tensorrt.compile(
        model,
        inputs=[input_tensor],
        enabled_precisions={torch.float32},
        min_block_size=1,
    )


if __name__ == "__main__":
    main()

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 2.5.0
  • PyTorch Version (e.g. 1.0): 2.5.1+cu124
  • CPU Architecture: x86_64
  • OS (e.g., Linux): Ubuntu 22.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives: No
  • Python version: 3.11.11
  • CUDA version: 12.6
  • GPU models and configuration: NVIDIA GeForce RTX 4090
  • Any other relevant information:
@xavierjimenezp xavierjimenezp added the bug Something isn't working label Jan 10, 2025
@zewenli98
Copy link
Collaborator

seems to related to #3343

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants