Skip to content

Commit

Permalink
chore: Access user settings within the lowering system (#3245)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 authored Oct 21, 2024
1 parent 4be64a8 commit aa37194
Show file tree
Hide file tree
Showing 24 changed files with 194 additions and 110 deletions.
33 changes: 18 additions & 15 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,6 @@ def compile(
raise AssertionError(
f"Input graph should be an ExportedProgram but got type {type(exported_program)}"
)
exported_program = pre_export_lowering(exported_program)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
)
gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
gm = post_lowering(gm, use_fp32_acc=use_fp32_acc)
logger.debug("Lowered Input graph: " + str(gm.graph))

engine_cache = None
if cache_built_engines or reuse_cached_engines:
Expand Down Expand Up @@ -305,6 +295,19 @@ def compile(

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

exported_program = pre_export_lowering(exported_program, settings)
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
)

gm = exported_program.module()
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph: " + str(gm.graph))

trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
Expand Down Expand Up @@ -683,7 +686,10 @@ def convert_exported_program_to_serialized_trt_engine(
"use_fp32_acc": use_fp32_acc,
}

exported_program = pre_export_lowering(exported_program)
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

exported_program = pre_export_lowering(exported_program, settings)
# Decompose the exported program
exported_program = exported_program.run_decompositions(
get_decompositions(enable_experimental_decompositions)
Expand All @@ -692,12 +698,9 @@ def convert_exported_program_to_serialized_trt_engine(
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
gm = post_lowering(gm)
gm = post_lowering(gm, settings)
logger.debug("Lowered Input graph: " + str(gm.graph))

settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,15 +292,15 @@ def refit_module_weights(
raise AssertionError(
f"Input graph should be an ExportedProgram but got type {type(new_weight_module)}"
)
new_weight_module = pre_export_lowering(new_weight_module)
new_weight_module = pre_export_lowering(new_weight_module, settings)
new_weight_module = new_weight_module.run_decompositions(
get_decompositions(settings.enable_experimental_decompositions)
)
new_gm = new_weight_module.module()
logger.debug("Input graph: " + str(new_gm.graph))
# Apply lowering on the graph module

new_gm = post_lowering(new_gm)
new_gm = post_lowering(new_gm, settings)

logger.info("Compilation Settings: %s\n", settings)

Expand Down Expand Up @@ -397,7 +397,7 @@ def refit_module_weights(
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
engine = compiled_submodule.engine
elif isinstance(compiled_submodule, TorchTensorRTModule):
engine_info = compiled_submodule.engine.__getstate__()[0] # type: ignore[index]
engine_info = compiled_submodule.engine.__getstate__()[0]
engine = get_engine_from_encoded_engine(
engine_info[ENGINE_IDX], runtime
)
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,16 @@ def _pretraced_backend(
with unittest.mock.patch.object(
fake_mode, "allow_non_fake_inputs", True
), fake_mode:
repair_input_aliasing(gm)
repair_input_aliasing(gm, settings)

# Remove sym_int placeholders and inputs
remove_sym_nodes(gm)
remove_sym_nodes(gm, settings)
torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]

# Remove detach nodes
remove_detach(gm)
remove_detach(gm, settings)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
Expand All @@ -100,7 +100,7 @@ def _pretraced_backend(

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = post_lowering(gm, use_fp32_acc=settings.use_fp32_acc)
gm = post_lowering(gm, settings)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

Expand Down
5 changes: 1 addition & 4 deletions py/torch_tensorrt/dynamo/lowering/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,4 @@
torch_enabled_decompositions,
)
from ._decompositions import get_decompositions # noqa: F401
from ._remove_sym_nodes import remove_sym_nodes
from ._repair_input_aliasing import repair_input_aliasing
from .passes import post_lowering, pre_export_lowering
from .passes.remove_detach import remove_detach
from .passes import *
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from ._aten_lowering_pass import *
from .remove_sym_nodes import remove_sym_nodes
from .repair_input_aliasing import repair_input_aliasing
19 changes: 12 additions & 7 deletions py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import logging
from typing import Any, Callable, Optional, Sequence, Union
from typing import Callable, Optional, Sequence, Union

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings

from .accumulate_fp32_matmul import accumulate_fp32_matmul
from .constant_folding import constant_fold
Expand Down Expand Up @@ -29,6 +30,7 @@
replace_full_like_with_full,
view_to_reshape,
remove_assert_scalar,
accumulate_fp32_matmul,
]
)

Expand Down Expand Up @@ -91,25 +93,28 @@ def _remove_lowering_pass(*, index: int) -> None:
return


def post_lowering(gm: torch.fx.GraphModule, **kwargs: Any) -> torch.fx.GraphModule:
def post_lowering(
gm: torch.fx.GraphModule, settings: CompilationSettings = CompilationSettings()
) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_POST_LOWERING_PASSES}"
)
gm = ATEN_POST_LOWERING_PASSES(gm)
if kwargs.get("use_fp32_acc", False):
gm = accumulate_fp32_matmul(gm)
gm = ATEN_POST_LOWERING_PASSES(gm, settings)

return gm


def pre_export_lowering(ep: torch.export.ExportedProgram) -> torch.fx.GraphModule:
def pre_export_lowering(
ep: torch.export.ExportedProgram,
settings: CompilationSettings = CompilationSettings(),
) -> torch.fx.GraphModule:
"""Applies the lowering passes to a graph module after torch.export/ torch.compile and their decompositions, returns the modified GraphModule"""
logging.debug(
f"Invoking DynamoPassManager and applying lowering passes: {ATEN_PRE_LOWERING_PASSES}"
)
gm = ep.graph_module
gm = ATEN_PRE_LOWERING_PASSES(gm)
gm = ATEN_PRE_LOWERING_PASSES(gm, settings)
return ep


Expand Down
84 changes: 50 additions & 34 deletions py/torch_tensorrt/dynamo/lowering/passes/accumulate_fp32_matmul.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,65 @@
import logging

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def accumulate_fp32_matmul(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def accumulate_fp32_matmul(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace a matmul layer with fp32 accumulation nodes"""
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.addmm.default,
]
matmul_nodes = [node for node in gm.graph.nodes if node.target in matmul_targets]
for matmul_node in matmul_nodes:
# Prior to the matmul node, insert a cast to the 32-bit float32 node
node_inputs = matmul_node.all_input_nodes

for node_input in node_inputs:
with gm.graph.inserting_before(matmul_node):
node_32bit = gm.graph.call_function(
if settings.use_fp32_acc:
matmul_targets = [
torch.ops.aten.mm.default,
torch.ops.aten.bmm.default,
torch.ops.aten.addmm.default,
]

matmul_nodes = [
node for node in gm.graph.nodes if node.target in matmul_targets
]
for matmul_node in matmul_nodes:
# Prior to the matmul node, insert a cast to the 32-bit float32 node
node_inputs = matmul_node.all_input_nodes

for node_input in node_inputs:
with gm.graph.inserting_before(matmul_node):
node_32bit = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(node_input,),
kwargs={"dtype": torch.float32},
)

# Replace the input to matmul node with new 32-bit cast node
matmul_node.replace_input_with(node_input, node_32bit)

# Add a cast back to original precision
with gm.graph.inserting_after(matmul_node):
node_orig_precision = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(node_input,),
kwargs={"dtype": torch.float32},
args=(matmul_node,),
kwargs={"dtype": torch.float16},
)
matmul_node.replace_all_uses_with(
node_orig_precision, propagate_meta=False
)
# This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created.
node_orig_precision.replace_input_with(
node_orig_precision.all_input_nodes[0], matmul_node
)

gm = clean_up_graph_after_modifications(gm)
logger.debug(
f"Graph after enabling matmul layers to use FP32 accumulation:\n{gm.graph}"
)
else:
logger.debug(
"Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings"
)

# Replace the input to matmul node with new 32-bit cast node
matmul_node.replace_input_with(node_input, node_32bit)

# Add a cast back to original precision
with gm.graph.inserting_after(matmul_node):
node_orig_precision = gm.graph.call_function(
torch.ops.aten._to_copy.default,
args=(matmul_node,),
kwargs={"dtype": torch.float16},
)
matmul_node.replace_all_uses_with(node_orig_precision, propagate_meta=False)
# This is a hack. replace_all_uses_with isn't working here. It complains node_orig_precision is already being used before created.
node_orig_precision.replace_input_with(
node_orig_precision.all_input_nodes[0], matmul_node
)

gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after changing matmuls to use FP32 accumulation:\n{gm.graph}")
return gm
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch
from torch_tensorrt._utils import sanitized_torch_version
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
Expand All @@ -19,7 +20,9 @@


@torch.utils._python_dispatch._disable_current_modes() # type: ignore
def constant_fold(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def constant_fold(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Adapted from:
https://github.com/pytorch/pytorch/blob/3a79621c9dce17f77fbddc06aab21f6bc477f313/torch/_inductor/freezing.py#L178-L197
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
Expand All @@ -9,7 +10,9 @@


# TODO: Add relevant prims to this fusion
def fuse_prims_broadcast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def fuse_prims_broadcast(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Fuses prim nodes which are effectively the ATen equivalents with keep_dim=True"""
modified_graph = False

Expand Down
5 changes: 4 additions & 1 deletion py/torch_tensorrt/dynamo/lowering/passes/lower_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@
from typing import Callable, Tuple

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)

logger = logging.getLogger(__name__)


def lower_linear(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
def lower_linear(
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace aten.linear with an equivalent implementation which can be easily converted to TRT"""
orig, replacement = linear_replacement()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Callable, Sequence, Tuple

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion.aten_ops_converters import args_bounds_check
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
Expand All @@ -17,7 +18,7 @@


def lower_scaled_dot_product_attention(
gm: torch.fx.GraphModule,
gm: torch.fx.GraphModule, settings: CompilationSettings
) -> torch.fx.GraphModule:
"""Replace specific versions of scaled_dot_product_attention with an equivalent
implementation which can be easily converted to TRT
Expand Down
Loading

0 comments on commit aa37194

Please sign in to comment.