Skip to content

Commit

Permalink
feat: Support weight streaming (#3111)
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna authored Oct 23, 2024
1 parent dad195b commit 92bf700
Show file tree
Hide file tree
Showing 20 changed files with 680 additions and 6 deletions.
38 changes: 38 additions & 0 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ TRTEngine::TRTEngine(
cuda_engine = make_trt(rt->deserializeCudaEngine(serialized_engine.c_str(), serialized_engine.size()));
TORCHTRT_CHECK((cuda_engine.get() != nullptr), "Unable to deserialize the TensorRT engine");

if (get_streamable_device_memory_budget() > 0) {
int64_t budget_bytes = get_automatic_device_memory_budget();
LOG_DEBUG("Weight streaming budget set to " << budget_bytes << "B");
cuda_engine->setWeightStreamingBudgetV2(budget_bytes);
}

exec_ctx = make_trt(cuda_engine->createExecutionContext());
TORCHTRT_CHECK((exec_ctx.get() != nullptr), "Unable to create TensorRT execution context");

Expand Down Expand Up @@ -258,6 +264,38 @@ void TRTEngine::set_profiling_paths() {
cuda_graph_debug_path = std::filesystem::path{profile_path_prefix + "/" + name + "_cudagraph.dot"}.string();
}

int64_t TRTEngine::get_device_memory_budget() {
return cuda_engine->getWeightStreamingBudgetV2();
}

bool TRTEngine::set_device_memory_budget(int64_t budget) {
// Recreating the context because weight streaming budget cannot be modified while there are active context.
if (exec_ctx.get() != nullptr) {
exec_ctx.reset();
}
if (profile_execution) {
trt_engine_profiler.reset();
}
bool result = cuda_engine->setWeightStreamingBudgetV2(budget);
exec_ctx = make_trt(cuda_engine->createExecutionContext());
TORCHTRT_CHECK(
(exec_ctx.get() != nullptr),
"Unable to recreate TensorRT execution context after setting new device memory budget");
if (profile_execution) {
enable_profiling();
}
return result;
}

// Returns 0 if BuilderFlag::kWEIGHT_STREAMING is unset during engine building.
int64_t TRTEngine::get_streamable_device_memory_budget() {
return cuda_engine->getStreamableWeightsSize();
}

int64_t TRTEngine::get_automatic_device_memory_budget() {
return cuda_engine->getWeightStreamingAutomaticBudget();
}

std::string TRTEngine::to_str() const {
// clang-format off
std::stringstream ss;
Expand Down
4 changes: 4 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ struct TRTEngine : torch::CustomClassHolder {
std::string get_engine_layer_info();
void dump_engine_layer_info_to_file(const std::string& path);
void dump_engine_layer_info();
int64_t get_device_memory_budget();
bool set_device_memory_budget(int64_t budget);
int64_t get_streamable_device_memory_budget();
int64_t get_automatic_device_memory_budget();
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
static const char BINDING_DELIM = '%';

Expand Down
6 changes: 6 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
.def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file)
.def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info)
.def("get_engine_layer_info", &TRTEngine::get_engine_layer_info)
.def_property(
"device_memory_budget",
&TRTEngine::get_device_memory_budget,
&TRTEngine::set_device_memory_budget)
.def_property("streamable_device_memory_budget", &TRTEngine::get_streamable_device_memory_budget)
.def_property("automatic_device_memory_budget", &TRTEngine::get_automatic_device_memory_budget)
.def_pickle(
[](const c10::intrusive_ptr<TRTEngine>& self) -> std::vector<std::string> {
// Serialize TensorRT engine
Expand Down
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Tutorials
* :ref:`converter_overloading`
* :ref:`custom_kernel_plugins`
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`

.. toctree::
:caption: Tutorials
Expand All @@ -82,6 +83,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/converter_overloading
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example

Dynamo Frontend
----------------
Expand Down
174 changes: 174 additions & 0 deletions examples/dynamo/weight_streaming_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""
.. _weight_streaming_example:
Weight Streaming
=======================
Weight streaming in TensorRT is a powerful feature designed to overcome GPU memory limitations
when working with large models. It enables running models larger than available GPU memory
by streaming weight data from host (CPU) memory to GPU memory during inference.
Streaming larger amounts of memory will likely result in lower performance. But if
streaming weights allows the user to run larger batch sizes and it can lead to higher throughput.
This increased throughput can sometimes outweigh the slowdown caused by streaming weights.
The optimal amount of memory to stream varies depending on the specific model and hardware.
Experimenting with different memory limits can help find the best balance between streaming
overhead and batch size benefits.
This example uses a pre-trained Llama-2 model and show how to use weight streaming feature with
Torch-TensorRT.
1. compile option - build trt engine with weight streaming feature
2. runtime api - weight streaming budget control by context manager
"""

# %%
# Imports and Model Definition
# ----------------------------------

import copy
import timeit

import numpy as np
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM
from utils import export_llm


def time_generate(model, inputs, output_seq_length, iterations=10):
"""
Measure the time for generating a sentence over certain number of iterations
"""
# We only support single input (B x seq_len) for LLMs now
input_seq = inputs[0]
with torch.no_grad():
timings = []
for _ in range(iterations):
start_time = timeit.default_timer()
inputs_copy = copy.copy(input_seq)
# Greedy decoding of the model. This generates up to max_tokens.
while inputs_copy.shape[1] <= output_seq_length:
outputs = model(inputs_copy)
logits = outputs.logits
next_token_logits = logits[:, -1, :]
next_tokens = torch.argmax(next_token_logits, dim=-1)
inputs_copy = torch.cat([inputs_copy, next_tokens[:, None]], dim=-1)
torch.cuda.synchronize()
end_time = timeit.default_timer()
timings.append(end_time - start_time)

times = np.array(timings)
time_mean_ms = np.mean(times) * 1000

return time_mean_ms


# Load the LLaMA-2 model
DEVICE = torch.device("cuda:0")
llama_path = "meta-llama/Llama-2-7b-chat-hf"
with torch.no_grad():
model = AutoModelForCausalLM.from_pretrained(
llama_path, use_cache=False, attn_implementation="eager"
).eval()

# Set input and output sequence lengths
isl = 128
osl = 256

# Create random input tensors
input_tensors = [torch.randint(0, 5, (1, isl), dtype=torch.int64).cuda()]
# Convert the model to half precision (FP16)
model = model.half()
# Exports the LLM model into an ExportedProgram with dynamic shapes.
llama2_ep = export_llm(model, input_tensors[0], max_seq_len=osl)

# %%
# Compiler option
# ----------------------------------
#
# enable_weight_streaming=True option and use_explicit_typing=True are required to build
# the engine with weight streaming feature. use_explicit_typing=True option creates a
# `strongly typed network <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strongly-typed-networks>`_ and only float32 precision is allowed in enabled_precisions option
#

# Create a TensorRT-compiled model
trt_model = torch_tensorrt.dynamo.compile(
llama2_ep,
inputs=input_tensors,
enabled_precisions={torch.float32},
truncate_double=True,
device=DEVICE,
use_explicit_typing=True,
enable_weight_streaming=True,
)

# Warm up for 3 iterations
_ = time_generate(trt_model, input_tensors, osl, 3)

# %%
# Running with automatic budget size
# ----------------------------------
#
# Once you specify the enable_weight_streaming compile option, automatic budget size is configured.
# This automatic size may not always provide the optimal solution because the automatically determined
# budget lacks insight into the user's specific memory constraints and usage patterns

# Weight streaming context to get current weight budget information
weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(trt_model)
# Measure the mean latency of the model with weight streaming
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
# Calculate the percentage of current weight budget used
weight_budget_pct = (
weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget * 100
)
print(
f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms"
)

# %%
# Running with weight streaming context manager
# ----------------------------------
#
# Weight streaming budget can be limited by using weight streaming context manager.
# The permissible range for the budget size is from 0 to ctx.total_device_budget.
# 0 means maximum memory savings occur by using minimum amounts of memory. Value
# equal to ctx.total_device_budget will disable weight streaming.
# If multiple trt engines are created, budgets are distributed proportionally

# Use a context manager for weight streaming
with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx:
# Get the total size of streamable weights in the engine
streamable_budget = weight_streaming_ctx.total_device_budget

# Scenario 1: Automatic weight streaming budget
# Get the automatically determined weight streaming budget
requested_budget = weight_streaming_ctx.get_automatic_weight_streaming_budget()
# Set the device budget to the automatically determined value
weight_streaming_ctx.device_budget = requested_budget
# Measure the mean latency with automatic budget
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
# Calculate the percentage of the weight budget used
weight_budget_pct = (
weight_streaming_ctx.device_budget
/ weight_streaming_ctx.total_device_budget
* 100
)
print(
f"Set auto weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms"
)

# Scenario 2: Manual 10% weight streaming budget
# Set the budget to 10% of the total streamable weights
requested_budget = int(streamable_budget * 0.1)
weight_streaming_ctx.device_budget = requested_budget
# Measure the mean latency with 10% budget
mean_latency = time_generate(trt_model, input_tensors, osl, 1)
# Calculate the percentage of the weight budget used
weight_budget_pct = (
weight_streaming_ctx.device_budget
/ weight_streaming_ctx.total_device_budget
* 100
)
print(
f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms"
)
10 changes: 10 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def compile(
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -162,6 +163,7 @@ def compile(
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
enable_weight_streaming (bool): Enable weight streaming.
**kwargs: Any,
Returns:
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
Expand Down Expand Up @@ -215,6 +217,10 @@ def compile(
This flag inserts casts around matmul layers and ensures TensorRT executes the matmul layers in FP16 with FP32 accumulation."
)

if enable_weight_streaming and not use_explicit_typing:
raise AssertionError(
"When enable_weight_streaming is enabled, it requires use_explicit_typing to be set to True"
)
# Aliasing inputs to arg_inputs for better understanding
if not arg_inputs and not inputs:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
Expand Down Expand Up @@ -291,6 +297,7 @@ def compile(
"reuse_cached_engines": reuse_cached_engines,
"use_explicit_typing": use_explicit_typing,
"use_fp32_acc": use_fp32_acc,
"enable_weight_streaming": enable_weight_streaming,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -549,6 +556,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
) -> bytes:
"""Convert an ExportedProgram to a serialized TensorRT engine
Expand Down Expand Up @@ -609,6 +617,7 @@ def convert_exported_program_to_serialized_trt_engine(
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
enable_weight_streaming (bool): Enable weight streaming.
Returns:
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
"""
Expand Down Expand Up @@ -684,6 +693,7 @@ def convert_exported_program_to_serialized_trt_engine(
"timing_cache_path": timing_cache_path,
"use_explicit_typing": use_explicit_typing,
"use_fp32_acc": use_fp32_acc,
"enable_weight_streaming": enable_weight_streaming,
}

settings = CompilationSettings(**compilation_options)
Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
CUSTOM_ENGINE_CACHE = None
USE_EXPLICIT_TYPING = False
USE_FP32_ACC = False
ENABLE_WEIGHT_STREAMING = False


def default_device() -> Device:
Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DLA_SRAM_SIZE,
DRYRUN,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
Expand Down Expand Up @@ -82,6 +83,7 @@ class CompilationSettings:
reuse_cached_engines (bool): Whether to load the compiled TRT engines from storage
use_strong_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
enable_weight_streaming (bool): Enable weight streaming.
"""

enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
Expand Down Expand Up @@ -118,6 +120,7 @@ class CompilationSettings:
reuse_cached_engines: bool = REUSE_CACHED_ENGINES
use_explicit_typing: bool = USE_EXPLICIT_TYPING
use_fp32_acc: bool = USE_FP32_ACC
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand All @@ -130,6 +133,7 @@ class CompilationSettings:
"make_refittable",
"engine_capability",
"hardware_compatible",
"enable_weight_streaming",
)


Expand Down
3 changes: 3 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ def _populate_trt_builder_config(
if tactic_sources is not None:
builder_config.set_tactic_sources(tactic_sources=tactic_sources)

if self.compilation_settings.enable_weight_streaming:
builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)

return builder_config

def _create_timing_cache(
Expand Down
Loading

0 comments on commit 92bf700

Please sign in to comment.