Skip to content

Commit

Permalink
feat: Add support for CUDA streams in GNN plugin (#4012)
Browse files Browse the repository at this point in the history
Uses the torch framework (as it is required anyways currently) as a source of cuda streams. Extends the interface of the some components to use the streams.
  • Loading branch information
benjaminhuth authored Jan 15, 2025
1 parent 0932329 commit c0e65bc
Show file tree
Hide file tree
Showing 11 changed files with 49 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BoostTrackBuilding final : public Acts::TrackBuildingBase {
std::vector<std::vector<int>> operator()(
std::any nodes, std::any edges, std::any edge_weights,
std::vector<int> &spacepointIDs,
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;
torch::Device device() const override { return m_device; };

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase {

std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase {
std::tuple<std::any, std::any, std::any> operator()(
std::vector<float>& inputValues, std::size_t numNodes,
const std::vector<std::uint64_t>& moduleIds,
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext& execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
20 changes: 13 additions & 7 deletions Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,20 @@
#include <exception>
#include <vector>

#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

namespace Acts {

/// Error that is thrown if no edges are found
struct NoEdgesError : std::exception {};

/// Capture the context of the execution
struct ExecutionContext {
torch::Device device{torch::kCPU};
std::optional<c10::cuda::CUDAStream> stream;
};

// TODO maybe replace std::any with some kind of variant<unique_ptr<torch>,
// unique_ptr<onnx>>?
// TODO maybe replace input for GraphConstructionBase with some kind of
Expand All @@ -34,13 +41,12 @@ class GraphConstructionBase {
/// then gives the number of features
/// @param moduleIds Module IDs of the features (used for module-map-like
/// graph construction)
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
///
/// @param execContext Device & stream information
/// @return (node_features, edge_features, edge_index)
virtual std::tuple<std::any, std::any, std::any> operator()(
std::vector<float> &inputValues, std::size_t numNodes,
const std::vector<std::uint64_t> &moduleIds,
torch::Device device = torch::Device(torch::kCPU)) = 0;
const ExecutionContext &execContext = {}) = 0;

virtual torch::Device device() const = 0;

Expand All @@ -54,12 +60,12 @@ class EdgeClassificationBase {
/// @param nodeFeatures Node tensor with shape (n_nodes, n_node_features)
/// @param edgeIndex Edge-index tensor with shape (2, n_edges)
/// @param edgeFeatures Edge-feature tensor with shape (n_edges, n_edge_features)
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
/// @param execContext Device & stream information
///
/// @return (node_features, edge_features, edge_index, edge_scores)
virtual std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
torch::Device device = torch::Device(torch::kCPU)) = 0;
const ExecutionContext &execContext = {}) = 0;

virtual torch::Device device() const = 0;

Expand All @@ -74,13 +80,13 @@ class TrackBuildingBase {
/// @param edgeIndex Edge-index tensor with shape (2, n_edges)
/// @param edgeScores Scores of the previous edge classification phase
/// @param spacepointIDs IDs of the nodes (must have size=n_nodes)
/// @param device Which GPU device to pick. Not relevant for CPU-only builds
/// @param execContext Device & stream information
///
/// @return tracks (as vectors of node-IDs)
virtual std::vector<std::vector<int>> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeScores,
std::vector<int> &spacepointIDs,
torch::Device device = torch::Device(torch::kCPU)) = 0;
const ExecutionContext &execContext = {}) = 0;

virtual torch::Device device() const = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase {

std::tuple<std::any, std::any, std::any, std::any> operator()(
std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {},
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase {
std::tuple<std::any, std::any, std::any> operator()(
std::vector<float> &inputValues, std::size_t numNodes,
const std::vector<std::uint64_t> &moduleIds,
torch::Device device = torch::Device(torch::kCPU)) override;
const ExecutionContext &execContext = {}) override;

Config config() const { return m_cfg; }
torch::Device device() const override { return m_device; };
Expand Down
2 changes: 1 addition & 1 deletion Plugins/ExaTrkX/src/BoostTrackBuilding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ namespace Acts {

std::vector<std::vector<int>> BoostTrackBuilding::operator()(
std::any /*nodes*/, std::any edges, std::any weights,
std::vector<int>& spacepointIDs, torch::Device) {
std::vector<int>& spacepointIDs, const ExecutionContext& execContext) {
ACTS_DEBUG("Start track building");
const auto edgeTensor = std::any_cast<torch::Tensor>(edges).to(torch::kCPU);
const auto edgeWeightTensor =
Expand Down
16 changes: 11 additions & 5 deletions Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
std::vector<float> &features, const std::vector<std::uint64_t> &moduleIds,
std::vector<int> &spacepointIDs, const ExaTrkXHook &hook,
ExaTrkXTiming *timing) const {
ExecutionContext ctx;
ctx.device = m_graphConstructor->device();
#ifndef ACTS_EXATRKX_CPUONLY
if (ctx.device.type() == torch::kCUDA) {
ctx.stream = c10::cuda::getStreamFromPool(ctx.device.index());
}
#endif

try {
auto t0 = std::chrono::high_resolution_clock::now();
auto [nodeFeatures, edgeIndex, edgeFeatures] =
(*m_graphConstructor)(features, spacepointIDs.size(), moduleIds,
m_graphConstructor->device());
(*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx);
auto t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -59,7 +66,7 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(
t0 = std::chrono::high_resolution_clock::now();
auto [newNodeFeatures, newEdgeIndex, newEdgeFeatures, newEdgeScores] =
(*edgeClassifier)(std::move(nodeFeatures), std::move(edgeIndex),
std::move(edgeFeatures), edgeClassifier->device());
std::move(edgeFeatures), ctx);
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand All @@ -76,8 +83,7 @@ std::vector<std::vector<int>> ExaTrkXPipeline::run(

t0 = std::chrono::high_resolution_clock::now();
auto res = (*m_trackBuilder)(std::move(nodeFeatures), std::move(edgeIndex),
std::move(edgeScores), spacepointIDs,
m_trackBuilder->device());
std::move(edgeScores), spacepointIDs, ctx);
t1 = std::chrono::high_resolution_clock::now();

if (timing != nullptr) {
Expand Down
3 changes: 2 additions & 1 deletion Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,8 @@ std::ostream &operator<<(std::ostream &os, Ort::Value &v) {

std::tuple<std::any, std::any, std::any, std::any>
OnnxEdgeClassifier::operator()(std::any inputNodes, std::any inputEdges,
std::any inEdgeFeatures, torch::Device) {
std::any inEdgeFeatures,
const ExecutionContext & /*unused*/) {
auto torchDevice = torch::kCPU;
Ort::MemoryInfo memoryInfo("Cpu", OrtArenaAllocator, /*device_id*/ 0,
OrtMemTypeDefault);
Expand Down
12 changes: 9 additions & 3 deletions Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,23 @@ TorchEdgeClassifier::~TorchEdgeClassifier() {}

std::tuple<std::any, std::any, std::any, std::any>
TorchEdgeClassifier::operator()(std::any inNodeFeatures, std::any inEdgeIndex,
std::any inEdgeFeatures, torch::Device device) {
decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4, t5;
std::any inEdgeFeatures,
const ExecutionContext& execContext) {
const auto& device = execContext.device;
decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4;
t0 = std::chrono::high_resolution_clock::now();
ACTS_DEBUG("Start edge classification, use " << device);
c10::InferenceMode guard(true);

// add a protection to avoid calling for kCPU
#ifndef ACTS_EXATRKX_CPUONLY
#ifdef ACTS_EXATRKX_CPUONLY
assert(device == torch::Device(torch::kCPU));
#else
std::optional<c10::cuda::CUDAGuard> device_guard;
std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
if (device.is_cuda()) {
device_guard.emplace(device.index());
streamGuard.emplace(execContext.stream.value());
}
#endif

Expand Down
10 changes: 8 additions & 2 deletions Plugins/ExaTrkX/src/TorchMetricLearning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,21 @@ TorchMetricLearning::~TorchMetricLearning() {}

std::tuple<std::any, std::any, std::any> TorchMetricLearning::operator()(
std::vector<float> &inputValues, std::size_t numNodes,
const std::vector<std::uint64_t> & /*moduleIds*/, torch::Device device) {
const std::vector<std::uint64_t> & /*moduleIds*/,
const ExecutionContext &execContext) {
const auto &device = execContext.device;
ACTS_DEBUG("Start graph construction");
c10::InferenceMode guard(true);

// add a protection to avoid calling for kCPU
#ifndef ACTS_EXATRKX_CPUONLY
#ifdef ACTS_EXATRKX_CPUONLY
assert(device == torch::Device(torch::kCPU));
#else
std::optional<c10::cuda::CUDAGuard> device_guard;
std::optional<c10::cuda::CUDAStreamGuard> streamGuard;
if (device.is_cuda()) {
device_guard.emplace(device.index());
streamGuard.emplace(execContext.stream.value());
}
#endif

Expand Down

0 comments on commit c0e65bc

Please sign in to comment.