Skip to content

Commit

Permalink
Cherrypick: nccl ops multi gpu (#3342)
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Co-authored-by: Naren Dasan <[email protected]>
  • Loading branch information
apbose and narendasan authored Jan 9, 2025
1 parent 91cf21e commit 37cb8a7
Show file tree
Hide file tree
Showing 24 changed files with 1,845 additions and 64 deletions.
34 changes: 34 additions & 0 deletions examples/distributed_inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,37 @@ See the examples started with `data_parallel` for more details.
Here we use torch.distributed as an example, but compilation with tensor parallelism is agnostic to the implementation framework as long as the module is properly sharded.

torchrun --nproc_per_node=2 tensor_parallel_llama2.py

3. Tensor parallel distributed inference using nccl ops plugin

apt install libmpich-dev

apt install libopenmpi-dev

#For python3.10

pip install tensorrt-llm

For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so. Please set that in the environment variable export TRTLLM_PLUGINS_PATH={lib_path}. For example, we have already set the variable in initialize_distributed_env(). You can replace this with your TRTLLM_PLUGINS_PATH and unset it there

#then pip install the tensorrt and torch version compatible with installed torchTRT

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py

#For other python

4. Tensor parallel distributed llama3 inference using nccl ops plugin

apt install libmpich-dev

apt install libopenmpi-dev

#For python3.10

pip install tensorrt-llm

For other python versions, you need to load the libnvinfer_plugin_tensorrt_llm.so

#then pip install the tensorrt and torch version compatible with installed torchTRT

mpirun -n 2 --allow-run-as-root python tensor_parallel_llama3.py
3 changes: 2 additions & 1 deletion examples/distributed_inference/requirement.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
transformers
diffusers
diffusers
tensorrt-llm
67 changes: 67 additions & 0 deletions examples/distributed_inference/tensor_parallel_initialize_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import tensorrt as trt
import torch
import torch.distributed as dist
from torch.distributed._tensor.device_mesh import init_device_mesh


def find_repo_root(max_depth=10):
dir_path = os.path.dirname(os.path.realpath(__file__))
for i in range(max_depth):
files = os.listdir(dir_path)
if "MODULE.bazel" in files:
return dir_path
else:
dir_path = os.path.dirname(dir_path)

raise RuntimeError("Could not find repo root")


def initialize_logger(rank, logger_file_name):
logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logger_file_name + f"_{rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
return logger


# This is required for env initialization since we use mpirun
def initialize_distributed_env(logger_file_name, rank=0, world_size=1, port=29500):
local_rank = int(
os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK", rank % torch.cuda.device_count())
)
world_size = int(os.environ.get("OMPI_COMM_WORLD_SIZE", world_size))

# Set up environment variable to run with mpirun
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = str(port)
os.environ["TRTLLM_PLUGINS_PATH"] = (
find_repo_root() + "/lib/libnvinfer_plugin_tensorrt_llm.so"
)

# Necessary to assign a device to each rank.
torch.cuda.set_device(local_rank)

# We use nccl backend
dist.init_process_group("nccl")

# set a manual seed for reproducibility
torch.manual_seed(1111)

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(world_size,))
rank = device_mesh.get_rank()
assert rank == local_rank
logger = initialize_logger(rank, logger_file_name)
device_id = (
rank % torch.cuda.device_count()
) # Ensure each rank gets a unique device
torch.cuda.set_device(device_id)

return device_mesh, world_size, rank, logger
26 changes: 12 additions & 14 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,25 @@
import time

import torch
import torch_tensorrt
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

_rank = int(os.environ["RANK"])
_world_size = int(os.environ["WORLD_SIZE"])
tp_size = 2

logger = logging.getLogger()
logger.setLevel(logging.INFO)
fh = logging.FileHandler(f"./tensor_parallel_log_{_rank}.log", mode="w")
fh.setLevel(logging.INFO)
logger.addHandler(fh)
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm

tp_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

model_args = ModelArgs(
vocab_size=32000,
Expand All @@ -38,7 +36,7 @@
)

with torch.no_grad():
model = ParallelTransformer(model_args, tp_mesh)
model = ParallelTransformer(model_args, device_mesh)
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
Expand All @@ -53,7 +51,7 @@
"use_python_runtime": True,
"workspace_size": 1 << 33,
"debug": False,
"timing_cache_path": "/opt/file/cache/timing_cache_llama.bin",
"use_aot_joint_export": False,
},
dynamic=False,
)
Expand Down
24 changes: 11 additions & 13 deletions examples/distributed_inference/tensor_parallel_simple_example.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
import os
import sys
import time

import tensorrt as trt
import torch
import torch.nn as nn
import torch_tensorrt
from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._tensor import Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)

