Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grid_sample #3340

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/conversion/impl/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
from torch_tensorrt.fx.types import TRTTensor
HolyWu marked this conversation as resolved.
Show resolved Hide resolved

# nearest, linear, cubic
# bilinear, nearest, bicubic
GridSamplerInterpolationMode = {
0: trt.InterpolationMode.NEAREST,
1: trt.InterpolationMode.LINEAR,
0: trt.InterpolationMode.LINEAR,
1: trt.InterpolationMode.NEAREST,
2: trt.InterpolationMode.CUBIC,
}

Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,15 @@ def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"])


@register_torch_trt_decomposition(
aten.cudnn_grid_sampler, registry=TORCH_TRT_DECOMPOSITIONS
)
def cudnn_grid_sampler_decomposition(
x: torch.Tensor, grid: torch.Tensor
) -> torch.Tensor:
return torch.grid_sampler_2d(x, grid, 0, 0, True)


def get_decompositions(
enable_experimental_decompositions: bool = False,
) -> Dict[OpOverload, Callable[[Any], Any]]:
Expand Down
54 changes: 53 additions & 1 deletion tests/py/dynamo/lowering/test_decompositions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
import torch_tensorrt
from parameterized import parameterized
from testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing
from torch.testing._internal.common_utils import TestCase, run_tests
from torch_tensorrt.dynamo.utils import ATOL, RTOL

from ..testing_utilities import DECIMALS_OF_AGREEMENT, lower_graph_testing


class TestLowering(TestCase):
Expand Down Expand Up @@ -1720,6 +1722,56 @@ def forward(self, input, weight, bias, running_mean=None, running_var=None):
"Instance_norm TRT outputs don't match with the original model.",
)

def test_lowering_cudnn_grid_sampler(self):
class TestModule(torch.nn.Module):
def forward(self, x, grid):
return torch.ops.aten.cudnn_grid_sampler.default(x, grid)

# Operations expected to be removed in the traced graph after decompositions
expected_ops = {torch.ops.aten.grid_sampler_2d.default}
unexpected_ops = {torch.ops.aten.cudnn_grid_sampler.default}

inputs = [
torch.randn(1, 3, 5, 7, device="cuda"),
torch.randn(1, 5, 7, 2, device="cuda"),
]

exported_program = torch.export.export(TestModule(), tuple(inputs))
fx_graph = exported_program.module()
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)

self.assertEqual(
len(unexpected_ops_seen),
0,
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
)

self.assertEqual(
len(expected_ops_unseen),
0,
f"The following expected ops were not encountered: {expected_ops_unseen}",
)

torch._dynamo.reset()

# Validate that the results between Torch and Torch-TRT are similar
trt_model = torch_tensorrt.dynamo.compile(
exported_program, inputs, min_block_size=1
)
torch.testing.assert_close(
trt_model(*inputs),
fx_graph(*inputs),
rtol=RTOL,
atol=ATOL,
msg="Cudnn_grid_sampler TRT outputs don't match with the original model.",
)


if __name__ == "__main__":
run_tests()
Loading