Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove legacy conv converter #3343

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
from torch.fx.node import Argument, Node, Target

from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
Expand Down Expand Up @@ -2451,15 +2452,8 @@ def aten_ops_le(
)


def conv_param_validator(
conv_node: Node, settings: Optional[CompilationSettings] = None
) -> bool:
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default,
capability_validator=conv_param_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
Expand Down Expand Up @@ -2505,6 +2499,7 @@ def aten_ops_convolution(
stride=args[3],
padding=args[4],
dilation=args[5],
output_padding=args[7],
groups=args[8],
)

Expand Down
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
Expand Down Expand Up @@ -105,6 +106,9 @@ def deconvNd(
padding = (padding,) if isinstance(padding, int) else padding
stride = (stride,) if isinstance(stride, int) else stride
dilation = (dilation,) if isinstance(dilation, int) else dilation
output_padding = (
(output_padding,) if isinstance(output_padding, int) else output_padding
)

# Expand parameters manually for Conv1D computations
if is_deconv1d:
Expand All @@ -113,6 +117,11 @@ def deconvNd(
dilation = (
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
)
output_padding = (
(tuple(output_padding) + (0,))
if output_padding is not None
else output_padding
)

set_layer_name(deconv_layer, target, name, source_ir)

Expand All @@ -126,6 +135,20 @@ def deconvNd(
if groups is not None:
deconv_layer.num_groups = groups

ndims = len(padding)
pre_padding_values = []
post_padding_values = []

for dim in range(ndims):
pre_padding = padding[dim]
post_padding = padding[dim] - output_padding[dim]

pre_padding_values.append(pre_padding)
post_padding_values.append(post_padding)

deconv_layer.pre_padding = tuple(pre_padding_values)
deconv_layer.post_padding = tuple(post_padding_values)

# Handle quantization cases
if scale is not None and zero_point is not None:
# Assume the dtype of activation is torch.quint8
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch.fx.immutable_collections import immutable_list
from torch.fx.node import Argument, Target

import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils
from torch_tensorrt.fx.converters import acc_ops_converters
from torch_tensorrt.fx.converters.impl import activation, convolution

Expand Down
61 changes: 56 additions & 5 deletions tests/py/dynamo/conversion/test_deconvolution_aten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand All @@ -15,6 +16,21 @@ class TestDeconvolutionConverter(DispatchTestCase):
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
param("output_padding_3", 3, stride=2, padding=3, output_padding=1),
param("output_padding_4", 3, stride=3, padding=2, output_padding=1),
param("output_padding_5", 3, stride=3, padding=3, output_padding=1),
param("output_padding_6", 3, stride=3, padding=3, output_padding=2),
param(
"combined_params",
3,
stride=3,
padding=3,
dilation=2,
groups=3,
output_padding=2,
),
]
)
def test_deconv1d(
Expand All @@ -26,6 +42,7 @@ def test_deconv1d(
dilation=1,
groups=1,
bias=True,
output_padding=0,
):
class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -36,9 +53,10 @@ def __init__(self):
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
)

def forward(self, x):
Expand Down Expand Up @@ -101,6 +119,22 @@ def forward(self, x):
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
param("output_padding_2", 3, stride=2, padding=1, output_padding=1),
param("output_padding_3", 3, stride=2, padding=2, output_padding=1),
param("output_padding_4", 3, stride=2, padding=3, output_padding=1),
param("output_padding_5", 3, stride=3, padding=2, output_padding=1),
param("output_padding_6", 3, stride=3, padding=3, output_padding=1),
param("output_padding_7", 3, stride=3, padding=3, output_padding=2),
param(
"combined_params",
3,
stride=3,
padding=3,
dilation=2,
groups=3,
output_padding=2,
),
]
)
def test_deconv2d(
Expand All @@ -112,6 +146,7 @@ def test_deconv2d(
dilation=1,
groups=1,
bias=True,
output_padding=0,
):
class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -122,9 +157,10 @@ def __init__(self):
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
)

def forward(self, x):
Expand Down Expand Up @@ -172,6 +208,19 @@ def forward(self, x):
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
param("output_padding_3", 3, stride=3, padding=3, output_padding=1),
param("output_padding_4", 3, stride=3, padding=3, output_padding=2),
param(
"combined_params",
3,
stride=3,
padding=3,
dilation=2,
groups=3,
output_padding=2,
),
]
)
def test_deconv3d(
Expand All @@ -183,6 +232,7 @@ def test_deconv3d(
dilation=1,
groups=1,
bias=True,
output_padding=0,
):
class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -193,9 +243,10 @@ def __init__(self):
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
)

def forward(self, x):
Expand All @@ -209,8 +260,8 @@ def forward(self, x):
enable_passes=True,
)

# Testing with (-1, -1, -1, -1, -1) results into Error:
# AssertionError: Channel dim can't be dynamic for deconvolution.
# # Testing with (-1, -1, -1, -1, -1) results into Error:
# # AssertionError: Channel dim can't be dynamic for deconvolution.

def test_deconv3d_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
Expand Down
Loading