Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grid_sample #3340

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open

Fix grid_sample #3340

wants to merge 2 commits into from

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Dec 31, 2024

Description

This PR fixes two issues in grid_sample.

1. PyTorch defines interpolation mode enum as "bilinear"=0 and "nearest"=1. But the converter impl has 0 and 1 reversed, causing discrepancy between Torch and Torch-TRT.

import os

import torch
import torch.nn.functional as F
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
        return F.grid_sample(x, grid, mode="bilinear", align_corners=False)


with torch.inference_mode():
    model = MyModule().eval().cuda()

    inputs = [torch.randn(1, 3, 224, 224, device="cuda"), torch.randn(1, 224, 224, 2, device="cuda")]

    trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1)

    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
    print("assert_close passed")
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return (grid_sampler_2d,)
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.grid_sampler_2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Detected support for 1 operators out of 1 in subgraph.
INFO:torch_tensorrt.dynamo._compiler:Partitioning the graph via the fast partitioner
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Number of TensorRT-Accelerated Engines Generated: 1
DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
Supported Nodes:
- torch.ops.aten.grid_sampler_2d.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._adjacency_partitioner:
All Nodes Supported

DEBUG:torch_tensorrt.dynamo._compiler:Updated metadata for node: _run_on_acc_0 with its corresponding submodule outputs
DEBUG:torch_tensorrt.dynamo._compiler:Converting submodule: _run_on_acc_0
 Input shapes: [(1, 3, 224, 224), (1, 224, 224, 2)]
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler_2d : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler_2d.default](args = (%x, %grid, 0, 0, False), kwargs = {})
    return grid_sampler_2d
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node x (kind: x, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: x [shape=[1, 3, 224, 224], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node x [x] (Inputs: () | Outputs: (x: (1, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node grid (kind: grid, args: ())
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Adding input to in-progress INetwork: grid [shape=[1, 224, 224, 2], dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node grid [grid] (Inputs: () | Outputs: (grid: (1, 224, 224, 2)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /grid_sampler_2d (kind: aten.grid_sampler_2d.default, args: ('x <Node>', 'grid <Node>', '0 <int>', '0 <int>', 'False <bool>'))
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Converter options for aten.grid_sampler_2d.default: 1
DEBUG:torch_tensorrt.dynamo.conversion._ConverterRegistry:Selecting converter option 0 for converting aten.grid_sampler_2d.default
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node /grid_sampler_2d [aten.grid_sampler_2d.default] (Inputs: (x: (1, 3, 224, 224)@torch.float32, grid: (1, 224, 224, 2)@torch.float32, 0, 0, False) | Outputs: (grid_sampler_2d: (1, 3, 224, 224)@torch.float32))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node output (kind: output, args: ('grid_sampler_2d <Node>',))
DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Marking output output0 [shape=(1, 3, 224, 224), dtype=DataType.FLOAT]
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converted node output [output] (Inputs: (grid_sampler_2d: (1, 3, 224, 224)@torch.float32) | Outputs: (output: ))
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT INetwork construction elapsed time: 0:00:00.001957
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Not found cached TRT engines. Start building engine.
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:00.120683
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 12180 bytes of Memory
DEBUG: [Torch-TensorRT] - Deserializing Device Info: 0%8%9%0%NVIDIA GeForce RTX 4060 Ti
DEBUG: [Torch-TensorRT] - Deserialized Device Info: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Target Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
DEBUG: [Torch-TensorRT] - Setting Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU) as active device
INFO: [Torch-TensorRT] - Loaded engine size: 0 MiB
DEBUG: [Torch-TensorRT] - Deserialization required 3953 microseconds.
DEBUG: [Torch-TensorRT] - Total per-runner device persistent memory is 0
DEBUG: [Torch-TensorRT] - Total per-runner host persistent memory is 80
DEBUG: [Torch-TensorRT] - Allocated device scratch memory of size 0
DEBUG: [Torch-TensorRT] - - Runner scratch: 0 bytes
INFO: [Torch-TensorRT] - [MemUsageChange] TensorRT-managed allocation in IExecutionContext creation: CPU +0, GPU +0, now: CPU 0, GPU 0 (MiB)
DEBUG: [Torch-TensorRT] - CUDA lazy loading is enabled.
DEBUG: [Torch-TensorRT] - Input binding name: x has TensorRT binding index: 0, Torch binding index: 0
DEBUG: [Torch-TensorRT] - Input binding name: grid has TensorRT binding index: 1, Torch binding index: 1
DEBUG: [Torch-TensorRT] - Output binding name: output0 has TensorRT binding index: 2, Torch binding index: 2
DEBUG: [Torch-TensorRT] - Torch-TensorRT TensorRT Engine:
  Name: _run_on_acc_0_engine
  Inputs: [
    id: 0
      name: x
      shape: [1, 3, 224, 224]
      dtype: Float
    id: 1
      name: grid
      shape: [1, 224, 224, 2]
      dtype: Float
  ]
  Outputs: [
    id: 0
      name: output0
      shape: [1, 3, 224, 224]
      dtype: Float
  ]
  Device: Device(ID: 0, Name: NVIDIA GeForce RTX 4060 Ti, SM Capability: 8.9, Type: GPU)
  Hardware Compatibility: Disabled
  Target Platform: windows_x86_64

DEBUG:torch_tensorrt.dynamo._DryRunTracker:
++++++++++++++++++++++++++++++++++++++++++++++++++ Dry-Run Results for Graph ++++++++++++++++++++++++++++++++++++++++++++++++++

The graph consists of 1 Total Operators, of which 1 operators are supported, 100.0% coverage

Compiled with: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=True, workspace_size=0, min_block_size=1, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='C:\\Users\\HolyWu\\AppData\\Local\\Temp\\torch_tensorrt_engine_cache\\timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=False, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False)

  Graph Structure:

   Inputs: List[Tensor: (1, 3, 224, 224)@float32, Tensor: (1, 224, 224, 2)@float32]
    ...
    TRT Engine #1 - Submodule name: _run_on_acc_0
     Engine Inputs: List[Tensor: (1, 3, 224, 224)@float32, Tensor: (1, 224, 224, 2)@float32]
     Number of Operators in Engine: 1
     Engine Outputs: List[Tensor: (1, 3, 224, 224)@float32]
    ...
   Outputs: List[Tensor: (1, 3, 224, 224)@float32]

  ------------------------- Aggregate Stats -------------------------

   Average Number of Operators per TRT Engine: 1.0
   Most Operators in a TRT Engine: 1

  ********** Recommendations **********

   - For minimal graph segmentation, select min_block_size=1 which would generate 1 TRT engine(s)
   - The current level of graph segmentation is equivalent to selecting min_block_size=1 which generates 1 TRT engine(s)
DEBUG: [Torch-TensorRT] - Attempting to run engine (ID: _run_on_acc_0_engine); Hardware Compatible: 0
DEBUG: [Torch-TensorRT] - Input shape changed None -> (1,3,224,224)(1,224,224,2)
DEBUG: [Torch-TensorRT] - Input Name: x Shape: [1, 3, 224, 224]
DEBUG: [Torch-TensorRT] - Input Name: grid Shape: [1, 224, 224, 2]
DEBUG: [Torch-TensorRT] - Output Name: output0 Shape: [1, 3, 224, 224]
Traceback (most recent call last):
  File "C:\Users\HolyWu\Downloads\test.py", line 25, in <module>
    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
  File "C:\Python312\Lib\site-packages\torch\testing\_comparison.py", line 1530, in assert_close
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 69917 / 150528 (46.4%)
Greatest absolute difference: 3.082557439804077 at index (0, 2, 213, 105) (up to 0.005 allowed)
Greatest relative difference: 280988.21875 at index (0, 1, 184, 207) (up to 0.005 allowed)

2. PyTorch dispatches grid_sampler to cudnn_grid_sampler when mode="bilinear", padding_mode="zeros", align_corners=True, causing graph breaks due to unsupported node.

import os

import torch
import torch.nn.functional as F
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor:
        return F.grid_sample(x, grid, mode="bilinear", align_corners=True)


with torch.inference_mode():
    model = MyModule().eval().cuda()

    inputs = [torch.randn(1, 3, 224, 224, device="cuda"), torch.randn(1, 224, 224, 2, device="cuda")]

    trt_model = torch_tensorrt.compile(model, "dynamo", inputs, debug=True, min_block_size=1)

    torch.testing.assert_close(trt_model(*inputs), model(*inputs), rtol=5e-3, atol=5e-3)
    print("assert_close passed")
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.grid_sampler.default](args = (%x, %grid, 0, 0, True), kwargs = {})
    return (grid_sampler,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %grid : [num_users=1] = placeholder[target=grid]
    %cudnn_grid_sampler : [num_users=1] = call_function[target=torch.ops.aten.cudnn_grid_sampler.default](args = (%x, %grid), kwargs = {})
    return (cudnn_grid_sampler,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Unsupported or Excluded Nodes:
- torch.ops.aten.cudnn_grid_sampler.default + Operator Count: 1

WARNING:torch_tensorrt.dynamo._compiler:0 supported operations detected in subgraph containing 1 computational nodes. Skipping this subgraph, since min_block_size was detected to be 1
assert_close passed

Type of change

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Dec 31, 2024
@github-actions github-actions bot requested a review from gs-olive December 31, 2024 18:04
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

py/torch_tensorrt/dynamo/conversion/impl/grid.py Outdated Show resolved Hide resolved
@peri044 peri044 requested review from apbose and removed request for gs-olive January 2, 2025 22:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants