Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add TensorRT support for GNNs #4016

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,43 @@ test_exatrkx_python:
- pytest -rFsv -k torch --collect-only
- pytest -rFsv -k gpu-torch # For now only test torch GPU pipeline

build_gnn_tensorrt:
stage: build
image: ghcr.io/acts-project/ubuntu2404_tensorrt:sha-b4f481f@sha256:8887aa00ad4394a53b4ca54968121d8893d537e5daf50805f1dd2030caef78ce
variables:
DEPENDENCY_URL: https://acts.web.cern.ch/ACTS/ci/ubuntu-24.04/deps.$DEPENDENCY_TAG.tar.zst

cache:
key: ccache-${CI_JOB_NAME}-${CI_COMMIT_REF_SLUG}-${CCACHE_KEY_SUFFIX}
fallback_keys:
- ccache-${CI_JOB_NAME}-${CI_DEFAULT_BRANCH}-${CCACHE_KEY_SUFFIX}
when: always
paths:
- ${CCACHE_DIR}

tags:
- docker-gpu-nvidia

script:
- git clone $CLONE_URL src
- cd src
- git checkout $HEAD_SHA
- source CI/dependencies.sh
- cd ..
- mkdir build
- >
cmake -B build -S src
-DACTS_BUILD_PLUGIN_EXATRKX=ON
-DACTS_EXATRKX_ENABLE_TORCH=OFF
-DACTS_EXATRKX_ENABLE_CUDA=ON
-DACTS_EXATRKX_ENABLE_TENSORRT=ON
-DPython_EXECUTABLE=$(which python3)
-DCMAKE_CUDA_ARCHITECTURES="75;86"
- ccache -z
- cmake --build build -- -j6
- ccache -s


build_linux_ubuntu:
stage: build
image: ghcr.io/acts-project/ubuntu2404:63
Expand Down
7 changes: 0 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,6 @@ if(ACTS_BUILD_PLUGIN_EXATRKX)
else()
message(STATUS "Build Exa.TrkX plugin for CPU only")
endif()
if(NOT (ACTS_EXATRKX_ENABLE_ONNX OR ACTS_EXATRKX_ENABLE_TORCH))
message(
FATAL_ERROR
"When building the Exa.TrkX plugin, at least one of ACTS_EXATRKX_ENABLE_ONNX \
and ACTS_EXATRKX_ENABLE_TORCHSCRIPT must be enabled."
)
endif()
if(ACTS_EXATRKX_ENABLE_TORCH)
find_package(TorchScatter REQUIRED)
endif()
Expand Down
27 changes: 27 additions & 0 deletions Examples/Python/src/ExaTrkXTrackFinding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "Acts/Plugins/ExaTrkX/ExaTrkXPipeline.hpp"
#include "Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp"
#include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp"
#include "Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp"
#include "Acts/Plugins/ExaTrkX/TorchTruthGraphMetricsHook.hpp"
Expand Down Expand Up @@ -112,6 +113,32 @@ void addExaTrkXTrackFinding(Context &ctx) {
}
#endif

#ifdef ACTS_EXATRKX_WITH_TENSORRT
{
using Alg = Acts::TensorRTEdgeClassifier;
using Config = Alg::Config;

auto alg =
py::class_<Alg, Acts::EdgeClassificationBase, std::shared_ptr<Alg>>(
mex, "TensorRTEdgeClassifier")
.def(py::init([](const Config &c, Logging::Level lvl) {
return std::make_shared<Alg>(
c, getDefaultLogger("EdgeClassifier", lvl));
}),
py::arg("config"), py::arg("level"))
.def_property_readonly("config", &Alg::config);

auto c = py::class_<Config>(alg, "Config").def(py::init<>());
ACTS_PYTHON_STRUCT_BEGIN(c, Config);
ACTS_PYTHON_MEMBER(modelPath);
ACTS_PYTHON_MEMBER(selectedFeatures);
ACTS_PYTHON_MEMBER(cut);
ACTS_PYTHON_MEMBER(deviceID);
ACTS_PYTHON_MEMBER(doSigmoid);
ACTS_PYTHON_STRUCT_END();
}
#endif

