From 45c2726822d70c6d4a4ef27ba6f828669ce482b3 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Thu, 7 Sep 2023 14:57:46 -0700 Subject: [PATCH] feat: Add toggle for fallback to Inductor --- py/torch_tensorrt/dynamo/_defaults.py | 1 + py/torch_tensorrt/dynamo/_settings.py | 4 ++++ py/torch_tensorrt/dynamo/backend/backends.py | 22 ++++++++++++++++++-- py/torch_tensorrt/dynamo/compile.py | 5 ++++- 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 103b5f7792..7503d6ee6e 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -15,6 +15,7 @@ USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False REQUIRE_FULL_COMPILATION = False +FALLBACK_TO_INDUCTOR = True def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index bae819eab5..fac00ad548 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -6,6 +6,7 @@ from torch_tensorrt.dynamo._defaults import ( DEBUG, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + FALLBACK_TO_INDUCTOR, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, @@ -42,6 +43,8 @@ class CompilationSettings: truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32 enable_experimental_decompositions (bool): Whether to enable all core aten decompositions or only a selected subset of them + fallback_to_inductor (bool): Whether to fallback to inductor on Torch-TRT Compilation Errors. + Is overridden by pass_through_build_failures. """ precision: torch.dtype = PRECISION @@ -59,3 +62,4 @@ class CompilationSettings: enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS device: Device = field(default_factory=default_device) require_full_compilation: bool = REQUIRE_FULL_COMPILATION + fallback_to_inductor: bool = FALLBACK_TO_INDUCTOR diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 2ba9f4d754..5f7dfcc235 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -83,14 +83,17 @@ def _pretraced_backend( settings=settings, ) return trt_compiled - except AssertionError: + except (AssertionError, RuntimeError): if not settings.pass_through_build_failures: logger.warning( "TRT conversion failed on the subgraph. See trace above. " + "Returning GraphModule forward instead.", exc_info=True, ) - return gm.forward + if settings.fallback_to_inductor: + pass + else: + return gm else: logger.critical( "Halting compilation on build failure since " @@ -100,3 +103,18 @@ def _pretraced_backend( + "specify pass_through_build_failures=False." ) raise + + # If Inductor fallback is desired, attempt model compilation with inductor + try: + inductor_compiled = torch._inductor.compile( + gm, + sample_inputs, + ) + return inductor_compiled + except (AssertionError, RuntimeError): + logger.warning( + "Inductor compilation failed on the subgraph. See trace above. " + + "Returning GraphModule forward instead.", + exc_info=True, + ) + return gm diff --git a/py/torch_tensorrt/dynamo/compile.py b/py/torch_tensorrt/dynamo/compile.py index 2a407552d8..12ae912997 100644 --- a/py/torch_tensorrt/dynamo/compile.py +++ b/py/torch_tensorrt/dynamo/compile.py @@ -15,6 +15,7 @@ DEBUG, DEVICE, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + FALLBACK_TO_INDUCTOR, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, OPTIMIZATION_LEVEL, @@ -69,6 +70,7 @@ def compile( use_python_runtime: bool = USE_PYTHON_RUNTIME, use_fast_partitioner: bool = USE_FAST_PARTITIONER, enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS, + fallback_to_inductor: bool = FALLBACK_TO_INDUCTOR, **kwargs: Any, ) -> torch.fx.GraphModule: if debug: @@ -84,7 +86,7 @@ def compile( "max_aux_streams, version_compatible, optimization_level, " "torch_executed_ops, pass_through_build_failures, " "use_fast_partitioner, enable_experimental_decompositions, " - "require_full_compilation}" + "require_full_compilation, fallback_to_inductor}" ) if not isinstance(inputs, collections.abc.Sequence): @@ -130,6 +132,7 @@ def compile( "use_fast_partitioner": use_fast_partitioner, "enable_experimental_decompositions": enable_experimental_decompositions, "require_full_compilation": require_full_compilation, + "fallback_to_inductor": fallback_to_inductor, } settings = CompilationSettings(**compilation_options)