Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add toggle for fallback to Inductor #2301

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
REQUIRE_FULL_COMPILATION = False
FALLBACK_TO_INDUCTOR = True


def default_device() -> Device:
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_tensorrt.dynamo._defaults import (
DEBUG,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
FALLBACK_TO_INDUCTOR,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
OPTIMIZATION_LEVEL,
Expand Down Expand Up @@ -42,6 +43,8 @@ class CompilationSettings:
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
fallback_to_inductor (bool): Whether to fallback to inductor on Torch-TRT Compilation Errors.
Is overridden by pass_through_build_failures.
"""

precision: torch.dtype = PRECISION
Expand All @@ -59,3 +62,4 @@ class CompilationSettings:
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
device: Device = field(default_factory=default_device)
require_full_compilation: bool = REQUIRE_FULL_COMPILATION
fallback_to_inductor: bool = FALLBACK_TO_INDUCTOR
22 changes: 20 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,17 @@ def _pretraced_backend(
settings=settings,
)
return trt_compiled
except AssertionError:
except (AssertionError, RuntimeError):
if not settings.pass_through_build_failures:
logger.warning(
"TRT conversion failed on the subgraph. See trace above. "
+ "Returning GraphModule forward instead.",
exc_info=True,
)
return gm.forward
if settings.fallback_to_inductor:
pass
else:
return gm
else:
logger.critical(
"Halting compilation on build failure since "
Expand All @@ -100,3 +103,18 @@ def _pretraced_backend(
+ "specify pass_through_build_failures=False."
)
raise

# If Inductor fallback is desired, attempt model compilation with inductor
try:
inductor_compiled = torch._inductor.compile(
gm,
sample_inputs,
)
return inductor_compiled
except (AssertionError, RuntimeError):
logger.warning(
"Inductor compilation failed on the subgraph. See trace above. "
+ "Returning GraphModule forward instead.",
exc_info=True,
)
return gm
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
DEBUG,
DEVICE,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
FALLBACK_TO_INDUCTOR,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
OPTIMIZATION_LEVEL,
Expand Down Expand Up @@ -69,6 +70,7 @@ def compile(
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
fallback_to_inductor: bool = FALLBACK_TO_INDUCTOR,
**kwargs: Any,
) -> torch.fx.GraphModule:
if debug:
Expand All @@ -84,7 +86,7 @@ def compile(
"max_aux_streams, version_compatible, optimization_level, "
"torch_executed_ops, pass_through_build_failures, "
"use_fast_partitioner, enable_experimental_decompositions, "
"require_full_compilation}"
"require_full_compilation, fallback_to_inductor}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -130,6 +132,7 @@ def compile(
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
"require_full_compilation": require_full_compilation,
"fallback_to_inductor": fallback_to_inductor,
}

settings = CompilationSettings(**compilation_options)
Expand Down