#ifdef ACTS_EXATRKX_ONNX_BACKEND
{
using Alg = Acts::OnnxMetricLearning;
Expand Down
14 changes: 14 additions & 0 deletions Plugins/ExaTrkX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@ if(ACTS_EXATRKX_ENABLE_TORCH)
)
endif()

if(ACTS_EXATRKX_ENABLE_TENSORRT)
find_package(TensorRT REQUIRED)
message(STATUS "Found TensorRT ${TensorRT_VERSION}")
target_link_libraries(
ActsPluginExaTrkX
PUBLIC trt::nvinfer trt::nvinfer_plugin
)
target_sources(ActsPluginExaTrkX PRIVATE src/TensorRTEdgeClassifier.cpp)
target_compile_definitions(
ActsPluginExaTrkX
PUBLIC ACTS_EXATRKX_WITH_TENSORRT
)
endif()

target_include_directories(
ActsPluginExaTrkX
PUBLIC
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#pragma once

#include "Acts/Plugins/ExaTrkX/Stages.hpp"
#include "Acts/Utilities/Logger.hpp"

#include <memory>

#include <torch/torch.h>

namespace nvinfer1 {
class IRuntime;
class ICudaEngine;
class ILogger;
class IExecutionContext;
} // namespace nvinfer1

namespace Acts {

class TensorRTEdgeClassifier final : public Acts::EdgeClassificationBase {
public:
struct Config {
std::string modelPath;
std::vector<int> selectedFeatures = {};
float cut = 0.21;
int deviceID = 0;
bool useEdgeFeatures = false;
bool doSigmoid = true;
};

TensorRTEdgeClassifier(const Config &cfg,
std::unique_ptr<const Logger> logger);
~TensorRTEdgeClassifier();

std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
const ExecutionContext &execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return torch::kCUDA; };
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Return a proper torch::Device object, you must.

Current implementation returns a device type, not a torch::Device instance. Correct this, you should.

Apply this diff to return the correct device:

   torch::Device device() const override { return torch::kCUDA; };
+  torch::Device device() const override { return torch::Device(torch::kCUDA); };

Committable suggestion skipped: line range outside the PR's diff.


private:
std::unique_ptr<const Acts::Logger> m_logger;
const auto &logger() const { return *m_logger; }

Config m_cfg;

std::unique_ptr<nvinfer1::IRuntime> m_runtime;
std::unique_ptr<nvinfer1::ICudaEngine> m_engine;
std::unique_ptr<nvinfer1::ILogger> m_trtLogger;
std::unique_ptr<nvinfer1::IExecutionContext> m_context;
};

} // namespace Acts
190 changes: 190 additions & 0 deletions Plugins/ExaTrkX/src/TensorRTEdgeClassifier.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// This file is part of the ACTS project.
//
// Copyright (C) 2016 CERN for the benefit of the ACTS project
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at https://mozilla.org/MPL/2.0/.

#include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp"

#include "Acts/Plugins/ExaTrkX/detail/Utils.hpp"

#include <chrono>
#include <filesystem>
#include <fstream>

#include <NvInfer.h>
#include <NvInferPlugin.h>
#include <NvInferRuntimeBase.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>

#include "printCudaMemInfo.hpp"

using namespace torch::indexing;

namespace {

class TensorRTLogger : public nvinfer1::ILogger {
std::unique_ptr<const Acts::Logger> m_logger;

public:
TensorRTLogger(Acts::Logging::Level lvl)
: m_logger(Acts::getDefaultLogger("TensorRT", lvl)) {}

void log(Severity severity, const char *msg) noexcept override {
const auto &logger = *m_logger;
switch (severity) {
case Severity::kVERBOSE:
ACTS_DEBUG(msg);
break;
case Severity::kINFO:
ACTS_INFO(msg);
break;
case Severity::kWARNING:
ACTS_WARNING(msg);
break;
case Severity::kERROR:
ACTS_ERROR(msg);
break;
case Severity::kINTERNAL_ERROR:
ACTS_FATAL(msg);
break;
}
}
};

} // namespace

