Skip to content

Commit

Permalink
Add test case for ITensor weight in convolution and fix related bug (#…
Browse files Browse the repository at this point in the history
…3327)

Co-authored-by: Hoonkyung Cho <[email protected]>
  • Loading branch information
chohk88 and Hoonkyung Cho authored Dec 17, 2024
1 parent 544c545 commit 7767594
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
Expand Down Expand Up @@ -68,10 +69,9 @@ def convNd(
weight = get_trt_tensor(ctx, weight, f"{name}_weight")
# Append new dimension (unsqueeze) if the convolution is 1d
if is_conv1d:
input = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, name + "_unsqueeze_weight", weight, -1
weight = impl.unsqueeze.unsqueeze(
ctx, target, source_ir, weight.name + "_unsqueeze_conv1d", weight, -1
)

elif isinstance(weight, (torch.Tensor, np.ndarray)):
# Transform the weight constant into a Numpy array
weight = to_numpy(weight, dtype=input.dtype)
Expand Down
49 changes: 49 additions & 0 deletions tests/py/dynamo/conversion/test_convolution_aten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand Down Expand Up @@ -45,6 +46,54 @@ def forward(self, x):
enable_passes=True,
)

@parameterized.expand(
[
("default", 1),
param("no_bias", 1, bias=False),
("tuple_parameters", 1, (1), (1)),
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
]
)
def test_conv1d_TRTTensor_weight(
self,
_,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, w):
return torch.ops.aten.convolution.default(
x,
w,
None,
(stride,) if isinstance(stride, int) else stride,
(padding,) if isinstance(padding, int) else padding,
(dilation,) if isinstance(dilation, int) else dilation,
False,
(0,),
groups,
)

inputs = [
torch.randn(1, 3, 32),
torch.randn(
6, 3, 1
), # Conv1d weight shape: (out_channels, in_channels, kernel_size)
]
self.run_test(
TestModule(),
inputs,
use_dynamo_tracer=True,
)

def test_conv1d_with_dynamic_shape(
self,
kernel_size=1,
Expand Down

0 comments on commit 7767594

Please sign in to comment.