Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT] Support Multiple EP Context #23294

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
42 changes: 25 additions & 17 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
return false;
}

int FindCtxNodeInGraph(const GraphViewer& graph_viewer) {
// Assumes there's only 1 context node in this subgraph (graph_viewer)
// Returns index of node
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
LOGS_DEFAULT(VERBOSE) << "*#* context node found at index=" << i;
return i;
}
}
return -1;
}

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
Expand Down Expand Up @@ -64,7 +77,8 @@ void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
/*
* Create "EP context node" model where engine information is embedded
*/
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
Expand Down Expand Up @@ -123,17 +137,10 @@ ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
node_attributes->emplace(ONNX_MODEL_FILENAME, *attr_3);

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
graph_build.AddNode(fused_subgraph_name, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
ORT_ENFORCE(graph_build.Resolve().IsOK());

// Serialize modelproto to string
auto new_graph_viewer = graph_build.CreateGraphViewer();
auto model = new_graph_viewer->CreateModel(*logger);
auto model_proto = model->ToProto();
new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

return model_proto.release();
return model_build;
}

/*
Expand Down Expand Up @@ -266,11 +273,11 @@ bool IsWeightStrippedEngineCache(std::filesystem::path& engine_cache_path) {
return engine_cache_path.stem().extension().string() == ".stripped";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx) {
if (!ValidateEPCtxNode(graph_viewer, ctx_node_idx)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
}
auto node = graph_viewer.GetNode(0);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
Expand Down Expand Up @@ -380,14 +387,14 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph
/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
assert(graph_viewer.GetNode(ctx_node_idx)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(ctx_node_idx);
auto& attrs = node->GetAttributes();

// Show the warning if compute capability is not matched
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
if (attrs.find(COMPUTE_CAPABILITY)!=attrs.end() && attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
// Verify if engine was compiled with ampere+ hardware compatibility enabled
if (model_compute_capability == "80+") {
Expand All @@ -414,4 +421,5 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe

return true;
}

} // namespace onnxruntime
9 changes: 6 additions & 3 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ static const std::string EPCONTEXT_WARNING =
for the best model loading time";

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
int FindCtxNodeInGraph(const GraphViewer& graph_viewer);

const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path);
ONNX_NAMESPACE::ModelProto* CreateCtxModel(const GraphViewer& graph_viewer,
std::unique_ptr<Model> CreateCtxModel(const GraphViewer& graph_viewer,
const std::string fused_subgraph_name,
const std::string engine_cache_path,
char* engine_data,
size_t size,
Expand Down Expand Up @@ -67,9 +70,9 @@ class TensorRTCacheModelHandler {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer);
bool ValidateEPCtxNode(const GraphViewer& graph_viewer, const int ctx_node_idx);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer);
Status GetEpContextFromGraph(const GraphViewer& graph_viewer, const int ctx_node_idx);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
Expand Down
120 changes: 89 additions & 31 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2467,15 +2467,48 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
strncpy(model_path_, path_string.c_str(), sizeof(model_path_) - 1);
#endif
model_path_[sizeof(model_path_) - 1] = '\0';

// If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and
// load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation.
// So, simply return the ComputeCapability here.
if (graph.NumberOfNodes() == 1 && GraphHasCtxNode(graph)) {
jingyanwangms marked this conversation as resolved.
Show resolved Hide resolved
SubGraph_t supported_node_vector = {{0}, true};
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
return result;
if (GraphHasCtxNode(graph)) {
if (graph.NumberOfNodes() == 1) {
SubGraph_t supported_node_vector = {{0}, true};
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
return result;
} else {
const size_t number_of_ort_nodes = graph.NumberOfNodes();
SubGraphCollection_t supported_node_vectors;
std::vector<long unsigned int> subgraph_indices;
const std::vector<NodeIndex>& node_index = graph.GetNodesInTopologicalOrder(1 /*priority-based topological sort*/);
for (size_t i = 0; i < number_of_ort_nodes; i++) {
const auto& node = graph.GetNode(node_index[i]);
const bool is_context_node = node && !node->OpType().empty() && node->OpType() == "EPContext";
if (is_context_node) {
// Add previous nonempty subgraph
if (subgraph_indices.size() > 0) {
supported_node_vectors.emplace_back(subgraph_indices, true);
}
// Add epcontext node, which is always just 1 node
supported_node_vectors.emplace_back(std::vector<long unsigned int>{i}, true);
subgraph_indices = {};
} else {
subgraph_indices.emplace_back(i);
}
if (i == number_of_ort_nodes - 1 && !is_context_node) {
supported_node_vectors.emplace_back(subgraph_indices, true);
}
}

for (auto supported_node_vector: supported_node_vectors) {
auto subgraph_idx = supported_node_vector.first[0];
std::unique_ptr<IndexedSubGraph> sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), subgraph_idx);
result.push_back(ComputeCapability::Create(std::move(sub_graph)));
}
return result;
}

}

// Generate unique kernel name for TRT graph
Expand Down Expand Up @@ -2541,7 +2574,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
if (exclude_set.find(node->OpType()) != exclude_set.end()) {
supported_node = false;
}