namespace Acts {

TensorRTEdgeClassifier::TensorRTEdgeClassifier(
const Config &cfg, std::unique_ptr<const Logger> _logger)
: m_logger(std::move(_logger)),
m_cfg(cfg),
m_trtLogger(std::make_unique<TensorRTLogger>(m_logger->level())) {
auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
assert(status);
Comment on lines +67 to +68
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Proper error handling for plugin initialization, implement you must.

Relying on assert(status); insufficient it is, especially in release builds. Check the return status and handle errors gracefully, you should.

Apply this diff to handle the error:

 auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
-assert(status);
+if (!status) {
+  ACTS_ERROR("Failed to initialize TensorRT plugins.");
+  // Handle the error appropriately, perhaps throw an exception or return an error code.
+}

Committable suggestion skipped: line range outside the PR's diff.


std::size_t fsize =
std::filesystem::file_size(std::filesystem::path(m_cfg.modelPath));
std::vector<char> engineData(fsize);

ACTS_DEBUG("Load '" << m_cfg.modelPath << "' with size " << fsize);

std::ifstream engineFile(m_cfg.modelPath);
engineFile.read(engineData.data(), fsize);

Comment on lines +76 to +78
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

File opening and reading errors, check you must.

Ensure the model file opens successfully before reading. Handle any file I/O errors to prevent unexpected behavior.

Apply this diff to add error handling:

 std::ifstream engineFile(m_cfg.modelPath);
+if (!engineFile.is_open()) {
+  ACTS_ERROR("Failed to open model file: " << m_cfg.modelPath);
+  // Handle the error appropriately.
+}
 engineFile.read(engineData.data(), fsize);
+if (!engineFile) {
+  ACTS_ERROR("Failed to read model data from: " << m_cfg.modelPath);
+  // Handle the error appropriately.
+}

Committable suggestion skipped: line range outside the PR's diff.

m_runtime.reset(nvinfer1::createInferRuntime(*m_trtLogger));

m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));

m_context.reset(m_engine->createExecutionContext());
Comment on lines +81 to +83
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Deserialization errors, handle you should.

Check if m_engine is successfully created after deserialization. Handle errors to avoid null pointer dereferences.

Apply this diff to verify the engine:

 m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));
+if (!m_engine) {
+  ACTS_ERROR("Failed to deserialize CUDA engine.");
+  // Handle the error appropriately.
+}

Committable suggestion skipped: line range outside the PR's diff.

}

TensorRTEdgeClassifier::~TensorRTEdgeClassifier() {}

auto milliseconds = [](const auto &a, const auto &b) {
return std::chrono::duration<double, std::milli>(b - a).count();
};

struct TimePrinter {
const char *name;
decltype(std::chrono::high_resolution_clock::now()) t0, t1;
TimePrinter(const char *n) : name(n) {
t0 = std::chrono::high_resolution_clock::now();
}
~TimePrinter() {
std::cout << name << ": " << milliseconds(t0, t1) << std::endl;
}
};

#if 0
#define TIME_BEGIN(name) TimePrinter printer##name(#name);
#define TIME_END(name) \
printer##name.t1 = std::chrono::high_resolution_clock::now();
#else
#define TIME_BEGIN(name) /*nothing*/
#define TIME_END(name) /*ǹothing*/
#endif