device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_simple_example"
)
import tensorrt_llm

"""
This example copies some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
Expand All @@ -36,14 +40,7 @@ def forward(self, x):
return x


# create a device mesh based on the given world_size.
_world_size = int(os.environ["WORLD_SIZE"])

device_mesh = init_device_mesh(device_type="cuda", mesh_shape=(_world_size,))
_rank = device_mesh.get_rank()


print(f"Starting PyTorch TP example on rank {_rank}.")
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
Expand Down Expand Up @@ -78,6 +75,7 @@ def forward(self, x):
"enabled_precisions": {torch.float32, torch.float16},
"use_python_runtime": True,
"min_block_size": 1,
"use_aot_joint_export": False,
},
dynamic=False,
)
Expand All @@ -91,9 +89,9 @@ def forward(self, x):
output = tp_model(inp)
end = time.time()
if i == 0:
print(f"Compilation time is {end-start}")
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
print(f"Inference time is {end-start}")
logger.info(f"Inference time is {end-start}")
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
IMMUTABLE_WEIGHTS = True
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
USE_AOT_JOINT_EXPORT = True


def default_device() -> Device:
Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
STRIP_ENGINE_WEIGHTS,
TIMING_CACHE_PATH,
TRUNCATE_DOUBLE,
USE_AOT_JOINT_EXPORT,
USE_EXPLICIT_TYPING,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
Expand Down Expand Up @@ -91,6 +92,7 @@ class CompilationSettings:
enable_weight_streaming (bool): Enable weight streaming.
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -131,6 +133,7 @@ class CompilationSettings:
immutable_weights: bool = IMMUTABLE_WEIGHTS
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
56 changes: 47 additions & 9 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from __future__ import annotations

import functools
import logging
import unittest
from typing import Any, Callable, Sequence

import torch
import torch._dynamo as td
from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.utils import detect_fake_mode
from torch._functorch.aot_autograd import aot_export_joint_simple
from torch_tensorrt.dynamo import CompilationSettings
from torch_tensorrt.dynamo._compiler import compile_module
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
modify_reshape_complex_nodes,
post_lowering,
remove_detach,
remove_sym_nodes,
Expand Down Expand Up @@ -49,7 +52,25 @@ def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule, sample_inputs: Sequence[Any], **kwargs: Any
) -> torch.nn.Module:
settings, engine_cache = parse_dynamo_kwargs(kwargs)
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
if settings.use_aot_joint_export:
return _pretraced_backend(gm, sample_inputs, settings, engine_cache)
logger.debug("Wrapping the backend with aot_autograd\n")
_pretraced_backend_autograd = functools.partial(
_pretraced_backend, settings=settings, engine_cache=engine_cache
)
settings_aot_autograd = {}
settings_aot_autograd["decompostions"] = get_decompositions(
settings.enable_experimental_decompositions
)
# This is added since detach lowering leads to alias nodes
# Error - View operation returned a tensor that is the same as the input base tensor
# torch nop_decompositions in torch/_decomp/decompositions.py
if aten.detach in settings_aot_autograd["decompositions"]:
del settings_aot_autograd["decompositions"][aten.detach]
return aot_autograd(
fw_compiler=_pretraced_backend_autograd,
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)(gm, sample_inputs)


def _pretraced_backend(
Expand Down Expand Up @@ -89,22 +110,39 @@ def _pretraced_backend(
# Remove detach nodes
remove_detach(gm, settings)

complexInputIndices = []
for i, torch_input in enumerate(torch_inputs):
if torch_inputs[i].dtype == torch.complex64:
complexInputIndices.append(i)
torch_input_real = torch_inputs[i].real
torch_input_imaginary = torch_inputs[i].imag
torch_inputs[i] = torch.stack(
(torch_input_real, torch_input_imaginary), dim=-1
)

# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
)
if settings.use_aot_joint_export:
gm = aot_export_joint_simple(
gm,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
),
)

logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))

gm = post_lowering(gm, settings)

logger.debug("Lowered Input graph:\n " + str(gm.graph))

if complexInputIndices:
modify_reshape_complex_nodes(gm, complexInputIndices)
logger.debug(
"Input graph after modifying complex nodes:\n " + str(gm.graph)
)

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
Expand Down
8 changes: 7 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from . import aten_ops_converters, ops_evaluators, plugins, prims_ops_converters
from . import (
aten_ops_converters,
custom_ops_converters,
ops_evaluators,
plugins,
prims_ops_converters,
)
from ._conversion import convert_module, interpret_module_to_result
from ._ConversionContext import ConversionContext
from ._ConverterRegistry import * # noqa: F403
Expand Down
Loading

0 comments on commit 37cb8a7

Please sign in to comment.