-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Add TensorRT support for GNNs #4016
base: main
Are you sure you want to change the base?
Changes from all commits
bd45373
6d2b0a6
f5a819a
fadbfd3
963d14f
09ce4b2
45ffd7b
b983ba9
574aa9a
d70d26f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
// This file is part of the ACTS project. | ||
// | ||
// Copyright (C) 2016 CERN for the benefit of the ACTS project | ||
// | ||
// This Source Code Form is subject to the terms of the Mozilla Public | ||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||
// file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||
|
||
#pragma once | ||
|
||
#include "Acts/Plugins/ExaTrkX/Stages.hpp" | ||
#include "Acts/Utilities/Logger.hpp" | ||
|
||
#include <memory> | ||
|
||
#include <torch/torch.h> | ||
|
||
namespace nvinfer1 { | ||
class IRuntime; | ||
class ICudaEngine; | ||
class ILogger; | ||
class IExecutionContext; | ||
} // namespace nvinfer1 | ||
|
||
namespace Acts { | ||
|
||
class TensorRTEdgeClassifier final : public Acts::EdgeClassificationBase { | ||
public: | ||
struct Config { | ||
std::string modelPath; | ||
std::vector<int> selectedFeatures = {}; | ||
float cut = 0.21; | ||
int deviceID = 0; | ||
bool useEdgeFeatures = false; | ||
bool doSigmoid = true; | ||
}; | ||
|
||
TensorRTEdgeClassifier(const Config &cfg, | ||
std::unique_ptr<const Logger> logger); | ||
~TensorRTEdgeClassifier(); | ||
|
||
std::tuple<std::any, std::any, std::any, std::any> operator()( | ||
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {}, | ||
const ExecutionContext &execContext = {}) override; | ||
|
||
Config config() const { return m_cfg; } | ||
torch::Device device() const override { return torch::kCUDA; }; | ||
|
||
private: | ||
std::unique_ptr<const Acts::Logger> m_logger; | ||
const auto &logger() const { return *m_logger; } | ||
|
||
Config m_cfg; | ||
|
||
std::unique_ptr<nvinfer1::IRuntime> m_runtime; | ||
std::unique_ptr<nvinfer1::ICudaEngine> m_engine; | ||
std::unique_ptr<nvinfer1::ILogger> m_trtLogger; | ||
std::unique_ptr<nvinfer1::IExecutionContext> m_context; | ||
}; | ||
|
||
} // namespace Acts |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,190 @@ | ||||||||||||||
// This file is part of the ACTS project. | ||||||||||||||
// | ||||||||||||||
// Copyright (C) 2016 CERN for the benefit of the ACTS project | ||||||||||||||
// | ||||||||||||||
// This Source Code Form is subject to the terms of the Mozilla Public | ||||||||||||||
// License, v. 2.0. If a copy of the MPL was not distributed with this | ||||||||||||||
// file, You can obtain one at https://mozilla.org/MPL/2.0/. | ||||||||||||||
|
||||||||||||||
#include "Acts/Plugins/ExaTrkX/TensorRTEdgeClassifier.hpp" | ||||||||||||||
|
||||||||||||||
#include "Acts/Plugins/ExaTrkX/detail/Utils.hpp" | ||||||||||||||
|
||||||||||||||
#include <chrono> | ||||||||||||||
#include <filesystem> | ||||||||||||||
#include <fstream> | ||||||||||||||
|
||||||||||||||
#include <NvInfer.h> | ||||||||||||||
#include <NvInferPlugin.h> | ||||||||||||||
#include <NvInferRuntimeBase.h> | ||||||||||||||
#include <c10/cuda/CUDAGuard.h> | ||||||||||||||
#include <cuda_runtime.h> | ||||||||||||||
|
||||||||||||||
#include "printCudaMemInfo.hpp" | ||||||||||||||
|
||||||||||||||
using namespace torch::indexing; | ||||||||||||||
|
||||||||||||||
namespace { | ||||||||||||||
|
||||||||||||||
class TensorRTLogger : public nvinfer1::ILogger { | ||||||||||||||
std::unique_ptr<const Acts::Logger> m_logger; | ||||||||||||||
|
||||||||||||||
public: | ||||||||||||||
TensorRTLogger(Acts::Logging::Level lvl) | ||||||||||||||
: m_logger(Acts::getDefaultLogger("TensorRT", lvl)) {} | ||||||||||||||
|
||||||||||||||
void log(Severity severity, const char *msg) noexcept override { | ||||||||||||||
const auto &logger = *m_logger; | ||||||||||||||
switch (severity) { | ||||||||||||||
case Severity::kVERBOSE: | ||||||||||||||
ACTS_DEBUG(msg); | ||||||||||||||
break; | ||||||||||||||
case Severity::kINFO: | ||||||||||||||
ACTS_INFO(msg); | ||||||||||||||
break; | ||||||||||||||
case Severity::kWARNING: | ||||||||||||||
ACTS_WARNING(msg); | ||||||||||||||
break; | ||||||||||||||
case Severity::kERROR: | ||||||||||||||
ACTS_ERROR(msg); | ||||||||||||||
break; | ||||||||||||||
case Severity::kINTERNAL_ERROR: | ||||||||||||||
ACTS_FATAL(msg); | ||||||||||||||
break; | ||||||||||||||
} | ||||||||||||||
} | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
} // namespace | ||||||||||||||
|
||||||||||||||
namespace Acts { | ||||||||||||||
|
||||||||||||||
TensorRTEdgeClassifier::TensorRTEdgeClassifier( | ||||||||||||||
const Config &cfg, std::unique_ptr<const Logger> _logger) | ||||||||||||||
: m_logger(std::move(_logger)), | ||||||||||||||
m_cfg(cfg), | ||||||||||||||
m_trtLogger(std::make_unique<TensorRTLogger>(m_logger->level())) { | ||||||||||||||
auto status = initLibNvInferPlugins(m_trtLogger.get(), ""); | ||||||||||||||
assert(status); | ||||||||||||||
Comment on lines
+67
to
+68
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Proper error handling for plugin initialization, implement you must. Relying on Apply this diff to handle the error: auto status = initLibNvInferPlugins(m_trtLogger.get(), "");
-assert(status);
+if (!status) {
+ ACTS_ERROR("Failed to initialize TensorRT plugins.");
+ // Handle the error appropriately, perhaps throw an exception or return an error code.
+}
|
||||||||||||||
|
||||||||||||||
std::size_t fsize = | ||||||||||||||
std::filesystem::file_size(std::filesystem::path(m_cfg.modelPath)); | ||||||||||||||
std::vector<char> engineData(fsize); | ||||||||||||||
|
||||||||||||||
ACTS_DEBUG("Load '" << m_cfg.modelPath << "' with size " << fsize); | ||||||||||||||
|
||||||||||||||
std::ifstream engineFile(m_cfg.modelPath); | ||||||||||||||
engineFile.read(engineData.data(), fsize); | ||||||||||||||
|
||||||||||||||
Comment on lines
+76
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. File opening and reading errors, check you must. Ensure the model file opens successfully before reading. Handle any file I/O errors to prevent unexpected behavior. Apply this diff to add error handling: std::ifstream engineFile(m_cfg.modelPath);
+if (!engineFile.is_open()) {
+ ACTS_ERROR("Failed to open model file: " << m_cfg.modelPath);
+ // Handle the error appropriately.
+}
engineFile.read(engineData.data(), fsize);
+if (!engineFile) {
+ ACTS_ERROR("Failed to read model data from: " << m_cfg.modelPath);
+ // Handle the error appropriately.
+}
|
||||||||||||||
m_runtime.reset(nvinfer1::createInferRuntime(*m_trtLogger)); | ||||||||||||||
|
||||||||||||||
m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize)); | ||||||||||||||
|
||||||||||||||
m_context.reset(m_engine->createExecutionContext()); | ||||||||||||||
Comment on lines
+81
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Deserialization errors, handle you should. Check if Apply this diff to verify the engine: m_engine.reset(m_runtime->deserializeCudaEngine(engineData.data(), fsize));
+if (!m_engine) {
+ ACTS_ERROR("Failed to deserialize CUDA engine.");
+ // Handle the error appropriately.
+}
|
||||||||||||||
} | ||||||||||||||
|
||||||||||||||
TensorRTEdgeClassifier::~TensorRTEdgeClassifier() {} | ||||||||||||||
|
||||||||||||||
auto milliseconds = [](const auto &a, const auto &b) { | ||||||||||||||
return std::chrono::duration<double, std::milli>(b - a).count(); | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
struct TimePrinter { | ||||||||||||||
const char *name; | ||||||||||||||
decltype(std::chrono::high_resolution_clock::now()) t0, t1; | ||||||||||||||
TimePrinter(const char *n) : name(n) { | ||||||||||||||
t0 = std::chrono::high_resolution_clock::now(); | ||||||||||||||
} | ||||||||||||||
~TimePrinter() { | ||||||||||||||
std::cout << name << ": " << milliseconds(t0, t1) << std::endl; | ||||||||||||||
} | ||||||||||||||
}; | ||||||||||||||
|
||||||||||||||
#if 0 | ||||||||||||||
#define TIME_BEGIN(name) TimePrinter printer##name(#name); | ||||||||||||||
#define TIME_END(name) \ | ||||||||||||||
printer##name.t1 = std::chrono::high_resolution_clock::now(); | ||||||||||||||
#else | ||||||||||||||
#define TIME_BEGIN(name) /*nothing*/ | ||||||||||||||
#define TIME_END(name) /*ǹothing*/ | ||||||||||||||
#endif | ||||||||||||||
|
||||||||||||||
std::tuple<std::any, std::any, std::any, std::any> | ||||||||||||||
TensorRTEdgeClassifier::operator()(std::any inNodeFeatures, | ||||||||||||||
std::any inEdgeIndex, | ||||||||||||||
std::any inEdgeFeatures, | ||||||||||||||
const ExecutionContext &execContext) { | ||||||||||||||
decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4, t5; | ||||||||||||||
t0 = std::chrono::high_resolution_clock::now(); | ||||||||||||||
|
||||||||||||||
c10::cuda::CUDAStreamGuard(execContext.stream.value()); | ||||||||||||||
|
||||||||||||||
auto nodeFeatures = | ||||||||||||||
std::any_cast<torch::Tensor>(inNodeFeatures).to(torch::kCUDA); | ||||||||||||||
|
||||||||||||||
auto edgeIndex = std::any_cast<torch::Tensor>(inEdgeIndex).to(torch::kCUDA); | ||||||||||||||
ACTS_DEBUG("edgeIndex: " << detail::TensorDetails{edgeIndex}); | ||||||||||||||
|
||||||||||||||
auto edgeFeatures = | ||||||||||||||
std::any_cast<torch::Tensor>(inEdgeFeatures).to(torch::kCUDA); | ||||||||||||||
ACTS_DEBUG("edgeFeatures: " << detail::TensorDetails{edgeFeatures}); | ||||||||||||||
|
||||||||||||||
t1 = std::chrono::high_resolution_clock::now(); | ||||||||||||||
|
||||||||||||||
m_context->setInputShape( | ||||||||||||||
"x", nvinfer1::Dims2{nodeFeatures.size(0), nodeFeatures.size(1)}); | ||||||||||||||
m_context->setTensorAddress("x", nodeFeatures.data_ptr()); | ||||||||||||||
|
||||||||||||||
m_context->setInputShape( | ||||||||||||||
"edge_index", nvinfer1::Dims2{edgeIndex.size(0), edgeIndex.size(1)}); | ||||||||||||||
m_context->setTensorAddress("edge_index", edgeIndex.data_ptr()); | ||||||||||||||
|
||||||||||||||
m_context->setInputShape( | ||||||||||||||
"edge_attr", nvinfer1::Dims2{edgeFeatures.size(0), edgeFeatures.size(1)}); | ||||||||||||||
m_context->setTensorAddress("edge_attr", edgeFeatures.data_ptr()); | ||||||||||||||
|
||||||||||||||
void *outputMem{nullptr}; | ||||||||||||||
std::size_t outputSize = edgeIndex.size(1) * sizeof(float); | ||||||||||||||
cudaMalloc(&outputMem, outputSize); | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check return value of Ensure that memory allocation on the GPU is successful before proceeding. Apply this diff to check -cudaMalloc(&outputMem, outputSize);
+cudaError_t err = cudaMalloc(&outputMem, outputSize);
+if (err != cudaSuccess) {
+ ACTS_ERROR("cudaMalloc failed: " << cudaGetErrorString(err));
+ // Handle the error appropriately.
+} 📝 Committable suggestion
Suggested change
|
||||||||||||||
m_context->setTensorAddress("output", outputMem); | ||||||||||||||
|
||||||||||||||
t2 = std::chrono::high_resolution_clock::now(); | ||||||||||||||
|
||||||||||||||
{ | ||||||||||||||
auto stream = execContext.stream.value().stream(); | ||||||||||||||
auto status = m_context->enqueueV3(stream); | ||||||||||||||
cudaStreamSynchronize(stream); | ||||||||||||||
Comment on lines
+155
to
+156
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Inference execution status, verify you should. Check the return status of Apply this diff to handle inference errors: auto status = m_context->enqueueV3(stream);
+if (!status) {
+ ACTS_ERROR("Inference execution failed.");
+ // Handle the error appropriately.
+}
cudaStreamSynchronize(stream);
|
||||||||||||||
ACTS_VERBOSE("TensorRT output status: " << std::boolalpha << status); | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
t3 = std::chrono::high_resolution_clock::now(); | ||||||||||||||
|
||||||||||||||
auto scores = torch::from_blob( | ||||||||||||||
outputMem, edgeIndex.size(1), 1, [](void *ptr) { cudaFree(ptr); }, | ||||||||||||||
torch::TensorOptions().device(torch::kCUDA).dtype(torch::kFloat32)); | ||||||||||||||
|
||||||||||||||
scores.sigmoid_(); | ||||||||||||||
|
||||||||||||||
ACTS_VERBOSE("Size after classifier: " << scores.size(0)); | ||||||||||||||
ACTS_VERBOSE("Slice of classified output:\n" | ||||||||||||||
<< scores.slice(/*dim=*/0, /*start=*/0, /*end=*/9)); | ||||||||||||||
printCudaMemInfo(logger()); | ||||||||||||||
|
||||||||||||||
torch::Tensor mask = scores > m_cfg.cut; | ||||||||||||||
torch::Tensor edgesAfterCut = edgeIndex.index({Slice(), mask}); | ||||||||||||||
|
||||||||||||||
scores = scores.masked_select(mask); | ||||||||||||||
ACTS_VERBOSE("Size after score cut: " << edgesAfterCut.size(1)); | ||||||||||||||
printCudaMemInfo(logger()); | ||||||||||||||
|
||||||||||||||
t4 = std::chrono::high_resolution_clock::now(); | ||||||||||||||
|
||||||||||||||
ACTS_DEBUG("Time anycast: " << milliseconds(t0, t1)); | ||||||||||||||
ACTS_DEBUG("Time alloc, set shape " << milliseconds(t1, t2)); | ||||||||||||||
ACTS_DEBUG("Time inference: " << milliseconds(t2, t3)); | ||||||||||||||
ACTS_DEBUG("Time sigmoid and cut: " << milliseconds(t3, t4)); | ||||||||||||||
|
||||||||||||||
return {nodeFeatures, edgesAfterCut, edgeFeatures, scores}; | ||||||||||||||
} | ||||||||||||||
|
||||||||||||||
} // namespace Acts |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return a proper
torch::Device
object, you must.Current implementation returns a device type, not a
torch::Device
instance. Correct this, you should.Apply this diff to return the correct device:
torch::Device device() const override { return torch::kCUDA; }; + torch::Device device() const override { return torch::Device(torch::kCUDA); };