diff --git a/py/torch_tensorrt/dynamo/lowering/_decompositions.py b/py/torch_tensorrt/dynamo/lowering/_decompositions.py index 5fcccb5c77..2418a7dbe4 100644 --- a/py/torch_tensorrt/dynamo/lowering/_decompositions.py +++ b/py/torch_tensorrt/dynamo/lowering/_decompositions.py @@ -400,6 +400,18 @@ def log_softmax_decomposition( ) +@register_torch_trt_decomposition( + torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS +) +def full_like_decomposition(*args, **kwargs) -> torch.Tensor: + input = args[0] + shape = args[0].shape + fill_value = args[1] + kwargs["dtype"] = input.dtype + kwargs["device"] = to_torch_device(default_device()) + return torch.full(shape, fill_value, dtype=kwargs["dtype"], device=kwargs["device"]) + + def get_decompositions( enable_experimental_decompositions: bool = False, ) -> Dict[OpOverload, Callable[[Any], Any]]: diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index a472e52d03..2c7d6d1c02 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -14,7 +14,6 @@ from .remove_detach import remove_detach from .remove_input_alias_fixing_clones import remove_input_alias_fixing_clones from .repair_input_as_output import repair_input_as_output -from .replace_full_like_with_full import replace_full_like_with_full from .replace_max_pool_with_indices import replace_max_pool_with_indices from .view_to_reshape import view_to_reshape @@ -27,7 +26,6 @@ lower_linear, fuse_prims_broadcast, replace_max_pool_with_indices, - replace_full_like_with_full, view_to_reshape, remove_assert_scalar, accumulate_fp32_matmul, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py b/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py deleted file mode 100644 index 9303aafd60..0000000000 --- a/py/torch_tensorrt/dynamo/lowering/passes/replace_full_like_with_full.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging - -import torch -import torch.fx -from torch_tensorrt.dynamo._defaults import default_device -from torch_tensorrt.dynamo._settings import CompilationSettings -from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( - clean_up_graph_after_modifications, -) -from torch_tensorrt.dynamo.utils import to_torch_device - -logger = logging.getLogger(__name__) - - -def replace_full_like_with_full( - gm: torch.fx.GraphModule, settings: CompilationSettings -) -> torch.fx.GraphModule: - """Replace full_like nodes with equivalent full nodes""" - modified_graph = False - - for node in gm.graph.nodes: - if node.target == torch.ops.aten.full_like.default: - modified_graph = True - - # Extract arguments from full_like - input_tensor = node.args[0] - fill_value = node.args[1] - input_dtype = None - input_shape = None - input_device = to_torch_device(default_device()) - if "val" in input_tensor.meta: - input_dtype = input_tensor.meta["val"].dtype - input_device = input_tensor.meta["val"].device - input_shape = list(input_tensor.meta["val"].shape) - elif "tensor_meta" in input_tensor.meta: - input_dtype = input_tensor.meta["tensor_meta"].dtype - input_shape = list(input_tensor.meta["tensor_meta"].shape) - - # There's no memory format argument for torch.full. - # Set the input_device and dtype correspondingly. - new_kwargs = {} - for key, val in node.kwargs.items(): - if key != "memory_format": - new_kwargs[key] = val - new_kwargs["device"] = input_device - new_kwargs["dtype"] = input_dtype - # Replace full_like with full, using the shape as a list - input_nodes = (input_shape, fill_value) - with gm.graph.inserting_after(node): - full_node = gm.graph.call_function( - torch.ops.aten.full.default, - args=input_nodes, - kwargs=new_kwargs, - ) - full_node.meta = node.meta - - node.replace_all_uses_with(full_node) - gm.graph.erase_node(node) - - if modified_graph: - gm = clean_up_graph_after_modifications(gm) - - return gm