diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp index a6905d5b06c..95d7f10e3b8 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/BoostTrackBuilding.hpp @@ -25,7 +25,7 @@ class BoostTrackBuilding final : public Acts::TrackBuildingBase { std::vector> operator()( std::any nodes, std::any edges, std::any edge_weights, std::vector &spacepointIDs, - torch::Device device = torch::Device(torch::kCPU)) override; + const ExecutionContext &execContext = {}) override; torch::Device device() const override { return m_device; }; private: diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp index 5682b7e84a9..529e31409d4 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxEdgeClassifier.hpp @@ -35,7 +35,7 @@ class OnnxEdgeClassifier final : public Acts::EdgeClassificationBase { std::tuple operator()( std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {}, - torch::Device device = torch::Device(torch::kCPU)) override; + const ExecutionContext &execContext = {}) override; Config config() const { return m_cfg; } torch::Device device() const override { return m_device; }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp index d78139d0732..0a97ab697c3 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/OnnxMetricLearning.hpp @@ -39,7 +39,7 @@ class OnnxMetricLearning final : public Acts::GraphConstructionBase { std::tuple operator()( std::vector& inputValues, std::size_t numNodes, const std::vector& moduleIds, - torch::Device device = torch::Device(torch::kCPU)) override; + const ExecutionContext& execContext = {}) override; Config config() const { return m_cfg; } torch::Device device() const override { return m_device; }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp index 1e35fb08a82..13753713917 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/Stages.hpp @@ -13,6 +13,7 @@ #include #include +#include #include namespace Acts { @@ -20,6 +21,12 @@ namespace Acts { /// Error that is thrown if no edges are found struct NoEdgesError : std::exception {}; +/// Capture the context of the execution +struct ExecutionContext { + torch::Device device{torch::kCPU}; + std::optional stream; +}; + // TODO maybe replace std::any with some kind of variant, // unique_ptr>? // TODO maybe replace input for GraphConstructionBase with some kind of @@ -34,13 +41,12 @@ class GraphConstructionBase { /// then gives the number of features /// @param moduleIds Module IDs of the features (used for module-map-like /// graph construction) - /// @param device Which GPU device to pick. Not relevant for CPU-only builds - /// + /// @param execContext Device & stream information /// @return (node_features, edge_features, edge_index) virtual std::tuple operator()( std::vector &inputValues, std::size_t numNodes, const std::vector &moduleIds, - torch::Device device = torch::Device(torch::kCPU)) = 0; + const ExecutionContext &execContext = {}) = 0; virtual torch::Device device() const = 0; @@ -54,12 +60,12 @@ class EdgeClassificationBase { /// @param nodeFeatures Node tensor with shape (n_nodes, n_node_features) /// @param edgeIndex Edge-index tensor with shape (2, n_edges) /// @param edgeFeatures Edge-feature tensor with shape (n_edges, n_edge_features) - /// @param device Which GPU device to pick. Not relevant for CPU-only builds + /// @param execContext Device & stream information /// /// @return (node_features, edge_features, edge_index, edge_scores) virtual std::tuple operator()( std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {}, - torch::Device device = torch::Device(torch::kCPU)) = 0; + const ExecutionContext &execContext = {}) = 0; virtual torch::Device device() const = 0; @@ -74,13 +80,13 @@ class TrackBuildingBase { /// @param edgeIndex Edge-index tensor with shape (2, n_edges) /// @param edgeScores Scores of the previous edge classification phase /// @param spacepointIDs IDs of the nodes (must have size=n_nodes) - /// @param device Which GPU device to pick. Not relevant for CPU-only builds + /// @param execContext Device & stream information /// /// @return tracks (as vectors of node-IDs) virtual std::vector> operator()( std::any nodeFeatures, std::any edgeIndex, std::any edgeScores, std::vector &spacepointIDs, - torch::Device device = torch::Device(torch::kCPU)) = 0; + const ExecutionContext &execContext = {}) = 0; virtual torch::Device device() const = 0; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp index 4cf92a7115d..e7c1d04e9d6 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchEdgeClassifier.hpp @@ -42,7 +42,7 @@ class TorchEdgeClassifier final : public Acts::EdgeClassificationBase { std::tuple operator()( std::any nodeFeatures, std::any edgeIndex, std::any edgeFeatures = {}, - torch::Device device = torch::Device(torch::kCPU)) override; + const ExecutionContext &execContext = {}) override; Config config() const { return m_cfg; } torch::Device device() const override { return m_device; }; diff --git a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp index 9d87e5c59d5..dba7d7f2220 100644 --- a/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp +++ b/Plugins/ExaTrkX/include/Acts/Plugins/ExaTrkX/TorchMetricLearning.hpp @@ -44,7 +44,7 @@ class TorchMetricLearning final : public Acts::GraphConstructionBase { std::tuple operator()( std::vector &inputValues, std::size_t numNodes, const std::vector &moduleIds, - torch::Device device = torch::Device(torch::kCPU)) override; + const ExecutionContext &execContext = {}) override; Config config() const { return m_cfg; } torch::Device device() const override { return m_device; }; diff --git a/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp b/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp index 0d75b31cedd..d46bc2f58d5 100644 --- a/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp +++ b/Plugins/ExaTrkX/src/BoostTrackBuilding.cpp @@ -48,7 +48,7 @@ namespace Acts { std::vector> BoostTrackBuilding::operator()( std::any /*nodes*/, std::any edges, std::any weights, - std::vector& spacepointIDs, torch::Device) { + std::vector& spacepointIDs, const ExecutionContext& execContext) { ACTS_DEBUG("Start track building"); const auto edgeTensor = std::any_cast(edges).to(torch::kCPU); const auto edgeWeightTensor = diff --git a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp index 6a05ae96f8e..d81e94ab342 100644 --- a/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp +++ b/Plugins/ExaTrkX/src/ExaTrkXPipeline.cpp @@ -39,11 +39,18 @@ std::vector> ExaTrkXPipeline::run( std::vector &features, const std::vector &moduleIds, std::vector &spacepointIDs, const ExaTrkXHook &hook, ExaTrkXTiming *timing) const { + ExecutionContext ctx; + ctx.device = m_graphConstructor->device(); +#ifndef ACTS_EXATRKX_CPUONLY + if (ctx.device.type() == torch::kCUDA) { + ctx.stream = c10::cuda::getStreamFromPool(ctx.device.index()); + } +#endif + try { auto t0 = std::chrono::high_resolution_clock::now(); auto [nodeFeatures, edgeIndex, edgeFeatures] = - (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, - m_graphConstructor->device()); + (*m_graphConstructor)(features, spacepointIDs.size(), moduleIds, ctx); auto t1 = std::chrono::high_resolution_clock::now(); if (timing != nullptr) { @@ -59,7 +66,7 @@ std::vector> ExaTrkXPipeline::run( t0 = std::chrono::high_resolution_clock::now(); auto [newNodeFeatures, newEdgeIndex, newEdgeFeatures, newEdgeScores] = (*edgeClassifier)(std::move(nodeFeatures), std::move(edgeIndex), - std::move(edgeFeatures), edgeClassifier->device()); + std::move(edgeFeatures), ctx); t1 = std::chrono::high_resolution_clock::now(); if (timing != nullptr) { @@ -76,8 +83,7 @@ std::vector> ExaTrkXPipeline::run( t0 = std::chrono::high_resolution_clock::now(); auto res = (*m_trackBuilder)(std::move(nodeFeatures), std::move(edgeIndex), - std::move(edgeScores), spacepointIDs, - m_trackBuilder->device()); + std::move(edgeScores), spacepointIDs, ctx); t1 = std::chrono::high_resolution_clock::now(); if (timing != nullptr) { diff --git a/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp b/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp index 59452daec09..71cdd0097b5 100644 --- a/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/OnnxEdgeClassifier.cpp @@ -90,7 +90,8 @@ std::ostream &operator<<(std::ostream &os, Ort::Value &v) { std::tuple OnnxEdgeClassifier::operator()(std::any inputNodes, std::any inputEdges, - std::any inEdgeFeatures, torch::Device) { + std::any inEdgeFeatures, + const ExecutionContext & /*unused*/) { auto torchDevice = torch::kCPU; Ort::MemoryInfo memoryInfo("Cpu", OrtArenaAllocator, /*device_id*/ 0, OrtMemTypeDefault); diff --git a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp index b81335ad0a9..3aff05eb4e2 100644 --- a/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp +++ b/Plugins/ExaTrkX/src/TorchEdgeClassifier.cpp @@ -67,17 +67,23 @@ TorchEdgeClassifier::~TorchEdgeClassifier() {} std::tuple TorchEdgeClassifier::operator()(std::any inNodeFeatures, std::any inEdgeIndex, - std::any inEdgeFeatures, torch::Device device) { - decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4, t5; + std::any inEdgeFeatures, + const ExecutionContext& execContext) { + const auto& device = execContext.device; + decltype(std::chrono::high_resolution_clock::now()) t0, t1, t2, t3, t4; t0 = std::chrono::high_resolution_clock::now(); ACTS_DEBUG("Start edge classification, use " << device); c10::InferenceMode guard(true); // add a protection to avoid calling for kCPU -#ifndef ACTS_EXATRKX_CPUONLY +#ifdef ACTS_EXATRKX_CPUONLY + assert(device == torch::Device(torch::kCPU)); +#else std::optional device_guard; + std::optional streamGuard; if (device.is_cuda()) { device_guard.emplace(device.index()); + streamGuard.emplace(execContext.stream.value()); } #endif diff --git a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp index 08088a5d904..57cb271351a 100644 --- a/Plugins/ExaTrkX/src/TorchMetricLearning.cpp +++ b/Plugins/ExaTrkX/src/TorchMetricLearning.cpp @@ -69,15 +69,21 @@ TorchMetricLearning::~TorchMetricLearning() {} std::tuple TorchMetricLearning::operator()( std::vector &inputValues, std::size_t numNodes, - const std::vector & /*moduleIds*/, torch::Device device) { + const std::vector & /*moduleIds*/, + const ExecutionContext &execContext) { + const auto &device = execContext.device; ACTS_DEBUG("Start graph construction"); c10::InferenceMode guard(true); // add a protection to avoid calling for kCPU -#ifndef ACTS_EXATRKX_CPUONLY +#ifdef ACTS_EXATRKX_CPUONLY + assert(device == torch::Device(torch::kCPU)); +#else std::optional device_guard; + std::optional streamGuard; if (device.is_cuda()) { device_guard.emplace(device.index()); + streamGuard.emplace(execContext.stream.value()); } #endif