std::tuple<std::any, std::any, std::any, std::any>
TensorRTEdgeClassifier::operator()(std::any inNodeFeatures,
std::any inEdgeIndex,
std::any inEdgeFeatures,
const ExecutionContext &execContext) {
decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4, t5;
t0 = std::chrono::high_resolution_clock::now();

c10::cuda::CUDAStreamGuard(execContext.stream.value());

auto nodeFeatures =
std::any_cast<torch::Tensor>(inNodeFeatures).to(torch::kCUDA);

auto edgeIndex = std::any_cast<torch::Tensor>(inEdgeIndex).to(torch::kCUDA);
ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex});

auto edgeFeatures =
std::any_cast<torch::Tensor>(inEdgeFeatures).to(torch::kCUDA);
ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures});

t1 = std::chrono::high_resolution_clock::now();

m_context->setInputShape(
"x", nvinfer1::Dims2{nodeFeatures.size(0), nodeFeatures.size(1)});
m_context->setTensorAddress("x", nodeFeatures.data_ptr());

m_context->setInputShape(
"edge_index", nvinfer1::Dims2{edgeIndex.size(0), edgeIndex.size(1)});
m_context->setTensorAddress("edge_index", edgeIndex.data_ptr());

m_context->setInputShape(
"edge_attr", nvinfer1::Dims2{edgeFeatures.size(0), edgeFeatures.size(1)});
m_context->setTensorAddress("edge_attr", edgeFeatures.data_ptr());

void *outputMem{nullptr};
std::size_t outputSize = edgeIndex.size(1) * sizeof(float);
cudaMalloc(&outputMem, outputSize);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Check return value of cudaMalloc, you must.

Ensure that memory allocation on the GPU is successful before proceeding.

Apply this diff to check cudaMalloc:

-cudaMalloc(&outputMem, outputSize);
+cudaError_t err = cudaMalloc(&outputMem, outputSize);
+if (err != cudaSuccess) {
+  ACTS_ERROR("cudaMalloc failed: " << cudaGetErrorString(err));
+  // Handle the error appropriately.
+}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
cudaMalloc(&outputMem, outputSize);
cudaError_t err = cudaMalloc(&outputMem, outputSize);
if (err != cudaSuccess) {
ACTS_ERROR("cudaMalloc failed: " << cudaGetErrorString(err));
// Handle the error appropriately.
}

m_context->setTensorAddress("output", outputMem);

t2 = std::chrono::high_resolution_clock::now();

{
auto stream = execContext.stream.value().stream();
auto status = m_context->enqueueV3(stream);
cudaStreamSynchronize(stream);
Comment on lines +155 to +156
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Inference execution status, verify you should.

Check the return status of enqueueV3 to confirm that the inference executed successfully.

Apply this diff to handle inference errors:

 auto status = m_context->enqueueV3(stream);
+if (!status) {
+  ACTS_ERROR("Inference execution failed.");
+  // Handle the error appropriately.
+}
 cudaStreamSynchronize(stream);

Committable suggestion skipped: line range outside the PR's diff.

ACTS_VERBOSE("TensorRT output status: " << std::boolalpha << status);
}

t3 = std::chrono::high_resolution_clock::now();

auto scores = torch::from_blob(
outputMem, edgeIndex.size(1), 1, [](void *ptr) { cudaFree(ptr); },
torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32));

scores.sigmoid_();

ACTS_VERBOSE("Size after classifier: " << scores.size(0));
ACTS_VERBOSE("Slice of classified output:\n"
<< scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9));
printCudaMemInfo(logger());

torch::Tensor mask = scores > m_cfg.cut;
torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask});

scores = scores.masked_select(mask);
ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1));
printCudaMemInfo(logger());

t4 = std::chrono::high_resolution_clock::now();

ACTS_DEBUG("Time anycast: " << milliseconds(t0, t1));
ACTS_DEBUG("Time alloc, set shape " << milliseconds(t1, t2));
ACTS_DEBUG("Time inference: " << milliseconds(t2, t3));
ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4));

return {nodeFeatures, edgesAfterCut, edgeFeatures, scores};
}

} // namespace Acts
Loading
Loading