if (supported_node) {
if (new_subgraph) {
parser_nodes_vector.emplace_back();
Expand Down Expand Up @@ -2766,7 +2799,7 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil
}

common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes_and_graphs,
std::vector<NodeComputeInfo>& node_compute_funcs) {
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (auto& fused_node_graph : fused_nodes_and_graphs) {
const GraphViewer& graph_body_viewer = fused_node_graph.filtered_graph;
const Node& fused_node = fused_node_graph.fused_node;
Expand All @@ -2787,19 +2820,24 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}

Status status;
if (GraphHasCtxNode(graph_body_viewer)) {
int ctx_node_idx = FindCtxNodeInGraph(graph_body_viewer);
jingyanwangms marked this conversation as resolved.
Show resolved Hide resolved
if (ctx_node_idx >= 0) {
status = CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer,
fused_node,
ctx_node_idx,
input_map,
output_map,
node_compute_funcs);
node_compute_funcs
);
} else {
status = CreateNodeComputeInfoFromGraph(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs);

}
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
}
}

return Status::OK();
}

Expand Down Expand Up @@ -3328,15 +3366,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
if (engine_cache_enable_ && engine_hw_compatible_) {
compute_capability_hw_compat = "80+";
}
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto{CreateCtxModel(graph_body_viewer,
ep_cache_context_attr_,
reinterpret_cast<char*>(serialized_engine->data()),
serialized_engine->size(),
ep_context_embed_mode_,
compute_capability_hw_compat,
model_path_,
GetLogger())};
DumpCtxModel(model_proto.get(), ctx_model_path_);
auto trt_ep_context_model_ptr = CreateCtxModel(graph_body_viewer,
fused_node.Name(),
ep_cache_context_attr_,
reinterpret_cast<char*>(serialized_engine->data()),
serialized_engine->size(),
ep_context_embed_mode_,
compute_capability_hw_compat,
model_path_,
GetLogger());
auto& graph = trt_ep_context_model_ptr->MainGraph();
trt_ep_context_models.emplace_back(std::move(trt_ep_context_model_ptr));
}
}
}
Expand Down Expand Up @@ -3434,17 +3474,17 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
if (engine_cache_enable_ && engine_hw_compatible_) {
compute_capability_hw_compat = "80+";
}
model_proto_.reset(CreateCtxModel(graph_body_viewer,
ep_cache_context_attr_,
nullptr,
0,
ep_context_embed_mode_,
compute_capability_hw_compat,
model_path_,
GetLogger()));
if (ep_context_embed_mode_ == 0) {
DumpCtxModel(model_proto_.get(), ctx_model_path_);
}
auto trt_ep_context_model_ptr = CreateCtxModel(graph_body_viewer,
fused_node.Name(),
ep_cache_context_attr_,
nullptr,
0,
ep_context_embed_mode_,
compute_capability_hw_compat,
model_path_,
GetLogger());

trt_ep_context_models.emplace_back(std::move(trt_ep_context_model_ptr));
}

// Create function state
Expand Down Expand Up @@ -4037,9 +4077,13 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView

Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
const Node& fused_node,
const int ctx_node_idx,
std::unordered_map<std::string, size_t>& input_map,
std::unordered_map<std::string, size_t>& output_map,
std::vector<NodeComputeInfo>& node_compute_funcs) {
auto model = graph_body_viewer.CreateModel(*GetLogger());
auto model_proto = model->ToProto();

std::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
std::unique_ptr<nvinfer1::IExecutionContext> trt_context;
std::unordered_map<std::string, size_t> input_indexes; // TRT engine input name -> ORT kernel context input index
Expand All @@ -4056,7 +4100,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
onnx_model_bytestream_,
onnx_model_bytestream_size_,
detailed_build_log_);
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer);
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer, ctx_node_idx);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
Expand Down Expand Up @@ -4360,6 +4404,20 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con
return Status::OK();
}

const InlinedVector<const Node*> TensorrtExecutionProvider::GetEpContextNodes() const {
InlinedVector<const Node*> ep_context_nodes;
if (!trt_ep_context_models.empty()) {
for (const auto& context_model: trt_ep_context_models) {
const auto& graph = context_model->MainGraph();
for (const auto& node: graph.Nodes()) {
// if (node.IsEpContextNode()) { // Check if it's an EP context node
ep_context_nodes.push_back(node);
}
}
}
return ep_context_nodes;
}

void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const {
auto allocator = allocators[GetOrtDeviceByMemType(OrtMemTypeCPU)];
RegisterCudaStreamHandles(stream_handle_registry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured(int graph_annotation_id) const override;
Status ReplayGraph(int graph_annotation_id) override;
const InlinedVector<const Node*> GetEpContextNodes() const override;

static common::Status RefitEngine(std::string onnx_model_filename,
std::string& onnx_model_folder_path,
Expand Down Expand Up @@ -331,6 +332,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::string cache_prefix_;
bool engine_hw_compatible_ = false;
std::string op_types_to_exclude_;
std::vector<std::unique_ptr<Model>> trt_ep_context_models;

// The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH
int32_t trt_version_;
Expand Down Expand Up @@ -567,6 +569,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
*/
Status CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
const Node& fused_node,
const int ctx_node_idx,
std::unordered_map<std::string, size_t>& input_map,
std::unordered_map<std::string, size_t>& output_map,
std::vector<NodeComputeInfo>& node_compute_funcs);
Expand Down
Loading