Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor parallel documentation #3359

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Tutorials
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`pre_allocated_output_example`
* :ref:`tensor_parallel_llama`

.. toctree::
:caption: Tutorials
Expand All @@ -87,6 +88,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example
tutorials/_rendered_examples/dynamo/pre_allocated_output_example
tutorials/_rendered_examples/dynamo/tensor_parallel_llama

Dynamo Frontend
----------------
Expand Down
69 changes: 66 additions & 3 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
"""
.. _tensor_parallel_llama:

Torch distributed example for llama3-7B model
======================================================

As model sizes are increasing, large models with billions of parameters are trained with many GPUs, where regular data parallel training is no longer possible. In this example, we illustrate the Llama3-7B model inference using Torch-TensorRT backend, split across multiple GPUs using a form of model parallelism called Tensor Parallelism. We make use of Pytorch Distributed Tensor Parallelism Module. Please refer to these tutorials- https://pytorch.org/tutorials/intermediate/TP_tutorial.html and https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning?section=featured"""

# %%
# Imports and Model Definition
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import logging
import os
import time

import torch

# %%
# Pytorch Tensor Parallel APIs offer set of module level primitives(ParallelStyle) to configure the sharding of tensors in each layer of the model
# ParallelTransformer creates the parallelize_plan for the FeedForward layer of the model
from llama3_model import ModelArgs, ParallelTransformer

from tensor_parallel_initialize_dist import initialize_distributed_env
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
Expand All @@ -14,11 +30,24 @@
checkpoint_wrapper,
)

# %%
# Initialize the distributed environment
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# Depending on the inputs/outputs sharded DTensors layout specified above, proper communication operations are required to transform DTensor layouts
# eg operations: allreduce, allgather, reduce_gather
# NCCL operations enable these operations.
# The below API does the following
# Initialize the communicators and the distributed environment
# Sets the path for the TRT-LLM plugin .so path which is required for the NCCL operations in Torch-TRT backend. Please note that if you are in python3.10 environment, `import tensorrt_llm` should be enough
# Initialize the logger. eg: In case of 2 GPUs, the log files are `./tensor_parallel_llama3_0.log` and `./tensor_parallel_llama3_1.log`
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"./tensor_parallel_llama3"
)
# Import should be after initialization of the TRT-LLM plugin .so path
import tensorrt_llm

# %%
# Model initialization with torch distributed parallel plan
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
Expand All @@ -36,7 +65,39 @@
)

with torch.no_grad():
# The plan is
#plan = {
# "attention": PrepareModuleInput(
# input_layouts=(Shard(1), None),
# desired_input_layouts=(Replicate(), None),
# ),
# "attention.wq": ColwiseParallel(),
# "attention.wk": ColwiseParallel(),
# "attention.wv": ColwiseParallel(),
# "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
# "attention_norm": SequenceParallel(),
# "feed_forward": PrepareModuleInput(
# input_layouts=(Shard(1),),
# desired_input_layouts=(Replicate(),),
# ),
# "feed_forward.w1": ColwiseParallel(),
# "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
# "feed_forward.w3": ColwiseParallel(),
# "ffn_norm": SequenceParallel(),
#}

model = ParallelTransformer(model_args, device_mesh)

# %%
# Model inference with Torch-TensorRT backend
# -------------------------------------------
# When we compile the distributed model using Torch-TensorRT backend, pytorch distributed libraries create the sharded model
# on multiple GPUs and the communicator operations are used for proper communication. In the above,
# `ColwiseParallel` and `RowwiseParallel` shard the attention layers in the column or row fashion.
# `SequenceParallel` performs sharded computations of the normalization layer
# `PrepareModuleInput` configures the model input with proper communication operations
# The NCCL operations used in the distributed backend is handled by the TensorRT-LLM NCCL plugins, which causes no graph breaks now

torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
Expand All @@ -62,9 +123,11 @@
output = model(inp)
end = time.time()
if i == 0:
# Logging the Compilation time
logger.info(f"Compilation time is {end-start}")
assert (
python_result - output
).std() < 0.01, "Compilation result is not correct."
elif _rank == 0:
# Logging the inference time
logger.info(f"Inference time is {end-start}")
Loading