From 110a3bca81d358d02f945ee90dc38ec12d52e2d3 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 5 Nov 2024 15:12:52 -0800 Subject: [PATCH 01/53] save work --- .../onnxruntime/core/framework/ortdevice.h | 1 + .../core/providers/qnn/qnn_allocator.cc | 34 ++++++++++ .../core/providers/qnn/qnn_allocator.h | 24 +++++++ .../core/providers/qnn/rpcmem_library.cc | 67 ++++++++++++++++++ .../core/providers/qnn/rpcmem_library.h | 68 +++++++++++++++++++ onnxruntime/core/session/IOBinding.h | 2 +- 6 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 onnxruntime/core/providers/qnn/qnn_allocator.cc create mode 100644 onnxruntime/core/providers/qnn/qnn_allocator.h create mode 100644 onnxruntime/core/providers/qnn/rpcmem_library.cc create mode 100644 onnxruntime/core/providers/qnn/rpcmem_library.h diff --git a/include/onnxruntime/core/framework/ortdevice.h b/include/onnxruntime/core/framework/ortdevice.h index 6f658ab65be20..adade482f6a17 100644 --- a/include/onnxruntime/core/framework/ortdevice.h +++ b/include/onnxruntime/core/framework/ortdevice.h @@ -25,6 +25,7 @@ struct OrtDevice { static const MemoryType CUDA_PINNED = 1; static const MemoryType HIP_PINNED = 2; static const MemoryType CANN_PINNED = 3; + static const MemoryType QNN_HTP_SHARED = 4; }; constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_) diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc new file mode 100644 index 0000000000000..d0c26f0aaca6c --- /dev/null +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/qnn_allocator.h" + +#include + +#include "core/providers/qnn/rpcmem_library.h" + +namespace onnxruntime::qnn { + +RpcMemAllocator::RpcMemAllocator(const RpcMemApi& rpc_mem_api) + : IAllocator{OrtMemoryInfo{"TODO name the allocator", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device id */ 0}, + 0, OrtMemTypeCPUOutput}}, + rpc_mem_api_{rpc_mem_api} { +} + +void* RpcMemAllocator::Alloc(size_t size) { + // rpcmem_alloc() has an int size parameter. + constexpr size_t max_size = std::numeric_limits::max(); + if (size > max_size) { + return nullptr; + } + + return rpc_mem_api_.alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, + static_cast(size)); +} + +void RpcMemAllocator::Free(void* p) { + rpc_mem_api_.free(p); +} + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h new file mode 100644 index 0000000000000..a9bda981781fa --- /dev/null +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/allocator.h" + +namespace onnxruntime::qnn { + +struct RpcMemApi; + +class RpcMemAllocator : public IAllocator { + public: + RpcMemAllocator(const RpcMemApi& rpc_mem_api); + + void* Alloc(size_t size) override; + void Free(void* p) override; + // void GetStats(AllocatorStats* stats) override; + + private: + const RpcMemApi& rpc_mem_api_; +}; + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc new file mode 100644 index 0000000000000..234a154a763a6 --- /dev/null +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#include "core/providers/qnn/rpcmem_library.h" + +#include "core/common/logging/logging.h" +#include "core/platform/env.h" + +namespace onnxruntime::qnn { + +namespace { + +const PathChar* GetRpcMemSharedLibraryPath() { +#if defined(_WIN32) + return ORT_TSTR("libcdsprpc.dll"); +#else + return ORT_TSTR("libcdsprpc.so"); +#endif +} + +SharedLibraryHandle LoadSharedLibrary(const PathString& path, bool global_symbols) { + // Custom deleter to unload the shared library. Avoid throwing from it because it may run in dtor. + const auto unload_shared_library = [](void* shared_library_handle) { + if (shared_library_handle == nullptr) { + return; + } + + const auto& env = Env::Default(); + const auto unload_status = env.UnloadDynamicLibrary(shared_library_handle); + + if (!unload_status.IsOK()) { + LOGS_DEFAULT(WARNING) << "Failed to unload shared library. Error: " << unload_status.ErrorMessage(); + } + }; + + const auto& env = Env::Default(); + void* shared_library_handle = nullptr; + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(path, global_symbols, &shared_library_handle)); + + return SharedLibraryHandle{shared_library_handle, unload_shared_library}; +} + +RpcMemApi CreateApi(void* shared_library_handle) { + RpcMemApi api{}; + + const auto& env = Env::Default(); + void* symbol = nullptr; + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(shared_library_handle, "rpcmem_alloc", &symbol)); + api.alloc = static_cast(symbol); + + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(shared_library_handle, "rpcmem_free", &symbol)); + api.free = static_cast(symbol); + + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(shared_library_handle, "rpcmem_to_fd", &symbol)); + api.to_fd = static_cast(symbol); + + return api; +} + +} // namespace + +RpcMemLibrary::RpcMemLibrary() + : shared_library_(LoadSharedLibrary(GetRpcMemSharedLibraryPath(), /* global_symbols */ false)), + api_{CreateApi(shared_library_.get())} { +} + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.h b/onnxruntime/core/providers/qnn/rpcmem_library.h new file mode 100644 index 0000000000000..c9e6b7cf7ec6d --- /dev/null +++ b/onnxruntime/core/providers/qnn/rpcmem_library.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#pragma once + +#include +#include + +#include "core/common/common.h" + +namespace onnxruntime::qnn { + +using SharedLibraryHandle = std::unique_ptr; + +// This namespace contains constants and typedefs corresponding to functions from rpcmem.h. +// https://github.com/quic/fastrpc/blob/v0.1.1/inc/rpcmem.h +namespace rpcmem { + +constexpr uint32_t RPCMEM_DEFAULT_FLAGS = 1; + +constexpr int RPCMEM_HEAP_ID_SYSTEM = 25; + +/** + * Allocate a zero-copy buffer for size upto 2 GB with the FastRPC framework. + * Buffers larger than 2 GB must be allocated with rpcmem_alloc2 + * @param[in] heapid Heap ID to use for memory allocation. + * @param[in] flags ION flags to use for memory allocation. + * @param[in] size Buffer size to allocate. + * @return Pointer to the buffer on success; NULL on failure. + */ +using AllocFnPtr = void* (*)(int heapid, uint32_t flags, int size); + +/** + * Free a buffer and ignore invalid buffers. + */ +using FreeFnPtr = void (*)(void* po); + +/** + * Return an associated file descriptor. + * @param[in] po Data pointer for an RPCMEM-allocated buffer. + * @return Buffer file descriptor. + */ +using ToFdFnPtr = int (*)(void* po); + +} // namespace rpcmem + +// RPCMEM API function pointers. +struct RpcMemApi { + rpcmem::AllocFnPtr alloc; + rpcmem::FreeFnPtr free; + rpcmem::ToFdFnPtr to_fd; +}; + +// Loads and provides access to the RPCMEM API functions from a dynamically loaded library. +class RpcMemLibrary { + public: + RpcMemLibrary(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(RpcMemLibrary); + + const RpcMemApi& Api() const { return api_; } + + private: + SharedLibraryHandle shared_library_; + RpcMemApi api_; +}; + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/session/IOBinding.h b/onnxruntime/core/session/IOBinding.h index 1f1b3b8073f96..d5a1e273369a1 100644 --- a/onnxruntime/core/session/IOBinding.h +++ b/onnxruntime/core/session/IOBinding.h @@ -51,7 +51,7 @@ class IOBinding { /** * If the BindInput calls are async this function acts as a barrier to ensure all inputs are fully copied - * before you call the Run() method. There is no point calling Run() if you're inputs are not ready at the + * before you call the Run() method. There is no point calling Run() if your inputs are not ready at the * desired location. * This is a blocking call and is a wrapper over IExecutionProvider::Sync(). * Call InferenceSession::Run() only after calling this method or else you'll end up wasting cycles inside Run(). From 0ba3a2fb46a53378db8eb9dfae91a6ac6844d333 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 8 Nov 2024 18:43:57 -0800 Subject: [PATCH 02/53] save work --- .../onnxruntime/core/framework/allocator.h | 1 + .../core/framework/ortmemoryinfo.h | 1 + .../core/session/onnxruntime_cxx_api.h | 4 +- onnxruntime/core/framework/allocator.cc | 11 +- onnxruntime/core/framework/session_state.cc | 2 +- .../core/providers/qnn/builder/qnn_def.cc | 30 +++ .../core/providers/qnn/builder/qnn_def.h | 2 + .../core/providers/qnn/builder/qnn_model.cc | 175 +++++++++++++----- .../core/providers/qnn/builder/qnn_model.h | 7 +- .../core/providers/qnn/qnn_allocator.cc | 22 ++- .../core/providers/qnn/qnn_allocator.h | 11 +- .../providers/qnn/qnn_execution_provider.cc | 52 +++++- .../providers/qnn/qnn_execution_provider.h | 12 +- .../core/providers/qnn/rpcmem_library.cc | 30 ++- .../core/providers/qnn/rpcmem_library.h | 4 +- 15 files changed, 268 insertions(+), 96 deletions(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 57b332ce65b93..525277375830c 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -52,6 +52,7 @@ constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; constexpr const char* OpenVINO_RT = "OpenVINO_RT"; constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU"; +constexpr const char* QNN_HTP_SHARED = "QnnHtpShared"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; constexpr const char* WEBNN_TENSOR = "WebNN_Tensor"; diff --git a/include/onnxruntime/core/framework/ortmemoryinfo.h b/include/onnxruntime/core/framework/ortmemoryinfo.h index 7af5554e25c0b..d060c6546ae27 100644 --- a/include/onnxruntime/core/framework/ortmemoryinfo.h +++ b/include/onnxruntime/core/framework/ortmemoryinfo.h @@ -6,6 +6,7 @@ #include #include "core/common/hash_combine.h" +#include "core/framework/ortdevice.h" struct OrtMemoryInfo { OrtMemoryInfo() = default; // to allow default construction of Tensor diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index f3e9758766d00..0a57999246b06 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2130,10 +2130,10 @@ struct KernelContext { explicit KernelContext(OrtKernelContext* context); size_t GetInputCount() const; size_t GetOutputCount() const; - // If input is optional and is not present, the method returns en empty ConstValue + // If input is optional and is not present, the method returns an empty ConstValue // which can be compared to nullptr. ConstValue GetInput(size_t index) const; - // If outout is optional and is not present, the method returns en empty UnownedValue + // If outout is optional and is not present, the method returns an empty UnownedValue // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 26b98b0a04d24..02dbb3e518783 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -155,11 +155,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA mem_type1); } else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) { *out = new OrtMemoryInfo( - onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), + onnxruntime::CUDA_PINNED, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast(id1)), id1, mem_type1); } else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) { *out = new OrtMemoryInfo( - onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(id1)), + onnxruntime::HIP_PINNED, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast(id1)), + id1, mem_type1); + } else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) { + *out = new OrtMemoryInfo( + onnxruntime::QNN_HTP_SHARED, type, + OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast(id1)), id1, mem_type1); } else { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported."); diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index 943db091b341f..ac1c42da20903 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -101,7 +101,7 @@ SessionState::SessionState(Graph& graph, for (auto& ep : execution_providers_) { auto allocators = ep->CreatePreferredAllocators(); for (auto& alloc : allocators) { - allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key + allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key } } } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index c0fc079979822..5af7f024716f1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -208,6 +208,22 @@ void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data) ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); } +void SetQnnTensorMemHandle(Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t mem_handle) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + qnn_tensor.v1.memHandle = mem_handle; + return; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + qnn_tensor.v2.memHandle = mem_handle; + return; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params) { if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { qnn_tensor.v1.quantizeParams = quantize_params; @@ -350,6 +366,20 @@ const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor) ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); } +Qnn_MemHandle_t GetQnnTensorMemHandle(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return qnn_tensor.v1.memHandle; + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.memHandle; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor) { if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { return qnn_tensor.v1.quantizeParams; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index ffd2dc9b11010..e8e5453afa48b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -105,6 +105,7 @@ void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, const std::vector void SetQnnTensorClientBuf(Qnn_Tensor_t& qnn_tensor, void* buf_data, uint32_t buf_size); void SetQnnTensorClientBufSize(Qnn_Tensor_t& qnn_tensor, uint32_t client_buf_size); void SetQnnTensorClientBufData(Qnn_Tensor_t& qnn_tensor, void* client_buf_data); +void SetQnnTensorMemHandle(Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t mem_handle); void SetQnnTensorQParams(Qnn_Tensor_t& qnn_tensor, const Qnn_QuantizeParams_t& quantize_params); bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface, const Qnn_GraphHandle_t& graph, @@ -123,6 +124,7 @@ Qnn_TensorMemType_t GetQnnTensorMemType(const Qnn_Tensor_t& qnn_tensor); uint32_t GetQnnTensorRank(const Qnn_Tensor_t& qnn_tensor); uint32_t* GetQnnTensorDims(const Qnn_Tensor_t& qnn_tensor); const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor); +Qnn_MemHandle_t GetQnnTensorMemHandle(const Qnn_Tensor_t& qnn_tensor); const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor); /** diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index dc797fef2d42a..67980be8f341b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -4,15 +4,17 @@ #include "qnn_model.h" #include +#include #include "QnnOpDef.h" -#include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_node_group.h" -#include "core/providers/shared/utils/utils.h" #include "core/framework/utils.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/qnn_allocator.h" +#include "core/providers/shared/utils/utils.h" namespace onnxruntime { namespace qnn { @@ -185,7 +187,53 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { return Status::OK(); } -Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger) { +static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_interface, + const RpcMemApi* rpcmem_api, + Qnn_ContextHandle_t qnn_context_handle, + const OrtMemoryInfo& ort_value_memory_info, + void* ort_value_data, uint32_t ort_value_data_size, + Qnn_Tensor_t& qnn_tensor, + std::vector& registered_qnn_mem_handles) { + // either set qnn_tensor memHandle or clientBuf + const bool uses_shared_memory = ort_value_memory_info == RpcMemAllocator::MemoryInfo(); + + if (!uses_shared_memory) { + SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_RAW); + SetQnnTensorClientBuf(qnn_tensor, ort_value_data, ort_value_data_size); + } else { + ORT_RETURN_IF(rpcmem_api == nullptr, "RPCMEM API must be available when using shared memory."); + + // get RpcMem file descriptor from shared memory + const auto shared_memory_fd = rpcmem_api->to_fd(ort_value_data); + ORT_RETURN_IF(shared_memory_fd == -1, "rpcmem_to_fd() returned invalid file descriptor."); + + // set up QNN memory descriptor + // note: we only support a single tensor per shared memory buffer (QNN_MEM_TYPE_ION) now + Qnn_MemDescriptor_t qnn_mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; + qnn_mem_descriptor.memShape = {GetQnnTensorRank(qnn_tensor), + GetQnnTensorDims(qnn_tensor), + nullptr}; + qnn_mem_descriptor.dataType = GetQnnTensorDataType(qnn_tensor); + qnn_mem_descriptor.memType = QNN_MEM_TYPE_ION; + qnn_mem_descriptor.ionInfo.fd = shared_memory_fd; + + Qnn_MemHandle_t qnn_mem_handle = nullptr; + const auto register_status = qnn_interface.memRegister(qnn_context_handle, &qnn_mem_descriptor, 1, + &qnn_mem_handle); + // TODO show error message + ORT_RETURN_IF(register_status != QNN_SUCCESS, "qnnInterface.memRegister() failed with error code ", register_status); + + registered_qnn_mem_handles.push_back(qnn_mem_handle); + + SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); + SetQnnTensorMemHandle(qnn_tensor, qnn_mem_handle); + } + + return Status::OK(); +} + +Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi* rpcmem_api, + const logging::Logger& logger) { LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs"; const size_t num_inputs = context.GetInputCount(); const size_t num_outputs = context.GetOutputCount(); @@ -193,7 +241,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging:: ORT_RETURN_IF_NOT(qnn_output_infos_.size() == num_outputs, "Inconsistent output sizes"); using namespace qnn::utils; - auto TensorDataSize = [&](auto ort_tensor) -> size_t { + auto TensorDataSize = [](auto ort_tensor) -> size_t { auto tensor_type_and_shape = ort_tensor.GetTensorTypeAndShapeInfo(); size_t length = tensor_type_and_shape.GetElementCount(); ONNXTensorElementDataType element_type = tensor_type_and_shape.GetElementType(); @@ -201,53 +249,84 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging:: return element_size * length; }; - std::vector qnn_inputs; - qnn_inputs.reserve(qnn_input_infos_.size()); - - for (const auto& qnn_input_info : qnn_input_infos_) { - LOGS(logger, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() - << " index = " << qnn_input_info.ort_index; - auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); - auto ort_tensor_size = TensorDataSize(ort_input_tensor); - LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; - ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, - "ORT Tensor data size does not match QNN tensor data size."); - - qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); - SetQnnTensorClientBuf(qnn_inputs.back(), - const_cast(ort_input_tensor.GetTensorData()), qnn_input_info.tensor_byte_size); - } - - std::vector qnn_outputs; - qnn_outputs.reserve(qnn_output_infos_.size()); - - for (auto& qnn_output_info : qnn_output_infos_) { - const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); - LOGS(logger, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; - const auto& ort_output_info = GetOutputInfo(model_output_name); - const std::vector& output_shape = ort_output_info->shape_; - auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); - auto ort_tensor_size = TensorDataSize(ort_output_tensor); - LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size - << "Ort tensor size: " << ort_tensor_size; - ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, - "ORT Tensor data size does not match QNN tensor data size"); - - qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); - SetQnnTensorClientBuf(qnn_outputs.back(), - const_cast(ort_output_tensor.GetTensorData()), qnn_output_info.tensor_byte_size); - } - - LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); - auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); - auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; { - // Acquire mutex before calling graphExecute and profiling APIs to support calling session.Run() - // from multiple threads. + // Acquire mutex before calling QNN APIs to support calling session.Run() from multiple threads. std::lock_guard lock(graph_exec_mutex_); + + const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); + + std::vector registered_qnn_mem_handles{}; + registered_qnn_mem_handles.reserve(qnn_input_infos_.size() + qnn_output_infos_.size()); + + const auto registered_qnn_mem_handle_cleanup = + gsl::finally([®istered_qnn_mem_handles, &qnn_interface, &logger] { + if (!registered_qnn_mem_handles.empty()) { + auto deregister_status = qnn_interface.memDeRegister(registered_qnn_mem_handles.data(), + static_cast(registered_qnn_mem_handles.size())); + if (deregister_status != QNN_SUCCESS) { + LOGS(logger, ERROR) << "qnnInterface.memDeRegister() failed with error code " << deregister_status; + } + } + }); + + const Qnn_ContextHandle_t qnn_context_handle = qnn_backend_manager_->GetQnnContext(); + + std::vector qnn_inputs; + qnn_inputs.reserve(qnn_input_infos_.size()); + + for (const auto& qnn_input_info : qnn_input_infos_) { + LOGS(logger, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() + << " index = " << qnn_input_info.ort_index; + auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); + auto ort_tensor_size = TensorDataSize(ort_input_tensor); + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size + << " Ort tensor size: " << ort_tensor_size; + ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, + "ORT Tensor data size does not match QNN tensor data size."); + + qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); + + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + qnn_interface, + rpcmem_api, + qnn_context_handle, + *static_cast(ort_input_tensor.GetTensorMemoryInfo()), + const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, + qnn_inputs.back(), + registered_qnn_mem_handles)); + } + + std::vector qnn_outputs; + qnn_outputs.reserve(qnn_output_infos_.size()); + + for (auto& qnn_output_info : qnn_output_infos_) { + const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); + LOGS(logger, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; + const auto& ort_output_info = GetOutputInfo(model_output_name); + const std::vector& output_shape = ort_output_info->shape_; + auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); + auto ort_tensor_size = TensorDataSize(ort_output_tensor); + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size + << " Ort tensor size: " << ort_tensor_size; + ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, + "ORT Tensor data size does not match QNN tensor data size"); + + qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); + + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + qnn_interface, + rpcmem_api, + qnn_context_handle, + *static_cast(ort_output_tensor.GetTensorMemoryInfo()), + const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, + qnn_outputs.back(), + registered_qnn_mem_handles)); + } + + LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); + auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); execute_status = qnn_interface.graphExecute(graph_info_->Graph(), qnn_inputs.data(), static_cast(qnn_inputs.size()), diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 2e0935391ca78..5fca33759f7f7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -3,15 +3,16 @@ #pragma once +#include #include #include "core/common/status.h" #include "core/framework/node_unit.h" #include "core/graph/graph_viewer.h" -#include #include "core/providers/qnn/builder/qnn_def.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" #include "core/providers/qnn/builder/qnn_backend_manager.h" +#include "core/providers/qnn/rpcmem_library.h" #include "core/session/onnxruntime_cxx_api.h" namespace onnxruntime { @@ -43,7 +44,9 @@ class QnnModel { Status SetupQnnInputOutput(const logging::Logger& logger); - Status ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger); + Status ExecuteGraph(const Ort::KernelContext& context, + const RpcMemApi* rpcmem_api, + const logging::Logger& logger); const OnnxTensorInfo* GetOutputInfo(const std::string& name) const { auto it = outputs_info_.find(name); diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index d0c26f0aaca6c..e9320bbcdb5f2 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -5,15 +5,21 @@ #include +#include "core/common/common.h" #include "core/providers/qnn/rpcmem_library.h" namespace onnxruntime::qnn { -RpcMemAllocator::RpcMemAllocator(const RpcMemApi& rpc_mem_api) - : IAllocator{OrtMemoryInfo{"TODO name the allocator", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device id */ 0}, - 0, OrtMemTypeCPUOutput}}, - rpc_mem_api_{rpc_mem_api} { +OrtMemoryInfo RpcMemAllocator::MemoryInfo() { + return OrtMemoryInfo{QNN_HTP_SHARED, OrtAllocatorType::OrtDeviceAllocator, + OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device_id */ 0}, + /* id */ 0, OrtMemTypeDefault}; +} + +RpcMemAllocator::RpcMemAllocator(std::shared_ptr rpc_mem_lib) + : IAllocator{MemoryInfo()}, + rpc_mem_lib_{std::move(rpc_mem_lib)} { + ORT_ENFORCE(rpc_mem_lib_ != nullptr, "rpc_mem_lib_ must not be nullptr"); } void* RpcMemAllocator::Alloc(size_t size) { @@ -23,12 +29,12 @@ void* RpcMemAllocator::Alloc(size_t size) { return nullptr; } - return rpc_mem_api_.alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, - static_cast(size)); + return rpc_mem_lib_->Api().alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, + static_cast(size)); } void RpcMemAllocator::Free(void* p) { - rpc_mem_api_.free(p); + rpc_mem_lib_->Api().free(p); } } // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index a9bda981781fa..8a38c626cd809 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -3,22 +3,27 @@ #pragma once +#include + #include "core/framework/allocator.h" namespace onnxruntime::qnn { -struct RpcMemApi; +class RpcMemLibrary; class RpcMemAllocator : public IAllocator { public: - RpcMemAllocator(const RpcMemApi& rpc_mem_api); + // Gets the single OrtMemoryInfo value that is associated with this allocator type. + static OrtMemoryInfo MemoryInfo(); + + RpcMemAllocator(std::shared_ptr rpc_mem_lib); void* Alloc(size_t size) override; void Free(void* p) override; // void GetStats(AllocatorStats* stats) override; private: - const RpcMemApi& rpc_mem_api_; + std::shared_ptr rpc_mem_lib_; }; } // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 6735528bebbf9..c0ffd14e58001 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -5,24 +5,27 @@ #include #include + #include "core/framework/compute_capability.h" -#include "core/graph/graph_viewer.h" -#include "core/session/onnxruntime_session_options_config_keys.h" -#include "core/session/onnxruntime_run_options_config_keys.h" -#include "core/session/onnxruntime_cxx_api.h" #include "core/framework/kernel_registry.h" +#include "core/framework/run_options.h" +#include "core/graph/graph_viewer.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" #include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/op_builder_factory.h" -#include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_def.h" -#include "core/providers/qnn/builder/onnx_ctx_model_helper.h" -#include "core/framework/run_options.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group.h" +#include "core/providers/qnn/qnn_allocator.h" +#include "core/providers/qnn/rpcmem_library.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #ifdef _WIN32 #include @@ -386,6 +389,13 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio << "handles the graph I/O quantization/dequantization."; } + static const std::string QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED = "enable_htp_shared_memory_allocator"; + if (ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map)) { + // Initialize rpcmem_library_. + // This is necessary for RpcMemAllocator to function and also indicates that it is available. + rpcmem_library_ = std::make_shared(); + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level_etw, @@ -814,10 +824,11 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod ORT_UNUSED_PARAMETER(state); }; - compute_info.compute_func = [&logger](FunctionState state, const OrtApi*, OrtKernelContext* context) { + compute_info.compute_func = [this, &logger](FunctionState state, const OrtApi*, OrtKernelContext* context) { Ort::KernelContext ctx(context); + const qnn::RpcMemApi* rpcmem_api = rpcmem_library_ ? &rpcmem_library_->Api() : nullptr; qnn::QnnModel* model = reinterpret_cast(state); - Status result = model->ExecuteGraph(ctx, logger); + Status result = model->ExecuteGraph(ctx, rpcmem_api, logger); return result; }; @@ -1152,4 +1163,25 @@ Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::R return Status::OK(); } + +std::vector QNNExecutionProvider::CreatePreferredAllocators() { + std::vector allocators{}; + + if (IsRpcMemAllocatorAvailable()) { + LOGS_DEFAULT(INFO) << "Creating RpcMemAllocator."; + + AllocatorFactory rpcmem_allocator_factory = [this](OrtDevice::DeviceId) { + return std::make_unique(rpcmem_library_); + }; + + AllocatorCreationInfo rpcmem_allocator_creation_info{rpcmem_allocator_factory, + /* device_id */ 0, + /* use_arena */ false}; + + allocators.emplace_back(CreateAllocator(rpcmem_allocator_creation_info)); + } + + return allocators; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 35c061de6132c..82361adb90349 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -7,13 +7,15 @@ #include "core/framework/session_options.h" #include "core/framework/model_metadef_id_generator.h" #include "core/graph/model.h" -#include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "core/providers/qnn/rpcmem_library.h" #include "HTP/QnnHtpGraph.h" +#include #include #include +#include #include #ifdef _WIN32 #include "core/platform/windows/logging/etw_sink.h" @@ -113,6 +115,8 @@ class QNNExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; + std::vector CreatePreferredAllocators() override; + private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, @@ -132,6 +136,8 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); + bool IsRpcMemAllocatorAvailable() const { return rpcmem_library_ != nullptr; } + private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; std::unique_ptr qnn_backend_manager_; @@ -155,6 +161,10 @@ class QNNExecutionProvider : public IExecutionProvider { #endif qnn::ModelSettings model_settings_ = {}; + // Whether this is set depends on a session option enabling it and if the RPCMEM dynamic library is available. + // It is shared with RpcMemAllocator which is returned by CreatePreferredAllocators(). + std::shared_ptr rpcmem_library_ = nullptr; + class PerThreadContext final { public: PerThreadContext(qnn::QnnBackendManager* qnn_backend_manager, diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc index 234a154a763a6..77a340ddfcea1 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.cc +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -18,15 +18,15 @@ const PathChar* GetRpcMemSharedLibraryPath() { #endif } -SharedLibraryHandle LoadSharedLibrary(const PathString& path, bool global_symbols) { +DynamicLibraryHandle LoadDynamicLibrary(const PathString& path, bool global_symbols) { // Custom deleter to unload the shared library. Avoid throwing from it because it may run in dtor. - const auto unload_shared_library = [](void* shared_library_handle) { - if (shared_library_handle == nullptr) { + const auto unload_library = [](void* library_handle) { + if (library_handle == nullptr) { return; } const auto& env = Env::Default(); - const auto unload_status = env.UnloadDynamicLibrary(shared_library_handle); + const auto unload_status = env.UnloadDynamicLibrary(library_handle); if (!unload_status.IsOK()) { LOGS_DEFAULT(WARNING) << "Failed to unload shared library. Error: " << unload_status.ErrorMessage(); @@ -34,25 +34,21 @@ SharedLibraryHandle LoadSharedLibrary(const PathString& path, bool global_symbol }; const auto& env = Env::Default(); - void* shared_library_handle = nullptr; - ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(path, global_symbols, &shared_library_handle)); + void* library_handle = nullptr; + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(path, global_symbols, &library_handle)); - return SharedLibraryHandle{shared_library_handle, unload_shared_library}; + return DynamicLibraryHandle{library_handle, unload_library}; } -RpcMemApi CreateApi(void* shared_library_handle) { +RpcMemApi CreateApi(void* library_handle) { RpcMemApi api{}; const auto& env = Env::Default(); - void* symbol = nullptr; - ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(shared_library_handle, "rpcmem_alloc", &symbol)); - api.alloc = static_cast(symbol); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_alloc", (void**)&api.alloc)); - ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(shared_library_handle, "rpcmem_free", &symbol)); - api.free = static_cast(symbol); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_free", (void**)&api.free)); - ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(shared_library_handle, "rpcmem_to_fd", &symbol)); - api.to_fd = static_cast(symbol); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_to_fd", (void**)&api.to_fd)); return api; } @@ -60,8 +56,8 @@ RpcMemApi CreateApi(void* shared_library_handle) { } // namespace RpcMemLibrary::RpcMemLibrary() - : shared_library_(LoadSharedLibrary(GetRpcMemSharedLibraryPath(), /* global_symbols */ false)), - api_{CreateApi(shared_library_.get())} { + : library_handle_(LoadDynamicLibrary(GetRpcMemSharedLibraryPath(), /* global_symbols */ false)), + api_{CreateApi(library_handle_.get())} { } } // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.h b/onnxruntime/core/providers/qnn/rpcmem_library.h index c9e6b7cf7ec6d..d5697ff298e79 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.h +++ b/onnxruntime/core/providers/qnn/rpcmem_library.h @@ -10,7 +10,7 @@ namespace onnxruntime::qnn { -using SharedLibraryHandle = std::unique_ptr; +using DynamicLibraryHandle = std::unique_ptr; // This namespace contains constants and typedefs corresponding to functions from rpcmem.h. // https://github.com/quic/fastrpc/blob/v0.1.1/inc/rpcmem.h @@ -61,7 +61,7 @@ class RpcMemLibrary { const RpcMemApi& Api() const { return api_; } private: - SharedLibraryHandle shared_library_; + DynamicLibraryHandle library_handle_; RpcMemApi api_; }; From 8436b14af6f8d0a52ccd7db333b1672619be59ec Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:46:31 -0800 Subject: [PATCH 03/53] add logging for setting QNN tensor memory, update comment --- onnxruntime/core/providers/qnn/builder/qnn_model.cc | 5 +++++ onnxruntime/core/providers/qnn/qnn_execution_provider.cc | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 67980be8f341b..84c7286bae73b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -189,6 +189,7 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_interface, const RpcMemApi* rpcmem_api, + const logging::Logger& logger, Qnn_ContextHandle_t qnn_context_handle, const OrtMemoryInfo& ort_value_memory_info, void* ort_value_data, uint32_t ort_value_data_size, @@ -198,6 +199,7 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in const bool uses_shared_memory = ort_value_memory_info == RpcMemAllocator::MemoryInfo(); if (!uses_shared_memory) { + LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t clientBuf to ORT tensor memory."; SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_RAW); SetQnnTensorClientBuf(qnn_tensor, ort_value_data, ort_value_data_size); } else { @@ -225,6 +227,7 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in registered_qnn_mem_handles.push_back(qnn_mem_handle); + LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t memHandle to ORT tensor shared memory."; SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); SetQnnTensorMemHandle(qnn_tensor, qnn_mem_handle); } @@ -291,6 +294,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( qnn_interface, rpcmem_api, + logger, qnn_context_handle, *static_cast(ort_input_tensor.GetTensorMemoryInfo()), const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, @@ -318,6 +322,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( qnn_interface, rpcmem_api, + logger, qnn_context_handle, *static_cast(ort_output_tensor.GetTensorMemoryInfo()), const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index c0ffd14e58001..a4477c7df0cf7 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -392,7 +392,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio static const std::string QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED = "enable_htp_shared_memory_allocator"; if (ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map)) { // Initialize rpcmem_library_. - // This is necessary for RpcMemAllocator to function and also indicates that it is available. + // This is necessary for RpcMemAllocator to function and also indicates that the allocator is available. rpcmem_library_ = std::make_shared(); } From c9826f44e01d3915697672a6445ee4fea474dc4e Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 11 Nov 2024 11:47:09 -0800 Subject: [PATCH 04/53] add option to enable HTP shared memory allocator to onnxruntime_perf_test --- .../test/perftest/command_args_parser.cc | 2 ++ onnxruntime/test/perftest/ort_test_session.cc | 22 +++++++++++++------ onnxruntime/test/perftest/ort_test_session.h | 2 +- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index e40544d950ed7..43fb22e5c9293 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -100,6 +100,8 @@ namespace perftest { "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [enable_htp_shared_memory_allocator]: Enable the QNN HTP shared memory allocator and use it for inputs and outputs.\n" + "\t Defaults to '0' (disabled).\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" "\t [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 8f2e5282ede9a..82a6ddd67db1a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -280,7 +280,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device } else if (key == "qnn_saver_path") { // no validation } else if (key == "htp_graph_finalization_optimization_mode") { - std::unordered_set supported_htp_graph_final_opt_modes = {"0", "1", "2", "3"}; + std::set supported_htp_graph_final_opt_modes = {"0", "1", "2", "3"}; if (supported_htp_graph_final_opt_modes.find(value) == supported_htp_graph_final_opt_modes.end()) { std::ostringstream str_stream; std::copy(supported_htp_graph_final_opt_modes.begin(), supported_htp_graph_final_opt_modes.end(), @@ -294,7 +294,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } } else if (key == "htp_arch") { - std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + std::set supported_htp_archs = {"0", "68", "69", "73", "75"}; if (supported_htp_archs.find(value) == supported_htp_archs.end()) { std::ostringstream str_stream; std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), @@ -302,8 +302,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { - std::unordered_set supported_options = {"0", "1"}; + } else if (key == "enable_htp_fp16_precision" || + key == "offload_graph_io_quantization" || + key == "enable_htp_shared_memory_allocator") { + std::set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; std::copy(supported_options.begin(), supported_options.end(), @@ -311,11 +313,17 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device std::string str = str_stream.str(); ORT_THROW("Wrong value for ", key, ". select from: ", str); } + + if (key == "enable_htp_shared_memory_allocator" && value == "1") { + // if this option is set, also use the enabled allocator + device_memory_name_ = "QnnHtpShared"; + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', 'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', -'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization'])"); +'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization', +'enable_htp_shared_memory_allocator'])"); } qnn_options[key] = value; @@ -932,8 +940,8 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); }; } else { Ort::MemoryInfo memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); - custom_allocator_ = std::make_unique(session_, memory_info); - allocator_ = *custom_allocator_; + custom_allocator_ = Ort::Allocator(session_, memory_info); + allocator_ = custom_allocator_; // free dimensions are treated as 1 if not overridden transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; diff --git a/onnxruntime/test/perftest/ort_test_session.h b/onnxruntime/test/perftest/ort_test_session.h index 7d5e46983ad41..d6580812da8f0 100644 --- a/onnxruntime/test/perftest/ort_test_session.h +++ b/onnxruntime/test/perftest/ort_test_session.h @@ -39,7 +39,7 @@ class OnnxRuntimeTestSession : public TestSession { std::uniform_int_distribution dist_; std::vector> test_inputs_; OrtAllocator* allocator_ = Ort::AllocatorWithDefaultOptions(); - std::unique_ptr custom_allocator_; + Ort::Allocator custom_allocator_{nullptr}; std::vector outputs_; std::vector output_names_; // The same size with output_names_. From c07c35e5cadcea56f93b9113917ba60e688b5b24 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 11 Nov 2024 19:46:02 -0800 Subject: [PATCH 05/53] hack - try to cache mem handles in QnnModel --- .../core/providers/qnn/builder/qnn_model.cc | 82 ++++++++++--------- .../core/providers/qnn/builder/qnn_model.h | 5 +- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 84c7286bae73b..a656ff4328541 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -194,7 +194,7 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in const OrtMemoryInfo& ort_value_memory_info, void* ort_value_data, uint32_t ort_value_data_size, Qnn_Tensor_t& qnn_tensor, - std::vector& registered_qnn_mem_handles) { + std::unordered_map& qnn_mem_handles) { // either set qnn_tensor memHandle or clientBuf const bool uses_shared_memory = ort_value_memory_info == RpcMemAllocator::MemoryInfo(); @@ -205,27 +205,33 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in } else { ORT_RETURN_IF(rpcmem_api == nullptr, "RPCMEM API must be available when using shared memory."); - // get RpcMem file descriptor from shared memory - const auto shared_memory_fd = rpcmem_api->to_fd(ort_value_data); - ORT_RETURN_IF(shared_memory_fd == -1, "rpcmem_to_fd() returned invalid file descriptor."); - - // set up QNN memory descriptor - // note: we only support a single tensor per shared memory buffer (QNN_MEM_TYPE_ION) now - Qnn_MemDescriptor_t qnn_mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; - qnn_mem_descriptor.memShape = {GetQnnTensorRank(qnn_tensor), - GetQnnTensorDims(qnn_tensor), - nullptr}; - qnn_mem_descriptor.dataType = GetQnnTensorDataType(qnn_tensor); - qnn_mem_descriptor.memType = QNN_MEM_TYPE_ION; - qnn_mem_descriptor.ionInfo.fd = shared_memory_fd; - - Qnn_MemHandle_t qnn_mem_handle = nullptr; - const auto register_status = qnn_interface.memRegister(qnn_context_handle, &qnn_mem_descriptor, 1, - &qnn_mem_handle); - // TODO show error message - ORT_RETURN_IF(register_status != QNN_SUCCESS, "qnnInterface.memRegister() failed with error code ", register_status); - - registered_qnn_mem_handles.push_back(qnn_mem_handle); + Qnn_MemHandle_t qnn_mem_handle; + auto qnn_mem_handle_it = qnn_mem_handles.find(ort_value_data); + if (qnn_mem_handle_it != qnn_mem_handles.end()) { + qnn_mem_handle = qnn_mem_handle_it->second; + } else { + // get RpcMem file descriptor from shared memory + const auto shared_memory_fd = rpcmem_api->to_fd(ort_value_data); + ORT_RETURN_IF(shared_memory_fd == -1, "rpcmem_to_fd() returned invalid file descriptor."); + + // set up QNN memory descriptor + // note: we only support a single tensor per shared memory buffer (QNN_MEM_TYPE_ION) now + Qnn_MemDescriptor_t qnn_mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; + qnn_mem_descriptor.memShape = {GetQnnTensorRank(qnn_tensor), + GetQnnTensorDims(qnn_tensor), + nullptr}; + qnn_mem_descriptor.dataType = GetQnnTensorDataType(qnn_tensor); + qnn_mem_descriptor.memType = QNN_MEM_TYPE_ION; + qnn_mem_descriptor.ionInfo.fd = shared_memory_fd; + + qnn_mem_handle = nullptr; + const auto register_status = qnn_interface.memRegister(qnn_context_handle, &qnn_mem_descriptor, 1, + &qnn_mem_handle); + // TODO show error message + ORT_RETURN_IF(register_status != QNN_SUCCESS, "qnnInterface.memRegister() failed with error code ", register_status); + + qnn_mem_handles.emplace(ort_value_data, qnn_mem_handle); + } LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t memHandle to ORT tensor shared memory."; SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); @@ -235,6 +241,19 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in return Status::OK(); } +QnnModel::~QnnModel() { + // clean up qnn_mem_handles_ + if (!qnn_mem_handles_.empty()) { + const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); + for (const auto [addr, qnn_mem_handle] : qnn_mem_handles_) { + auto deregister_status = qnn_interface.memDeRegister(&qnn_mem_handle, 1); + if (deregister_status != QNN_SUCCESS) { + LOGS_DEFAULT(ERROR) << "qnnInterface.memDeRegister() failed with error code " << deregister_status; + } + } + } +} + Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi* rpcmem_api, const logging::Logger& logger) { LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs"; @@ -259,21 +278,6 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi std::lock_guard lock(graph_exec_mutex_); const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); - - std::vector registered_qnn_mem_handles{}; - registered_qnn_mem_handles.reserve(qnn_input_infos_.size() + qnn_output_infos_.size()); - - const auto registered_qnn_mem_handle_cleanup = - gsl::finally([®istered_qnn_mem_handles, &qnn_interface, &logger] { - if (!registered_qnn_mem_handles.empty()) { - auto deregister_status = qnn_interface.memDeRegister(registered_qnn_mem_handles.data(), - static_cast(registered_qnn_mem_handles.size())); - if (deregister_status != QNN_SUCCESS) { - LOGS(logger, ERROR) << "qnnInterface.memDeRegister() failed with error code " << deregister_status; - } - } - }); - const Qnn_ContextHandle_t qnn_context_handle = qnn_backend_manager_->GetQnnContext(); std::vector qnn_inputs; @@ -299,7 +303,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi *static_cast(ort_input_tensor.GetTensorMemoryInfo()), const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, qnn_inputs.back(), - registered_qnn_mem_handles)); + qnn_mem_handles_)); } std::vector qnn_outputs; @@ -327,7 +331,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi *static_cast(ort_output_tensor.GetTensorMemoryInfo()), const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, qnn_outputs.back(), - registered_qnn_mem_handles)); + qnn_mem_handles_)); } LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 5fca33759f7f7..3d357e3bd41ef 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -31,7 +31,7 @@ class QnnModel { qnn_backend_type_ = qnn_backend_manager_->GetQnnBackendType(); } - ~QnnModel() = default; + ~QnnModel(); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnModel); Status ComposeGraph(const GraphViewer& graph_viewer, @@ -145,6 +145,9 @@ class QnnModel { std::vector qnn_output_infos_; QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; + // shared memory addr to Qnn_MemHandle_t + std::unordered_map qnn_mem_handles_; // TODO find the right place to save mem handles + // Mutex acquired during graph execution to support multi-threaded inference of a single session. std::mutex graph_exec_mutex_; }; From 60dc83748a51dca232ab06baeb58c4b0eefd5e89 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 13 Nov 2024 11:27:57 -0800 Subject: [PATCH 06/53] Remove duplicate include. --- onnxruntime/core/providers/qnn/qnn_execution_provider.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index a4477c7df0cf7..1b53ff84b31bf 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -15,7 +15,6 @@ #include "core/platform/env.h" #include "core/providers/common.h" #include "core/providers/partitioning_utils.h" -#include "core/providers/partitioning_utils.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_def.h" From 24e072f06bf310440b323394dddff8f32e8b80c0 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 13 Nov 2024 16:34:47 -0800 Subject: [PATCH 07/53] hack, continued - move cache out to SharedContext --- .../core/providers/qnn/builder/qnn_model.cc | 47 +++----- .../core/providers/qnn/builder/qnn_model.h | 5 +- .../providers/qnn/qnn_execution_provider.cc | 14 +++ .../providers/qnn/qnn_execution_provider.h | 61 ---------- .../core/providers/qnn/shared_context.h | 113 ++++++++++++++++++ 5 files changed, 145 insertions(+), 95 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/shared_context.h diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index a656ff4328541..a79368cd162df 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -14,6 +14,7 @@ #include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/qnn_allocator.h" +#include "core/providers/qnn/shared_context.h" #include "core/providers/shared/utils/utils.h" namespace onnxruntime { @@ -193,8 +194,7 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in Qnn_ContextHandle_t qnn_context_handle, const OrtMemoryInfo& ort_value_memory_info, void* ort_value_data, uint32_t ort_value_data_size, - Qnn_Tensor_t& qnn_tensor, - std::unordered_map& qnn_mem_handles) { + Qnn_Tensor_t& qnn_tensor) { // either set qnn_tensor memHandle or clientBuf const bool uses_shared_memory = ort_value_memory_info == RpcMemAllocator::MemoryInfo(); @@ -205,14 +205,12 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in } else { ORT_RETURN_IF(rpcmem_api == nullptr, "RPCMEM API must be available when using shared memory."); - Qnn_MemHandle_t qnn_mem_handle; - auto qnn_mem_handle_it = qnn_mem_handles.find(ort_value_data); - if (qnn_mem_handle_it != qnn_mem_handles.end()) { - qnn_mem_handle = qnn_mem_handle_it->second; - } else { + const auto create_mem_handle = [&](const void* addr) { + LOGS(logger, VERBOSE) << "Registering mem handle for addr " << addr; + // get RpcMem file descriptor from shared memory - const auto shared_memory_fd = rpcmem_api->to_fd(ort_value_data); - ORT_RETURN_IF(shared_memory_fd == -1, "rpcmem_to_fd() returned invalid file descriptor."); + const auto shared_memory_fd = rpcmem_api->to_fd(const_cast(addr)); + ORT_ENFORCE(shared_memory_fd != -1, "rpcmem_to_fd() returned invalid file descriptor."); // set up QNN memory descriptor // note: we only support a single tensor per shared memory buffer (QNN_MEM_TYPE_ION) now @@ -224,14 +222,18 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in qnn_mem_descriptor.memType = QNN_MEM_TYPE_ION; qnn_mem_descriptor.ionInfo.fd = shared_memory_fd; - qnn_mem_handle = nullptr; + Qnn_MemHandle_t qnn_mem_handle = nullptr; const auto register_status = qnn_interface.memRegister(qnn_context_handle, &qnn_mem_descriptor, 1, &qnn_mem_handle); // TODO show error message - ORT_RETURN_IF(register_status != QNN_SUCCESS, "qnnInterface.memRegister() failed with error code ", register_status); + ORT_ENFORCE(register_status == QNN_SUCCESS, + "qnnInterface.memRegister() failed with error code ", register_status); - qnn_mem_handles.emplace(ort_value_data, qnn_mem_handle); - } + return qnn_mem_handle; + }; + + const Qnn_MemHandle_t qnn_mem_handle = + SharedContext::GetInstance().GetSharedMemHandles().GetOrCreate(ort_value_data, create_mem_handle); LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t memHandle to ORT tensor shared memory."; SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); @@ -241,19 +243,6 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in return Status::OK(); } -QnnModel::~QnnModel() { - // clean up qnn_mem_handles_ - if (!qnn_mem_handles_.empty()) { - const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); - for (const auto [addr, qnn_mem_handle] : qnn_mem_handles_) { - auto deregister_status = qnn_interface.memDeRegister(&qnn_mem_handle, 1); - if (deregister_status != QNN_SUCCESS) { - LOGS_DEFAULT(ERROR) << "qnnInterface.memDeRegister() failed with error code " << deregister_status; - } - } - } -} - Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi* rpcmem_api, const logging::Logger& logger) { LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs"; @@ -302,8 +291,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi qnn_context_handle, *static_cast(ort_input_tensor.GetTensorMemoryInfo()), const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, - qnn_inputs.back(), - qnn_mem_handles_)); + qnn_inputs.back())); } std::vector qnn_outputs; @@ -330,8 +318,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi qnn_context_handle, *static_cast(ort_output_tensor.GetTensorMemoryInfo()), const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, - qnn_outputs.back(), - qnn_mem_handles_)); + qnn_outputs.back())); } LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 3d357e3bd41ef..5fca33759f7f7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -31,7 +31,7 @@ class QnnModel { qnn_backend_type_ = qnn_backend_manager_->GetQnnBackendType(); } - ~QnnModel(); + ~QnnModel() = default; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnModel); Status ComposeGraph(const GraphViewer& graph_viewer, @@ -145,9 +145,6 @@ class QnnModel { std::vector qnn_output_infos_; QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; - // shared memory addr to Qnn_MemHandle_t - std::unordered_map qnn_mem_handles_; // TODO find the right place to save mem handles - // Mutex acquired during graph execution to support multi-threaded inference of a single session. std::mutex graph_exec_mutex_; }; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 1b53ff84b31bf..00c2b1f15a30b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -22,6 +22,7 @@ #include "core/providers/qnn/builder/qnn_node_group.h" #include "core/providers/qnn/qnn_allocator.h" #include "core/providers/qnn/rpcmem_library.h" +#include "core/providers/qnn/shared_context.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_run_options_config_keys.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -452,6 +453,19 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } QNNExecutionProvider::~QNNExecutionProvider() { + // hack: need somewhere to clean up the global shared memory handle state, here might be sufficient for now + // clean up shared memory handles, if any + { + const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); + const auto deregister_mem_handle = [&qnn_interface](const void* /*addr*/, Qnn_MemHandle_t qnn_mem_handle) { + auto deregister_status = qnn_interface.memDeRegister(&qnn_mem_handle, 1); + if (deregister_status != QNN_SUCCESS) { + LOGS_DEFAULT(ERROR) << "qnnInterface.memDeRegister() failed with error code " << deregister_status; + } + }; + SharedContext::GetInstance().GetSharedMemHandles().Clear(deregister_mem_handle); + } + // clean up thread local context caches std::lock_guard lock(context_state_.mutex); for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 82361adb90349..53b1cb2a6c77c 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -25,67 +25,6 @@ namespace onnxruntime { void RunOnUnload(std::function function); -class SharedContext { - public: - static SharedContext& GetInstance() { - static SharedContext instance_; - return instance_; - } - - bool HasSharedQnnModels() { - const std::lock_guard lock(mtx_); - return !shared_qnn_models_.empty(); - } - - bool HasQnnModel(const std::string& model_name) { - auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); - return it != shared_qnn_models_.end(); - } - - std::unique_ptr GetSharedQnnModel(const std::string& model_name) { - const std::lock_guard lock(mtx_); - auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); - if (it == shared_qnn_models_.end()) { - return nullptr; - } - auto qnn_model = std::move(*it); - shared_qnn_models_.erase(it); - return qnn_model; - } - - bool SetSharedQnnModel(std::vector>&& shared_qnn_models, - std::string& duplicate_graph_names) { - const std::lock_guard lock(mtx_); - bool graph_exist = false; - for (auto& shared_qnn_model : shared_qnn_models) { - auto& model_name = shared_qnn_model->Name(); - auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), - [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); - if (it == shared_qnn_models_.end()) { - shared_qnn_models_.push_back(std::move(shared_qnn_model)); - } else { - duplicate_graph_names.append(model_name + " "); - graph_exist = true; - } - } - - return graph_exist; - } - - private: - SharedContext() = default; - ~SharedContext() = default; - SharedContext(const SharedContext&) = delete; - SharedContext& operator=(const SharedContext&) = delete; - - std::vector> shared_qnn_models_; - // Producer sessions can be in parallel - // Consumer sessions have to be after producer sessions initialized - std::mutex mtx_; -}; - // Logical device representation. class QNNExecutionProvider : public IExecutionProvider { public: diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h new file mode 100644 index 0000000000000..4b38de37ba700 --- /dev/null +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License + +#include +#include +#include +#include + +#include + +#include "core/providers/qnn/builder/qnn_model.h" + +#pragma once + +namespace onnxruntime { + +class SharedMemHandles { + public: + Qnn_MemHandle_t GetOrCreate(const void* addr, const std::function& create_fn) { + std::lock_guard g{mutex_}; + Qnn_MemHandle_t& qnn_mem_handle = qnn_mem_handles_[addr]; + if (qnn_mem_handle == Qnn_MemHandle_t{}) { + qnn_mem_handle = create_fn(addr); + } + return qnn_mem_handle; + } + + void Clear(const std::function& cleanup_fn) { + std::unordered_map qnn_mem_handles_copy; + { + std::lock_guard g{mutex_}; + std::swap(qnn_mem_handles_, qnn_mem_handles_copy); + } + + if (cleanup_fn) { + for (const auto [addr, mem_handle] : qnn_mem_handles_copy) { + cleanup_fn(addr, mem_handle); + } + } + } + + private: + std::unordered_map qnn_mem_handles_; + std::mutex mutex_; +}; + +class SharedContext { + public: + static SharedContext& GetInstance() { + static SharedContext instance_; + return instance_; + } + + bool HasSharedQnnModels() { + const std::lock_guard lock(mtx_); + return !shared_qnn_models_.empty(); + } + + bool HasQnnModel(const std::string& model_name) { + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + return it != shared_qnn_models_.end(); + } + + std::unique_ptr GetSharedQnnModel(const std::string& model_name) { + const std::lock_guard lock(mtx_); + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + if (it == shared_qnn_models_.end()) { + return nullptr; + } + auto qnn_model = std::move(*it); + shared_qnn_models_.erase(it); + return qnn_model; + } + + bool SetSharedQnnModel(std::vector>&& shared_qnn_models, + std::string& duplicate_graph_names) { + const std::lock_guard lock(mtx_); + bool graph_exist = false; + for (auto& shared_qnn_model : shared_qnn_models) { + auto& model_name = shared_qnn_model->Name(); + auto it = find_if(shared_qnn_models_.begin(), shared_qnn_models_.end(), + [&model_name](const std::unique_ptr& qnn_model) { return qnn_model->Name() == model_name; }); + if (it == shared_qnn_models_.end()) { + shared_qnn_models_.push_back(std::move(shared_qnn_model)); + } else { + duplicate_graph_names.append(model_name + " "); + graph_exist = true; + } + } + + return graph_exist; + } + + SharedMemHandles& GetSharedMemHandles() { return shared_mem_handles_; } + + private: + SharedContext() = default; + ~SharedContext() = default; + SharedContext(const SharedContext&) = delete; + SharedContext& operator=(const SharedContext&) = delete; + + std::vector> shared_qnn_models_; + // Producer sessions can be in parallel + // Consumer sessions have to be after producer sessions initialized + std::mutex mtx_; + + // hack: we should tie the mem handle lifetime to the OrtValue with the shared mem data + SharedMemHandles shared_mem_handles_; +}; + +} // namespace onnxruntime From 8c515dabd6aa192bfbf38965bab2c00b13fc4f94 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:57:47 -0800 Subject: [PATCH 08/53] move mem handle registration to allocator --- .../onnxruntime/core/framework/allocator.h | 11 ++ .../core/session/onnxruntime_c_api.h | 3 + onnxruntime/core/framework/allocator.cc | 6 + onnxruntime/core/framework/tensor.cc | 7 +- .../core/providers/qnn/builder/qnn_model.cc | 144 +++++++----------- .../core/providers/qnn/builder/qnn_model.h | 1 - .../core/providers/qnn/qnn_allocator.cc | 130 ++++++++++++++-- .../core/providers/qnn/qnn_allocator.h | 11 +- .../providers/qnn/qnn_execution_provider.cc | 22 +-- .../providers/qnn/qnn_execution_provider.h | 5 +- .../core/providers/qnn/shared_context.h | 37 ++--- .../core/session/allocator_adapters.cc | 39 ++++- onnxruntime/core/session/allocator_adapters.h | 3 + 13 files changed, 265 insertions(+), 154 deletions(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 525277375830c..5aaa62f19408b 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -7,10 +7,12 @@ #include "core/common/common.h" #include "core/framework/allocator_stats.h" +#include "core/framework/data_types.h" // some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h #include "core/session/onnxruntime_c_api.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" +#include "core/framework/tensor_shape.h" // This configures the arena based allocator used by ORT // See docs/C_API.md for details on what these mean and how to choose these values @@ -84,6 +86,15 @@ class IAllocator { virtual void Free(void* p) = 0; + /** + * Allocate memory for a tensor of the given shape and element data type. + * If the tensor size is 0, nullptr is returned. + * On other failures, an exception is thrown. + * + * Note: The default implementation will call Alloc(). + */ + virtual void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape); + // Reserve() is an interface exposed for an implementation of IAllocator // to optionally implement some allocation logic that by-passes any arena-based // logic that may be housed in the Alloc() implementation. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index b0c5d2329c428..911bc3955edf6 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -329,6 +329,9 @@ typedef struct OrtAllocator { * those made during session initialization. This allows for separate memory management strategies for these allocations. */ void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes + // TODO docs + void*(ORT_API_CALL* TensorAlloc)(struct OrtAllocator* this_, + const int64_t* shape, size_t shape_len, ONNXTensorElementDataType element_data_type); } OrtAllocator; typedef void(ORT_API_CALL* OrtLoggingFunction)( diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 02dbb3e518783..a7eb82148fc49 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -14,9 +14,15 @@ #endif #include "core/framework/bfc_arena.h" +#include "core/framework/tensor.h" namespace onnxruntime { +void* IAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { + const auto size_in_bytes = Tensor::CalculateTensorStorageSize(element_data_type, shape); + return Alloc(size_in_bytes); +} + // private helper for calculation so SafeInt usage doesn't bleed into the public allocator.h header bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { bool ok = true; diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index 60d768cc59a5d..ea80f55ac0327 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -87,12 +87,7 @@ Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, cons Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator) : alloc_info_(allocator->Info()) { ORT_ENFORCE(elt_type != nullptr); - size_t len = Tensor::CalculateTensorStorageSize(elt_type, shape); - - void* p_data = nullptr; - if (len > 0) { - p_data = allocator->Alloc(len); - } + void* p_data = allocator->TensorAlloc(elt_type, shape); Init(elt_type, shape, p_data, allocator, 0L); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 0fd6ffde0b8c8..07b01bca3522e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -188,10 +188,7 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { return Status::OK(); } -static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_interface, - const RpcMemApi* rpcmem_api, - const logging::Logger& logger, - Qnn_ContextHandle_t qnn_context_handle, +static Status BindQnnTensorMemoryToOrtValue(const logging::Logger& logger, const OrtMemoryInfo& ort_value_memory_info, void* ort_value_data, uint32_t ort_value_data_size, Qnn_Tensor_t& qnn_tensor) { @@ -203,39 +200,8 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_RAW); SetQnnTensorClientBuf(qnn_tensor, ort_value_data, ort_value_data_size); } else { - ORT_RETURN_IF(rpcmem_api == nullptr, "RPCMEM API must be available when using shared memory."); - - const auto create_mem_handle = [&](const void* addr) { - LOGS(logger, VERBOSE) << "Registering mem handle for addr " << addr; - - // get RpcMem file descriptor from shared memory - const auto shared_memory_fd = rpcmem_api->to_fd(const_cast(addr)); - ORT_ENFORCE(shared_memory_fd != -1, "rpcmem_to_fd() returned invalid file descriptor."); - - // set up QNN memory descriptor - // note: we only support a single tensor per shared memory buffer (QNN_MEM_TYPE_ION) now - Qnn_MemDescriptor_t qnn_mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; - qnn_mem_descriptor.memShape = {GetQnnTensorRank(qnn_tensor), - GetQnnTensorDims(qnn_tensor), - nullptr}; - qnn_mem_descriptor.dataType = GetQnnTensorDataType(qnn_tensor); - qnn_mem_descriptor.memType = QNN_MEM_TYPE_ION; - qnn_mem_descriptor.ionInfo.fd = shared_memory_fd; - - Qnn_MemHandle_t qnn_mem_handle = nullptr; - const auto register_status = qnn_interface.memRegister(qnn_context_handle, &qnn_mem_descriptor, 1, - &qnn_mem_handle); - // TODO show error message - ORT_ENFORCE(register_status == QNN_SUCCESS, - "qnnInterface.memRegister() failed with error code ", register_status); - - return qnn_mem_handle; - }; - - const Qnn_MemHandle_t qnn_mem_handle = - SharedContext::GetInstance().GetSharedMemHandles().GetOrCreate(ort_value_data, create_mem_handle); - LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t memHandle to ORT tensor shared memory."; + const Qnn_MemHandle_t qnn_mem_handle = SharedContext::GetInstance().GetSharedMemHandles().Get(ort_value_data); SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); SetQnnTensorMemHandle(qnn_tensor, qnn_mem_handle); } @@ -243,7 +209,7 @@ static Status BindQnnTensorMemoryToOrtValue(const QNN_INTERFACE_VER_TYPE& qnn_in return Status::OK(); } -Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi* rpcmem_api, +Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const logging::Logger& logger) { LOGS(logger, VERBOSE) << "QnnModel::ExecuteGraphs"; const size_t num_inputs = context.GetInputCount(); @@ -260,66 +226,58 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, const RpcMemApi return element_size * length; }; - Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; + std::vector qnn_inputs; + qnn_inputs.reserve(qnn_input_infos_.size()); + + for (const auto& qnn_input_info : qnn_input_infos_) { + LOGS(logger, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() + << " index = " << qnn_input_info.ort_index; + auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); + auto ort_tensor_size = TensorDataSize(ort_input_tensor); + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size + << " Ort tensor size: " << ort_tensor_size; + ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, + "ORT Tensor data size does not match QNN tensor data size."); + + qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); + + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + logger, + *static_cast(ort_input_tensor.GetTensorMemoryInfo()), + const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, + qnn_inputs.back())); + } - { - // Acquire mutex before calling QNN APIs to support calling session.Run() from multiple threads. - std::lock_guard lock(graph_exec_mutex_); + std::vector qnn_outputs; + qnn_outputs.reserve(qnn_output_infos_.size()); + + for (auto& qnn_output_info : qnn_output_infos_) { + const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); + LOGS(logger, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; + const auto& ort_output_info = GetOutputInfo(model_output_name); + const std::vector& output_shape = ort_output_info->shape_; + auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); + auto ort_tensor_size = TensorDataSize(ort_output_tensor); + LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size + << " Ort tensor size: " << ort_tensor_size; + ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, + "ORT Tensor data size does not match QNN tensor data size"); + + qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); + + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + logger, + *static_cast(ort_output_tensor.GetTensorMemoryInfo()), + const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, + qnn_outputs.back())); + } + Qnn_ErrorHandle_t execute_status = QNN_GRAPH_NO_ERROR; + { const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); - const Qnn_ContextHandle_t qnn_context_handle = qnn_backend_manager_->GetQnnContext(); - - std::vector qnn_inputs; - qnn_inputs.reserve(qnn_input_infos_.size()); - - for (const auto& qnn_input_info : qnn_input_infos_) { - LOGS(logger, VERBOSE) << "model_input = " << qnn_input_info.tensor_wrapper->GetName() - << " index = " << qnn_input_info.ort_index; - auto ort_input_tensor = context.GetInput(qnn_input_info.ort_index); - auto ort_tensor_size = TensorDataSize(ort_input_tensor); - LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size - << " Ort tensor size: " << ort_tensor_size; - ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, - "ORT Tensor data size does not match QNN tensor data size."); - - qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); - - ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( - qnn_interface, - rpcmem_api, - logger, - qnn_context_handle, - *static_cast(ort_input_tensor.GetTensorMemoryInfo()), - const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, - qnn_inputs.back())); - } - std::vector qnn_outputs; - qnn_outputs.reserve(qnn_output_infos_.size()); - - for (auto& qnn_output_info : qnn_output_infos_) { - const std::string& model_output_name = qnn_output_info.tensor_wrapper->GetName(); - LOGS(logger, VERBOSE) << "model_output = " << model_output_name << " index = " << qnn_output_info.ort_index; - const auto& ort_output_info = GetOutputInfo(model_output_name); - const std::vector& output_shape = ort_output_info->shape_; - auto ort_output_tensor = context.GetOutput(qnn_output_info.ort_index, output_shape.data(), output_shape.size()); - auto ort_tensor_size = TensorDataSize(ort_output_tensor); - LOGS(logger, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size - << " Ort tensor size: " << ort_tensor_size; - ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, - "ORT Tensor data size does not match QNN tensor data size"); - - qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); - - ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( - qnn_interface, - rpcmem_api, - logger, - qnn_context_handle, - *static_cast(ort_output_tensor.GetTensorMemoryInfo()), - const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, - qnn_outputs.back())); - } + // Acquire mutex before calling QNN APIs to support calling session.Run() from multiple threads. + std::lock_guard lock(graph_exec_mutex_); LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 5fca33759f7f7..85d50eff09d67 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -45,7 +45,6 @@ class QnnModel { Status SetupQnnInputOutput(const logging::Logger& logger); Status ExecuteGraph(const Ort::KernelContext& context, - const RpcMemApi* rpcmem_api, const logging::Logger& logger); const OnnxTensorInfo* GetOutputInfo(const std::string& name) const { diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index e9320bbcdb5f2..5389af1eb1385 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -3,38 +3,146 @@ #include "core/providers/qnn/qnn_allocator.h" +#include #include +#include + #include "core/common/common.h" -#include "core/providers/qnn/rpcmem_library.h" +#include "core/common/logging/logging.h" +#include "core/common/inlined_containers.h" +#include "core/common/narrow.h" +#include "core/framework/tensor.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/shared_context.h" // for shared mem handle access namespace onnxruntime::qnn { +namespace { + +Qnn_MemHandle_t RegisterQnnMemHandle(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ContextHandle_t qnn_context_handle, + int shared_memory_fd, + MLDataType element_data_type, const TensorShape& shape) { + auto qnn_shape = [shape_span = shape.GetDims()]() { + InlinedVector qnn_shape; + std::transform(shape_span.begin(), shape_span.end(), std::back_inserter(qnn_shape), + [](int64_t dim) { return narrow(dim); }); + return qnn_shape; + }(); + + const auto qnn_data_type = [element_data_type]() { + Qnn_DataType_t qnn_data_type; + ORT_ENFORCE(element_data_type->IsPrimitiveDataType()); + const auto onnx_data_type = element_data_type->AsPrimitiveDataType()->GetDataType(); + const bool is_quantized = false; // TODO how should we set this? + if (!utils::OnnxDataTypeToQnnDataType(onnx_data_type, qnn_data_type, is_quantized)) { + ORT_THROW("Unable to get QNN data type from ONNX data type: ", onnx_data_type); + } + return qnn_data_type; + }(); + + // set up QNN memory descriptor + Qnn_MemDescriptor_t qnn_mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; + qnn_mem_descriptor.memShape = {narrow(qnn_shape.size()), + qnn_shape.data(), + nullptr}; + qnn_mem_descriptor.dataType = qnn_data_type; + qnn_mem_descriptor.memType = QNN_MEM_TYPE_ION; + qnn_mem_descriptor.ionInfo.fd = shared_memory_fd; + + Qnn_MemHandle_t qnn_mem_handle = nullptr; + const auto register_status = qnn_interface.memRegister(qnn_context_handle, &qnn_mem_descriptor, 1, + &qnn_mem_handle); + // TODO show error message + ORT_ENFORCE(register_status == QNN_SUCCESS, + "qnn_interface.memRegister() failed with error code ", register_status); + + return qnn_mem_handle; +} + +void DeregisterQnnMemHandle(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_MemHandle_t qnn_mem_handle) { + const auto deregister_status = qnn_interface.memDeRegister(&qnn_mem_handle, 1); + // TODO show error message + if (deregister_status != QNN_SUCCESS) { + LOGS_DEFAULT(ERROR) << "qnn_interface.memDeRegister() failed with error code " << deregister_status; + } +} + +using RpcMemUniquePtr = std::unique_ptr; + +RpcMemUniquePtr WrapSharedMemoryWithUniquePtr(void* shared_memory_raw, const RpcMemApi& rpcmem_api) { + return {shared_memory_raw, rpcmem_api.free}; +} + +} // namespace + OrtMemoryInfo RpcMemAllocator::MemoryInfo() { return OrtMemoryInfo{QNN_HTP_SHARED, OrtAllocatorType::OrtDeviceAllocator, OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device_id */ 0}, /* id */ 0, OrtMemTypeDefault}; } -RpcMemAllocator::RpcMemAllocator(std::shared_ptr rpc_mem_lib) +RpcMemAllocator::RpcMemAllocator(std::shared_ptr rpcmem_lib, + std::shared_ptr qnn_backend_manager) : IAllocator{MemoryInfo()}, - rpc_mem_lib_{std::move(rpc_mem_lib)} { - ORT_ENFORCE(rpc_mem_lib_ != nullptr, "rpc_mem_lib_ must not be nullptr"); + rpcmem_lib_{std::move(rpcmem_lib)}, + qnn_backend_manager_{std::move(qnn_backend_manager)} { + ORT_ENFORCE(rpcmem_lib_ != nullptr); + ORT_ENFORCE(qnn_backend_manager_ != nullptr); +} + +void* RpcMemAllocator::Alloc(size_t /* size */) { + LOGS_DEFAULT(ERROR) << "hey this ain't right"; + std::exit(1); + ORT_THROW("RpcMemAllocator::Alloc() is not implemented. Use RpcMemAllocator::TensorAlloc() instead."); } -void* RpcMemAllocator::Alloc(size_t size) { - // rpcmem_alloc() has an int size parameter. - constexpr size_t max_size = std::numeric_limits::max(); - if (size > max_size) { +void* RpcMemAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { + const auto size_in_bytes = Tensor::CalculateTensorStorageSize(element_data_type, shape); + + if (size_in_bytes == 0) { return nullptr; } - return rpc_mem_lib_->Api().alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, - static_cast(size)); + // rpcmem_alloc() has an int size parameter. make sure we don't overflow. + constexpr size_t max_size_in_bytes = std::numeric_limits::max(); + ORT_ENFORCE(size_in_bytes <= max_size_in_bytes, + "Allocation size (", size_in_bytes, ") is larger than maximum allowed (", max_size_in_bytes, ")."); + + // allocate shared memory + void* shared_memory_raw = rpcmem_lib_->Api().alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, + static_cast(size_in_bytes)); + + auto shared_memory = WrapSharedMemoryWithUniquePtr(shared_memory_raw, rpcmem_lib_->Api()); + + // get shared memory fd + const auto shared_memory_fd = rpcmem_lib_->Api().to_fd(shared_memory.get()); + ORT_ENFORCE(shared_memory_fd != -1, "rpcmem_to_fd() returned invalid file descriptor."); + + // register mem handle + // TODO synchronize calls to qnn_interface.memRegister()? + const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); + const auto qnn_context_handle = qnn_backend_manager_->GetQnnContext(); + const auto qnn_mem_handle = RegisterQnnMemHandle(qnn_interface, qnn_context_handle, + shared_memory_fd, element_data_type, shape); + + // save mem handle. for now, the global SharedContext will do... + SharedContext::GetInstance().GetSharedMemHandles().Add(shared_memory.get(), qnn_mem_handle); + + return shared_memory.release(); } void RpcMemAllocator::Free(void* p) { - rpc_mem_lib_->Api().free(p); + // take ownership of shared memory and free at end of scope + auto shared_memory = WrapSharedMemoryWithUniquePtr(p, rpcmem_lib_->Api()); + + // deregister mem handle + // TODO synchronize calls to qnn_interface.memDeRegister()? + const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); + const auto qnn_mem_handle = SharedContext::GetInstance().GetSharedMemHandles().GetAndRemove(p); + DeregisterQnnMemHandle(qnn_interface, qnn_mem_handle); } } // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index 8a38c626cd809..6866189c5a084 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -7,8 +7,12 @@ #include "core/framework/allocator.h" +#include "core/providers/qnn/builder/qnn_backend_manager.h" +#include "core/providers/qnn/rpcmem_library.h" + namespace onnxruntime::qnn { +class QnnBackendManager; class RpcMemLibrary; class RpcMemAllocator : public IAllocator { @@ -16,14 +20,17 @@ class RpcMemAllocator : public IAllocator { // Gets the single OrtMemoryInfo value that is associated with this allocator type. static OrtMemoryInfo MemoryInfo(); - RpcMemAllocator(std::shared_ptr rpc_mem_lib); + RpcMemAllocator(std::shared_ptr rpcmem_lib, + std::shared_ptr qnn_backend_manager); void* Alloc(size_t size) override; + void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape) override; void Free(void* p) override; // void GetStats(AllocatorStats* stats) override; private: - std::shared_ptr rpc_mem_lib_; + std::shared_ptr rpcmem_lib_; + std::shared_ptr qnn_backend_manager_; }; } // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 00c2b1f15a30b..bdccc64a3b8dd 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -396,7 +396,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio rpcmem_library_ = std::make_shared(); } - qnn_backend_manager_ = std::make_unique( + qnn_backend_manager_ = std::make_shared( std::move(backend_path), profiling_level_etw, profiling_level, @@ -453,19 +453,6 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } QNNExecutionProvider::~QNNExecutionProvider() { - // hack: need somewhere to clean up the global shared memory handle state, here might be sufficient for now - // clean up shared memory handles, if any - { - const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); - const auto deregister_mem_handle = [&qnn_interface](const void* /*addr*/, Qnn_MemHandle_t qnn_mem_handle) { - auto deregister_status = qnn_interface.memDeRegister(&qnn_mem_handle, 1); - if (deregister_status != QNN_SUCCESS) { - LOGS_DEFAULT(ERROR) << "qnnInterface.memDeRegister() failed with error code " << deregister_status; - } - }; - SharedContext::GetInstance().GetSharedMemHandles().Clear(deregister_mem_handle); - } - // clean up thread local context caches std::lock_guard lock(context_state_.mutex); for (const auto& cache_weak : context_state_.caches_to_update_on_destruction) { @@ -837,11 +824,10 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod ORT_UNUSED_PARAMETER(state); }; - compute_info.compute_func = [this, &logger](FunctionState state, const OrtApi*, OrtKernelContext* context) { + compute_info.compute_func = [&logger](FunctionState state, const OrtApi*, OrtKernelContext* context) { Ort::KernelContext ctx(context); - const qnn::RpcMemApi* rpcmem_api = rpcmem_library_ ? &rpcmem_library_->Api() : nullptr; qnn::QnnModel* model = reinterpret_cast(state); - Status result = model->ExecuteGraph(ctx, rpcmem_api, logger); + Status result = model->ExecuteGraph(ctx, logger); return result; }; @@ -1184,7 +1170,7 @@ std::vector QNNExecutionProvider::CreatePreferredAllocators() { LOGS_DEFAULT(INFO) << "Creating RpcMemAllocator."; AllocatorFactory rpcmem_allocator_factory = [this](OrtDevice::DeviceId) { - return std::make_unique(rpcmem_library_); + return std::make_unique(rpcmem_library_, qnn_backend_manager_); }; AllocatorCreationInfo rpcmem_allocator_creation_info{rpcmem_allocator_factory, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 53b1cb2a6c77c..18fdef9a7e3f5 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -79,7 +79,8 @@ class QNNExecutionProvider : public IExecutionProvider { private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; - std::unique_ptr qnn_backend_manager_; + // This is potentially shared with RpcMemAllocator which may be returned by CreatePreferredAllocators(). + std::shared_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; @@ -101,7 +102,7 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::ModelSettings model_settings_ = {}; // Whether this is set depends on a session option enabling it and if the RPCMEM dynamic library is available. - // It is shared with RpcMemAllocator which is returned by CreatePreferredAllocators(). + // This is potentially shared with RpcMemAllocator which may be returned by CreatePreferredAllocators(). std::shared_ptr rpcmem_library_ = nullptr; class PerThreadContext final { diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index 4b38de37ba700..4ce4aa15029a3 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -8,6 +8,7 @@ #include +#include "core/common/common.h" #include "core/providers/qnn/builder/qnn_model.h" #pragma once @@ -16,27 +17,27 @@ namespace onnxruntime { class SharedMemHandles { public: - Qnn_MemHandle_t GetOrCreate(const void* addr, const std::function& create_fn) { + Qnn_MemHandle_t Get(const void* addr) { std::lock_guard g{mutex_}; - Qnn_MemHandle_t& qnn_mem_handle = qnn_mem_handles_[addr]; - if (qnn_mem_handle == Qnn_MemHandle_t{}) { - qnn_mem_handle = create_fn(addr); - } - return qnn_mem_handle; + const auto it = qnn_mem_handles_.find(addr); + ORT_ENFORCE(it != qnn_mem_handles_.end(), "Failed to find mem handle associated with address (", addr, ")."); + return it->second; } - void Clear(const std::function& cleanup_fn) { - std::unordered_map qnn_mem_handles_copy; - { - std::lock_guard g{mutex_}; - std::swap(qnn_mem_handles_, qnn_mem_handles_copy); - } + void Add(const void* addr, Qnn_MemHandle_t mem_handle) { + std::lock_guard g{mutex_}; + auto [it, added] = qnn_mem_handles_.emplace(addr, mem_handle); + ORT_ENFORCE(added, + "There is already a mem handle (", mem_handle, ") associated with the address (", addr, ")."); + } - if (cleanup_fn) { - for (const auto [addr, mem_handle] : qnn_mem_handles_copy) { - cleanup_fn(addr, mem_handle); - } - } + Qnn_MemHandle_t GetAndRemove(const void* addr) { + std::lock_guard g{mutex_}; + const auto it = qnn_mem_handles_.find(addr); + ORT_ENFORCE(it != qnn_mem_handles_.end(), "Failed to find mem handle associated with address (", addr, ")."); + const auto qnn_mem_handle = it->second; + qnn_mem_handles_.erase(it); + return qnn_mem_handle; } private: @@ -106,7 +107,7 @@ class SharedContext { // Consumer sessions have to be after producer sessions initialized std::mutex mtx_; - // hack: we should tie the mem handle lifetime to the OrtValue with the shared mem data + // TODO can we avoid keeping mem handles in SharedContext? SharedMemHandles shared_mem_handles_; }; diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index ac5ea75453558..2397b128e8163 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -2,12 +2,19 @@ // Licensed under the MIT License. #include "allocator_adapters.h" +#include "core/framework/data_types.h" +#include "core/framework/error_code_helper.h" #include "core/session/inference_session.h" #include "core/session/ort_env.h" #include "core/session/ort_apis.h" -#include "core/framework/error_code_helper.h" namespace onnxruntime { + +namespace { +constexpr uint32_t kOrtAllocatorReserveMinVersion = 18; +constexpr uint32_t kOrtAllocatorTensorAllocMinVersion = 21; +} // namespace + OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxruntime::AllocatorPtr&& i_allocator) : i_allocator_(std::move(i_allocator)) { OrtAllocator::version = ORT_API_VERSION; @@ -17,10 +24,17 @@ OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxrunti [](OrtAllocator* this_, void* p) { static_cast(this_)->Free(p); }; OrtAllocator::Info = [](const OrtAllocator* this_) { return static_cast(this_)->Info(); }; - if (OrtAllocator::version >= 18) { + if (OrtAllocator::version >= kOrtAllocatorReserveMinVersion) { OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Reserve(size); }; } + if (OrtAllocator::version >= kOrtAllocatorTensorAllocMinVersion) { + OrtAllocator::TensorAlloc = + [](OrtAllocator* this_, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType element_data_type) { + return static_cast(this_)->TensorAlloc(shape, shape_len, + element_data_type); + }; + } } void* OrtAllocatorImplWrappingIAllocator::Alloc(size_t size) { @@ -31,6 +45,13 @@ void* OrtAllocatorImplWrappingIAllocator::Reserve(size_t size) { return i_allocator_->Reserve(size); } +void* OrtAllocatorImplWrappingIAllocator::TensorAlloc(const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType onnx_element_data_type) { + const auto tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(onnx_element_data_type); + const TensorShape tensor_shape(gsl::span{shape, shape_len}); + return i_allocator_->TensorAlloc(tensor_type->GetElementType(), tensor_shape); +} + void OrtAllocatorImplWrappingIAllocator::Free(void* p) { i_allocator_->Free(p); } @@ -51,13 +72,25 @@ void* IAllocatorImplWrappingOrtAllocator::Alloc(size_t size) { } void* IAllocatorImplWrappingOrtAllocator::Reserve(size_t size) { - if (ort_allocator_->version >= 18 && ort_allocator_->Reserve) { + if (ort_allocator_->version >= kOrtAllocatorReserveMinVersion && ort_allocator_->Reserve) { return ort_allocator_->Reserve(ort_allocator_, size); } return ort_allocator_->Alloc(ort_allocator_, size); } +void* IAllocatorImplWrappingOrtAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { + if (ort_allocator_->version >= kOrtAllocatorTensorAllocMinVersion && ort_allocator_->TensorAlloc) { + const auto shape_span = shape.GetDims(); + ORT_ENFORCE(element_data_type->IsPrimitiveDataType()); + const auto onnx_element_data_type = + static_cast(element_data_type->AsPrimitiveDataType()->GetDataType()); + return ort_allocator_->TensorAlloc(ort_allocator_, shape_span.data(), shape_span.size(), onnx_element_data_type); + } + + return IAllocator::TensorAlloc(element_data_type, shape); +} + void IAllocatorImplWrappingOrtAllocator::Free(void* p) { return ort_allocator_->Free(ort_allocator_, p); } diff --git a/onnxruntime/core/session/allocator_adapters.h b/onnxruntime/core/session/allocator_adapters.h index 48f4ea03118c8..a8f3b6460574f 100644 --- a/onnxruntime/core/session/allocator_adapters.h +++ b/onnxruntime/core/session/allocator_adapters.h @@ -29,6 +29,8 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl { const OrtMemoryInfo* Info() const; void* Reserve(size_t size); + void* TensorAlloc(const int64_t* shape, size_t shape_len, ONNXTensorElementDataType element_data_type); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtAllocatorImplWrappingIAllocator); onnxruntime::AllocatorPtr GetWrappedIAllocator(); @@ -45,6 +47,7 @@ class IAllocatorImplWrappingOrtAllocator final : public IAllocator { void* Alloc(size_t size) override; void* Reserve(size_t size) override; + void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape) override; void Free(void* p) override; From 18e2780b4f6055090f9f6b2d1d81adc8ada1efbe Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:58:05 -0800 Subject: [PATCH 09/53] hook up some test code --- .../optimizer/graph_transform_test_builder.h | 18 ++++++--- .../test/providers/qnn/max_min_op_test.cc | 37 +++++++++++++++++-- .../test/providers/qnn/qnn_test_utils.h | 33 ++++++++++------- 3 files changed, 64 insertions(+), 24 deletions(-) diff --git a/onnxruntime/test/optimizer/graph_transform_test_builder.h b/onnxruntime/test/optimizer/graph_transform_test_builder.h index f641c597acf07..88ad49329f929 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_builder.h +++ b/onnxruntime/test/optimizer/graph_transform_test_builder.h @@ -82,7 +82,11 @@ class ModelTestBuilder { } template - NodeArg* MakeInput(const std::vector& shape, const std::vector& data) { + NodeArg* MakeInput(const std::vector& shape, const std::vector& data, + AllocatorPtr allocator = nullptr) { + if (!allocator) { + allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + } ONNX_NAMESPACE::TypeProto type_proto; type_proto.mutable_tensor_type()->set_elem_type(utils::ToTensorProtoElementType()); @@ -93,7 +97,7 @@ class ModelTestBuilder { } OrtValue input_value; - CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + CreateMLValue(allocator, shape, data, &input_value); @@ -104,17 +108,19 @@ class ModelTestBuilder { } template - NodeArg* MakeInput(const std::vector& shape, T min, T max) { - return MakeInput(shape, rand_gen_.Uniform(shape, min, max)); + NodeArg* MakeInput(const std::vector& shape, T min, T max, + AllocatorPtr allocator = nullptr) { + return MakeInput(shape, rand_gen_.Uniform(shape, min, max), allocator); } - NodeArg* MakeInputBool(const std::vector& shape) { + NodeArg* MakeInputBool(const std::vector& shape, + AllocatorPtr allocator = nullptr) { std::vector data_uint8 = rand_gen_.Uniform(shape, 0, 1); std::vector data; for (uint8_t x : data_uint8) { data.push_back(x != 0); } - return MakeInput(shape, data); + return MakeInput(shape, data, allocator); } template diff --git a/onnxruntime/test/providers/qnn/max_min_op_test.cc b/onnxruntime/test/providers/qnn/max_min_op_test.cc index 3deff121f3c72..6e0f9f191cf47 100644 --- a/onnxruntime/test/providers/qnn/max_min_op_test.cc +++ b/onnxruntime/test/providers/qnn/max_min_op_test.cc @@ -39,20 +39,30 @@ template static void RunQDQMinOrMaxOpTest(const std::string& op_type, const std::vector>& input_defs, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13) { + int opset = 13, + AllocatorPtr io_allocator = nullptr, + const ProviderOptions& extra_provider_options = {}) { ProviderOptions provider_options; + if (!extra_provider_options.empty()) { + provider_options.insert(extra_provider_options.begin(), extra_provider_options.end()); + } + #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // baseline float32 model - BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // QDQ model + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain, + io_allocator), // baseline float32 model + BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain, /* use_contrib_qdq*/ false, + io_allocator), // QDQ model provider_options, opset, - expected_ep_assignment); + expected_ep_assignment, + {}, + logging::Severity::kVERBOSE); } // @@ -128,6 +138,25 @@ TEST_F(QnnHTPBackendTests, Max_2Inputs) { ExpectedEPNodeAssignment::All, 13); } +// Test accuracy of 8-bit Q/DQ Min with 2 inputs on HTP backend. +TEST_F(QnnHTPBackendTests, Min_2Inputs_HtpSharedMemoryAllocator) { + ProviderOptions qnn_ep_options{ + {"enable_htp_shared_memory_allocator", "1"}, + {"backend_path", "libQnnHtp.so"}, + }; + + AllocatorPtr htp_shared_memory_allocator = + QnnExecutionProviderWithOptions(qnn_ep_options)->CreatePreferredAllocators()[0]; + + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQMinOrMaxOpTest("Min", + {TestInputDef({1, 3, 4, 4}, false, input_data), + TestInputDef({1, 3, 4, 4}, false, input_data)}, + ExpectedEPNodeAssignment::All, 13, + htp_shared_memory_allocator, + qnn_ep_options); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index a8670252ff9e0..6c8ae5392bee4 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -904,7 +904,8 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, * \return A pointer to the new input. */ template -inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def) { +inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def, + AllocatorPtr io_allocator = nullptr) { NodeArg* input = nullptr; const auto& shape = input_def.GetShape(); const bool is_initializer = input_def.IsInitializer(); @@ -915,7 +916,7 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& if (is_initializer) { input = builder.MakeInitializer(shape, raw_data); } else { - input = builder.MakeInput(shape, raw_data); + input = builder.MakeInput(shape, raw_data, io_allocator); } } else { // Random data const auto& rand_info = input_def.GetRandomDataInfo(); @@ -923,7 +924,7 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& if (is_initializer) { input = builder.MakeInitializer(shape, rand_info.min, rand_info.max); } else { - input = builder.MakeInput(shape, rand_info.min, rand_info.max); + input = builder.MakeInput(shape, rand_info.min, rand_info.max, io_allocator); } } @@ -931,7 +932,8 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& } template <> -inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def) { +inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def, + AllocatorPtr io_allocator) { NodeArg* input = nullptr; const auto& shape = input_def.GetShape(); const bool is_initializer = input_def.IsInitializer(); @@ -942,13 +944,13 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef(shape, raw_data); + input = builder.MakeInput(shape, raw_data, io_allocator); } } else { // Random data if (is_initializer) { input = builder.MakeRandInitializerBool(shape); } else { - input = builder.MakeInputBool(shape); + input = builder.MakeInputBool(shape, io_allocator); } } @@ -980,18 +982,19 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, const std::vector>& input_defs_1, const std::vector>& input_defs_2, const std::vector& attrs, - const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs_1, input_defs_2, attrs, op_domain](ModelTestBuilder& builder) { + const std::string& op_domain = kOnnxDomain, + AllocatorPtr io_allocator = nullptr) { + return [op_type, input_defs_1, input_defs_2, attrs, op_domain, io_allocator](ModelTestBuilder& builder) { std::vector op_inputs; op_inputs.reserve(input_defs_1.size() + input_defs_2.size()); for (const auto& input_def : input_defs_1) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, io_allocator); op_inputs.push_back(input); } for (const auto& input_def : input_defs_2) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, io_allocator); op_inputs.push_back(input); } @@ -1021,15 +1024,17 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( const std::vector>& non_quant_input_defs, const std::vector& attrs, const std::string& op_domain = kOnnxDomain, - bool use_contrib_qdq = false) { + bool use_contrib_qdq = false, + AllocatorPtr io_allocator = nullptr) { return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, - use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { + use_contrib_qdq, io_allocator]( + ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size()); // Create QDQ inputs for (const auto& input_def : quant_input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, io_allocator); QuantParams input_qparams = GetTestInputQuantParams(input_def); NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, use_contrib_qdq); @@ -1038,7 +1043,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( // Create non-QDQ inputs for (const auto& input_def : non_quant_input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* input = MakeTestInput(builder, input_def, io_allocator); op_inputs.push_back(input); } From a65bb71c2979c1a742f959ecc42866bb89582af9 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 27 Nov 2024 11:31:15 -0800 Subject: [PATCH 10/53] rename to RpcMemAllocator to HtpSharedMemoryAllocator --- .../core/providers/qnn/builder/qnn_model.cc | 2 +- onnxruntime/core/providers/qnn/qnn_allocator.cc | 14 +++++++------- onnxruntime/core/providers/qnn/qnn_allocator.h | 6 +++--- .../core/providers/qnn/qnn_execution_provider.cc | 6 +++--- .../core/providers/qnn/qnn_execution_provider.h | 4 ++-- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 07b01bca3522e..d991759f1a731 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -193,7 +193,7 @@ static Status BindQnnTensorMemoryToOrtValue(const logging::Logger& logger, void* ort_value_data, uint32_t ort_value_data_size, Qnn_Tensor_t& qnn_tensor) { // either set qnn_tensor memHandle or clientBuf - const bool uses_shared_memory = ort_value_memory_info == RpcMemAllocator::MemoryInfo(); + const bool uses_shared_memory = ort_value_memory_info == HtpSharedMemoryAllocator::MemoryInfo(); if (!uses_shared_memory) { LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t clientBuf to ORT tensor memory."; diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 5389af1eb1385..6798de0d5527b 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -78,14 +78,14 @@ RpcMemUniquePtr WrapSharedMemoryWithUniquePtr(void* shared_memory_raw, const Rpc } // namespace -OrtMemoryInfo RpcMemAllocator::MemoryInfo() { +OrtMemoryInfo HtpSharedMemoryAllocator::MemoryInfo() { return OrtMemoryInfo{QNN_HTP_SHARED, OrtAllocatorType::OrtDeviceAllocator, OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device_id */ 0}, /* id */ 0, OrtMemTypeDefault}; } -RpcMemAllocator::RpcMemAllocator(std::shared_ptr rpcmem_lib, - std::shared_ptr qnn_backend_manager) +HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, + std::shared_ptr qnn_backend_manager) : IAllocator{MemoryInfo()}, rpcmem_lib_{std::move(rpcmem_lib)}, qnn_backend_manager_{std::move(qnn_backend_manager)} { @@ -93,13 +93,13 @@ RpcMemAllocator::RpcMemAllocator(std::shared_ptr rpcmem_lib, ORT_ENFORCE(qnn_backend_manager_ != nullptr); } -void* RpcMemAllocator::Alloc(size_t /* size */) { +void* HtpSharedMemoryAllocator::Alloc(size_t /* size */) { LOGS_DEFAULT(ERROR) << "hey this ain't right"; std::exit(1); - ORT_THROW("RpcMemAllocator::Alloc() is not implemented. Use RpcMemAllocator::TensorAlloc() instead."); + ORT_THROW("HtpSharedMemoryAllocator::Alloc() is not implemented. Use HtpSharedMemoryAllocator::TensorAlloc() instead."); } -void* RpcMemAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { +void* HtpSharedMemoryAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { const auto size_in_bytes = Tensor::CalculateTensorStorageSize(element_data_type, shape); if (size_in_bytes == 0) { @@ -134,7 +134,7 @@ void* RpcMemAllocator::TensorAlloc(MLDataType element_data_type, const TensorSha return shared_memory.release(); } -void RpcMemAllocator::Free(void* p) { +void HtpSharedMemoryAllocator::Free(void* p) { // take ownership of shared memory and free at end of scope auto shared_memory = WrapSharedMemoryWithUniquePtr(p, rpcmem_lib_->Api()); diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index 6866189c5a084..0e80df5c2a175 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -15,13 +15,13 @@ namespace onnxruntime::qnn { class QnnBackendManager; class RpcMemLibrary; -class RpcMemAllocator : public IAllocator { +class HtpSharedMemoryAllocator : public IAllocator { public: // Gets the single OrtMemoryInfo value that is associated with this allocator type. static OrtMemoryInfo MemoryInfo(); - RpcMemAllocator(std::shared_ptr rpcmem_lib, - std::shared_ptr qnn_backend_manager); + HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, + std::shared_ptr qnn_backend_manager); void* Alloc(size_t size) override; void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape) override; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index bdccc64a3b8dd..f8af1752bbc62 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -392,7 +392,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio static const std::string QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED = "enable_htp_shared_memory_allocator"; if (ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map)) { // Initialize rpcmem_library_. - // This is necessary for RpcMemAllocator to function and also indicates that the allocator is available. + // This is necessary for HtpSharedMemoryAllocator to function and also indicates that the allocator is available. rpcmem_library_ = std::make_shared(); } @@ -1167,10 +1167,10 @@ std::vector QNNExecutionProvider::CreatePreferredAllocators() { std::vector allocators{}; if (IsRpcMemAllocatorAvailable()) { - LOGS_DEFAULT(INFO) << "Creating RpcMemAllocator."; + LOGS_DEFAULT(INFO) << "Creating HtpSharedMemoryAllocator."; AllocatorFactory rpcmem_allocator_factory = [this](OrtDevice::DeviceId) { - return std::make_unique(rpcmem_library_, qnn_backend_manager_); + return std::make_unique(rpcmem_library_, qnn_backend_manager_); }; AllocatorCreationInfo rpcmem_allocator_creation_info{rpcmem_allocator_factory, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 18fdef9a7e3f5..bb6bae688d669 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -79,7 +79,7 @@ class QNNExecutionProvider : public IExecutionProvider { private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; - // This is potentially shared with RpcMemAllocator which may be returned by CreatePreferredAllocators(). + // This is potentially shared with HtpSharedMemoryAllocator which may be returned by CreatePreferredAllocators(). std::shared_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; @@ -102,7 +102,7 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::ModelSettings model_settings_ = {}; // Whether this is set depends on a session option enabling it and if the RPCMEM dynamic library is available. - // This is potentially shared with RpcMemAllocator which may be returned by CreatePreferredAllocators(). + // This is potentially shared with HtpSharedMemoryAllocator which may be returned by CreatePreferredAllocators(). std::shared_ptr rpcmem_library_ = nullptr; class PerThreadContext final { From f179a0d86d97e7c1b5a92b319dcf81bdb5a4a899 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 2 Dec 2024 18:45:51 -0800 Subject: [PATCH 11/53] remove onnx protobuf dependency from allocator.h, add shared provider declarations and definitions for IAllocator::TensorAlloc(). --- include/onnxruntime/core/framework/allocator.h | 9 +++++++-- onnxruntime/core/framework/allocator.cc | 1 + .../providers/shared_library/provider_bridge_provider.cc | 1 + .../core/providers/shared_library/provider_interfaces.h | 1 + onnxruntime/core/session/provider_bridge_ort.cc | 1 + 5 files changed, 11 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 5aaa62f19408b..7eebd8fb6e23f 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -7,12 +7,10 @@ #include "core/common/common.h" #include "core/framework/allocator_stats.h" -#include "core/framework/data_types.h" // some enums are defined in session/onnxruntime_c_api.h but used in ortdevice.h/ortmemory.h #include "core/session/onnxruntime_c_api.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" -#include "core/framework/tensor_shape.h" // This configures the arena based allocator used by ORT // See docs/C_API.md for details on what these mean and how to choose these values @@ -71,6 +69,12 @@ void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_res template using IAllocatorUniquePtr = std::unique_ptr>; +// Note: Re-declare these from core/framework/data_types.h to avoid including the ONNX protobuf header. +class DataTypeImpl; +using MLDataType = const DataTypeImpl*; + +class TensorShape; + class IAllocator { public: IAllocator(const OrtMemoryInfo& info) : memory_info_(info) {} @@ -269,6 +273,7 @@ class CPUAllocator : public IAllocator { CPUAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} void* Alloc(size_t size) override; + void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape) override; void Free(void* p) override; }; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index a7eb82148fc49..cd63ad98ab10b 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -13,6 +13,7 @@ #include #endif +#include "core/framework/data_types.h" #include "core/framework/bfc_arena.h" #include "core/framework/tensor.h" diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d3b12f9728135..00efc10a1fbc5 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -114,6 +114,7 @@ struct OnUnload { } g_on_unload; +void* IAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { return g_host->IAllocator__TensorAlloc(this, element_data_type, shape); } void* CPUAllocator::Alloc(size_t size) { return g_host->CPUAllocator__Alloc(this, size); } void CPUAllocator::Free(void* p) { g_host->CPUAllocator__Free(this, p); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f9f2bb69a9d1a..ae75ad7d55131 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -158,6 +158,7 @@ struct ProviderHost { virtual std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info) = 0; + virtual void* IAllocator__TensorAlloc(IAllocator* p, MLDataType element_data_type, const TensorShape& shape) = 0; virtual void* CPUAllocator__Alloc(CPUAllocator* p, size_t size) = 0; virtual void CPUAllocator__Free(CPUAllocator* p, void* allocation) = 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d55fd34d5a8f2..eb8ad28f0a146 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -231,6 +231,7 @@ struct ProviderHostImpl : ProviderHost { AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) override { return onnxruntime::CreateAllocator(info); } std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info) override { return std::make_unique(memory_info); }; + void* IAllocator__TensorAlloc(IAllocator* p, MLDataType element_data_type, const TensorShape& shape) override { return p->IAllocator::TensorAlloc(element_data_type, shape); } void* CPUAllocator__Alloc(CPUAllocator* p, size_t size) override { return p->CPUAllocator::Alloc(size); } void CPUAllocator__Free(CPUAllocator* p, void* allocation) override { return p->CPUAllocator::Free(allocation); } From 7645ef458a51d3c2b6f5f7227be80c2d323012bb Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 5 Dec 2024 13:12:17 -0800 Subject: [PATCH 12/53] remove unused CPUAllocator::TensorAlloc declaration --- include/onnxruntime/core/framework/allocator.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 7eebd8fb6e23f..449baa4383b6d 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -273,7 +273,6 @@ class CPUAllocator : public IAllocator { CPUAllocator() : IAllocator(OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator)) {} void* Alloc(size_t size) override; - void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape) override; void Free(void* p) override; }; From 104373282d292a89b580d9737eba971903d313d7 Mon Sep 17 00:00:00 2001 From: Baiju Meswani Date: Thu, 5 Dec 2024 15:49:02 -0800 Subject: [PATCH 13/53] Check for nullptr when trying to free --- onnxruntime/core/providers/qnn/qnn_allocator.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 6798de0d5527b..cf134b81e7a60 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -135,6 +135,10 @@ void* HtpSharedMemoryAllocator::TensorAlloc(MLDataType element_data_type, const } void HtpSharedMemoryAllocator::Free(void* p) { + if (!p) { + return; + } + // take ownership of shared memory and free at end of scope auto shared_memory = WrapSharedMemoryWithUniquePtr(p, rpcmem_lib_->Api()); From 022f4bcb2967103f9903f025ddba985a9b1441f2 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 9 Dec 2024 16:53:15 -0800 Subject: [PATCH 14/53] move mem handle management to QNN backend manager --- .../qnn/builder/qnn_backend_manager.cc | 72 ++++- .../qnn/builder/qnn_backend_manager.h | 15 + .../builder/qnn_context_mem_handle_manager.cc | 125 ++++++++ .../builder/qnn_context_mem_handle_manager.h | 59 ++++ .../core/providers/qnn/builder/qnn_def.h | 9 +- .../core/providers/qnn/builder/qnn_model.cc | 42 +-- .../core/providers/qnn/builder/qnn_model.h | 4 - .../qnn/builder/qnn_model_wrapper.cc | 10 +- .../providers/qnn/builder/qnn_model_wrapper.h | 6 +- .../core/providers/qnn/builder/qnn_utils.cc | 12 +- .../core/providers/qnn/builder/qnn_utils.h | 5 + .../core/providers/qnn/qnn_allocator.cc | 294 ++++++++++++------ .../core/providers/qnn/qnn_allocator.h | 65 +++- .../providers/qnn/qnn_execution_provider.cc | 6 +- .../providers/qnn/qnn_execution_provider.h | 5 +- 15 files changed, 575 insertions(+), 154 deletions(-) create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc create mode 100644 onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index f37c91aa0413b..bc917684e62ce 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -7,17 +7,18 @@ #include #include #include "QnnOpDef.h" -#include "HTP/QnnHtpPerfInfrastructure.h" #include "CPU/QnnCpuCommon.h" // TODO: not exist for Windows yet // #include "GPU/QnnGpuCommon.h" #include "DSP/QnnDspCommon.h" #include "HTP/QnnHtpCommon.h" #include "HTP/QnnHtpContext.h" +#include "HTP/QnnHtpPerfInfrastructure.h" #include "Saver/QnnSaver.h" #include #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" +#include "core/providers/qnn/qnn_allocator.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" @@ -549,10 +550,11 @@ Status QnnBackendManager::CreateContext() { device_handle_, context_configs, &context); - contexts_.push_back(context); ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result)); + ORT_RETURN_IF_ERROR(AddQnnContext(context)); // TODO use RAII type for context handle? + context_created_ = true; return Status::OK(); } @@ -562,6 +564,8 @@ Status QnnBackendManager::ReleaseContext() { return Status::OK(); } + ORT_RETURN_IF_ERROR(ReleaseQnnContextMemHandles()); + bool failed = false; for (auto context : contexts_) { Qnn_ErrorHandle_t result = qnn_interface_.contextFree(context, nullptr); @@ -674,7 +678,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t &context, profile_backend_handle_); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); - contexts_.push_back(context); + ORT_RETURN_IF_ERROR(AddQnnContext(context)); if (1 == graph_count) { // in case the EPContext node is generated from script // the graph name from the context binary may not match the EPContext node name @@ -1564,5 +1568,67 @@ void* QnnBackendManager::LibFunction(void* handle, const char* symbol, std::stri #endif } +Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) { + ORT_RETURN_IF(logger_ == nullptr, "logger_ should be set."); + + auto mem_handle_manager = std::make_unique(GetQnnInterface(), context, *logger_); + auto mem_handle_record = ContextMemHandleRecord{std::move(mem_handle_manager), {}}; + const bool inserted = context_mem_handles_.try_emplace(context, std::move(mem_handle_record)).second; + ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", context); + + contexts_.push_back(context); + + return Status::OK(); +} + +Status QnnBackendManager::ReleaseQnnContextMemHandles() { + // remove outstanding allocation clean up callbacks + for (auto& [context_handle, context_mem_handle_record] : context_mem_handles_) { + for (const auto [shared_memory_address, idx] : + context_mem_handle_record.outstanding_allocation_clean_up_callbacks) { + ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::RemoveAllocationCleanUp(shared_memory_address, idx, + /* allocation_clean_up */ nullptr)); + } + } + + context_mem_handles_.clear(); + + return Status::OK(); +} + +Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address, + const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& mem_handle) { + const auto context_mem_handles_it = context_mem_handles_.find(context); + ORT_RETURN_IF_NOT(context_mem_handles_it != context_mem_handles_.end(), "QNN context not found: ", context); + + auto& context_mem_handle_record = context_mem_handles_it->second; + auto& context_mem_handle_manager = *context_mem_handle_record.mem_handle_manager; + bool did_register{}; + ORT_RETURN_IF_ERROR(context_mem_handle_manager.GetOrRegister(shared_memory_address, qnn_tensor, + mem_handle, did_register)); + + if (did_register) { + HtpSharedMemoryAllocator::AllocationCleanUpFn allocation_clean_up = + [&logger = *logger_, &context_mem_handle_manager](void* shared_memory_address) { + auto unregister_status = context_mem_handle_manager.Unregister(shared_memory_address); + if (!unregister_status.IsOK()) { + LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: " + << shared_memory_address << ", error: " << unregister_status.ErrorMessage(); + } + }; + + size_t allocation_clean_up_idx{}; + ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::AddAllocationCleanUp(shared_memory_address, + std::move(allocation_clean_up), + allocation_clean_up_idx)); + + context_mem_handle_record.outstanding_allocation_clean_up_callbacks.emplace_back(shared_memory_address, + allocation_clean_up_idx); + } + + return Status::OK(); +} + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 43007d4a5c244..96e4d2d667569 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -24,6 +24,7 @@ #include "core/common/status.h" #include "core/common/logging/logging.h" #include "core/common/path_string.h" +#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { @@ -163,6 +164,10 @@ class QnnBackendManager { Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); + Status GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address, + const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& mem_handle); + private: void* LoadLib(const char* file_name, int flags, std::string& error_msg); @@ -233,6 +238,9 @@ class QnnBackendManager { const char* eventIdentifier); #endif + Status AddQnnContext(Qnn_ContextHandle_t context); + Status ReleaseQnnContextMemHandles(); + private: const std::string backend_path_; std::mutex logger_mutex_; @@ -246,6 +254,13 @@ class QnnBackendManager { Qnn_LogHandle_t log_handle_ = nullptr; Qnn_DeviceHandle_t device_handle_ = nullptr; std::vector contexts_; + + struct ContextMemHandleRecord { + std::unique_ptr mem_handle_manager; + InlinedVector> outstanding_allocation_clean_up_callbacks; + }; + + std::unordered_map context_mem_handles_; ProfilingLevel profiling_level_etw_; ProfilingLevel profiling_level_; ProfilingLevel profiling_level_merge_; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc new file mode 100644 index 0000000000000..de77b309c0105 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" + +#include "HTP/QnnHtpMem.h" + +#include "core/common/common.h" +#include "core/providers/qnn/builder/qnn_def.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/qnn_allocator.h" + +namespace onnxruntime::qnn { + +QnnContextMemHandleManager::QnnContextMemHandleManager(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ContextHandle_t context, + const logging::Logger& logger) + : qnn_interface_{qnn_interface}, + context_{context}, + logger_{logger} { +} + +QnnContextMemHandleManager::~QnnContextMemHandleManager() { + Clear(); +} + +Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& qnn_mem_handle, bool& did_register) { + const auto qnn_tensor_rank = GetQnnTensorRank(qnn_tensor); + auto* const qnn_tensor_dims = GetQnnTensorDims(qnn_tensor); + const auto qnn_tensor_data_type = GetQnnTensorDataType(qnn_tensor); + + const size_t qnn_tensor_data_size = + utils::GetQnnTensorDataSize(gsl::span{qnn_tensor_dims, size_t{qnn_tensor_rank}}, qnn_tensor_data_type); + + { + std::scoped_lock g{mem_handles_mutex_}; + + // find existing mem handle + if (const auto mem_handles_it = mem_handles_.find(shared_memory_address); + mem_handles_it != mem_handles_.end()) { + const auto& mem_handle_record = mem_handles_it->second; + + // check that actual tensor size is less than or equal to registered tensor size + ORT_RETURN_IF_NOT(qnn_tensor_data_size <= mem_handle_record.registered_tensor_data_size, + "Actual tensor data size (", qnn_tensor_data_size, + ") is larger than registered tensor data size (", mem_handle_record.registered_tensor_data_size, + ")."); + + qnn_mem_handle = mem_handle_record.mem_handle.get(); + did_register = false; + return Status::OK(); + } + + // register a new mem handle + HtpSharedMemoryAllocator::SharedMemoryInfo shared_memory_info{}; + ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(shared_memory_address, + shared_memory_info)); + + Qnn_MemDescriptor_t mem_descriptor{}; + mem_descriptor.memShape.dimSize = qnn_tensor_dims; + mem_descriptor.memShape.numDim = qnn_tensor_rank; + mem_descriptor.memShape.shapeConfig = nullptr; + mem_descriptor.dataType = qnn_tensor_data_type; + mem_descriptor.memType = QNN_MEM_TYPE_CUSTOM; + + QnnMemHtp_Descriptor_t htp_mem_descriptor{}; + htp_mem_descriptor.type = QNN_HTP_MEM_SHARED_BUFFER; + htp_mem_descriptor.size = shared_memory_info.total_size; + htp_mem_descriptor.sharedBufferConfig.fd = shared_memory_info.fd; + htp_mem_descriptor.sharedBufferConfig.offset = shared_memory_info.offset; + + mem_descriptor.customInfo = &htp_mem_descriptor; + + LOGS(logger_, VERBOSE) << "Registering QNN mem handle for context: " << context_ + << ", shared memory (address: " << shared_memory_address + << ", offset: " << shared_memory_info.offset + << ", fd: " << shared_memory_info.fd + << ")"; + + Qnn_MemHandle_t raw_mem_handle{}; + const auto register_result = qnn_interface_.memRegister(context_, &mem_descriptor, 1, &raw_mem_handle); + ORT_RETURN_IF_NOT(register_result == QNN_SUCCESS, + "qnn_interface.memRegister() failed: ", register_result); // TODO get error message + + LOGS(logger_, VERBOSE) << "Registered QNN mem handle. mem_handle: " << raw_mem_handle; + + const auto unregister_mem_handle = [this](Qnn_MemHandle_t raw_mem_handle) { + LOGS(logger_, VERBOSE) << "Unregistering QNN mem handle. mem_handle: " << raw_mem_handle; + + const auto unregister_result = qnn_interface_.memDeRegister(&raw_mem_handle, 1); + if (unregister_result != QNN_SUCCESS) { + LOGS(logger_, ERROR) << "qnn_interface.memDeRegister() failed: " << unregister_result; + return; + } + }; + + UniqueQnnMemHandle mem_handle(raw_mem_handle, unregister_mem_handle); + MemHandleRecord mem_handle_record{qnn_tensor_data_size, std::move(mem_handle)}; + mem_handles_.emplace(shared_memory_address, std::move(mem_handle_record)); + + qnn_mem_handle = raw_mem_handle; + did_register = true; + return Status::OK(); + } +} + +Status QnnContextMemHandleManager::Unregister(void* shared_memory_address) { + std::scoped_lock g{mem_handles_mutex_}; + + auto mem_handles_it = mem_handles_.find(shared_memory_address); + ORT_RETURN_IF_NOT(mem_handles_it != mem_handles_.end(), + "No mem handle found for address (", shared_memory_address, ")."); + + mem_handles_.erase(mem_handles_it); + + return Status::OK(); +} + +void QnnContextMemHandleManager::Clear() { + std::scoped_lock g{mem_handles_mutex_}; + mem_handles_.clear(); +} + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h new file mode 100644 index 0000000000000..acb33d7175061 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "QnnInterface.h" + +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" +#include "core/common/status.h" + +namespace onnxruntime::qnn { + +// This class manages QNN mem handles (Qnn_MemHandle_t) associated with a QNN context (Qnn_ContextHandle_t). +// In particular, it handles the registration and deregistration of mem handles. +// The associated QNN context is expected to be in scope for the lifetime of the QnnContextMemHandleManager. +class QnnContextMemHandleManager { + public: + QnnContextMemHandleManager(const QNN_INTERFACE_VER_TYPE& qnn_interface, Qnn_ContextHandle_t qnn_context, + const logging::Logger& logger); + + ~QnnContextMemHandleManager(); + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnContextMemHandleManager); + + Status GetOrRegister(void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, + Qnn_MemHandle_t& qnn_mem_handle, bool& did_register); + + Status Unregister(void* shared_memory_address); + + void Clear(); + + private: + const QNN_INTERFACE_VER_TYPE& qnn_interface_; + Qnn_ContextHandle_t context_; + const logging::Logger& logger_; + + // assume Qnn_MemHandle_t is a pointer and able to be wrapped with std::unique_ptr + static_assert(std::is_pointer_v); + + using UniqueQnnMemHandle = + std::unique_ptr, std::function>; + + struct MemHandleRecord { + size_t registered_tensor_data_size; + UniqueQnnMemHandle mem_handle; + }; + + // shared memory address -> associated mem handle record + InlinedHashMap mem_handles_; + std::mutex mem_handles_mutex_; // synchronize access to mem_handles_ +}; + +} // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index e8e5453afa48b..b3b6b392d7857 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -467,11 +467,13 @@ class QnnOpProperty { class GraphInfo { public: - GraphInfo(const Qnn_GraphHandle_t graph, + GraphInfo(Qnn_GraphHandle_t graph, const std::string& name, + Qnn_ContextHandle_t graph_context, std::vector&& input_tensors, std::vector&& output_tensors) : graph_name_(name), graph_(graph), + graph_context_(graph_context), input_tensors_(std::move(input_tensors)), output_tensors_(std::move(output_tensors)) { } @@ -481,12 +483,15 @@ class GraphInfo { const std::string& Name() const { return graph_name_; } const std::vector& InputTensors() const { return input_tensors_; } const std::vector& OutputTensors() const { return output_tensors_; } - const Qnn_GraphHandle_t& Graph() const { return graph_; } + Qnn_GraphHandle_t Graph() const { return graph_; } + Qnn_ContextHandle_t GraphContext() const { return graph_context_; } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphInfo); private: std::string graph_name_; Qnn_GraphHandle_t graph_; + // QNN context that holds the QNN graph referenced by `graph_` + Qnn_ContextHandle_t graph_context_; std::vector input_tensors_; std::vector output_tensors_; }; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index d991759f1a731..23a9f515aec0a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -20,17 +20,14 @@ namespace onnxruntime { namespace qnn { -bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger) { +bool QnnModel::GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& /* logger */) { bool rt = true; graph_info_ = std::make_unique(model_wrapper.GetQnnGraph(), model_wrapper.GetQnnGraphName(), + model_wrapper.GetQnnGraphContext(), std::move(model_wrapper.GetGraphInputTensorWrappers()), std::move(model_wrapper.GetGraphOutputTensorWrappers())); - if (graph_info_ == nullptr) { - LOGS(logger, ERROR) << "GetGraphInfoFromModel() failed to allocate GraphInfo."; - return false; - } return rt; } @@ -189,11 +186,13 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { } static Status BindQnnTensorMemoryToOrtValue(const logging::Logger& logger, + QnnBackendManager& qnn_backend_manager, const OrtMemoryInfo& ort_value_memory_info, void* ort_value_data, uint32_t ort_value_data_size, + Qnn_ContextHandle_t qnn_context, Qnn_Tensor_t& qnn_tensor) { // either set qnn_tensor memHandle or clientBuf - const bool uses_shared_memory = ort_value_memory_info == HtpSharedMemoryAllocator::MemoryInfo(); + const bool uses_shared_memory = ort_value_memory_info == HtpSharedMemoryAllocator::AssociatedMemoryInfo(); if (!uses_shared_memory) { LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t clientBuf to ORT tensor memory."; @@ -201,7 +200,9 @@ static Status BindQnnTensorMemoryToOrtValue(const logging::Logger& logger, SetQnnTensorClientBuf(qnn_tensor, ort_value_data, ort_value_data_size); } else { LOGS(logger, VERBOSE) << "Setting Qnn_Tensor_t memHandle to ORT tensor shared memory."; - const Qnn_MemHandle_t qnn_mem_handle = SharedContext::GetInstance().GetSharedMemHandles().Get(ort_value_data); + Qnn_MemHandle_t qnn_mem_handle{}; + ORT_RETURN_IF_ERROR(qnn_backend_manager.GetOrRegisterContextMemHandle(qnn_context, ort_value_data, qnn_tensor, + qnn_mem_handle)); SetQnnTensorMemType(qnn_tensor, QNN_TENSORMEMTYPE_MEMHANDLE); SetQnnTensorMemHandle(qnn_tensor, qnn_mem_handle); } @@ -243,8 +244,10 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( logger, + *qnn_backend_manager_, *static_cast(ort_input_tensor.GetTensorMemoryInfo()), const_cast(ort_input_tensor.GetTensorRawData()), qnn_input_info.tensor_byte_size, + graph_info_->GraphContext(), qnn_inputs.back())); } @@ -267,8 +270,10 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( logger, + *qnn_backend_manager_, *static_cast(ort_output_tensor.GetTensorMemoryInfo()), const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, + graph_info_->GraphContext(), qnn_outputs.back())); } @@ -308,20 +313,6 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, return Status::OK(); } -Status QnnModel::GetQnnTensorDataLength(const std::vector& dims, - Qnn_DataType_t data_type, - size_t& data_length) const { - ORT_RETURN_IF(dims.empty(), "Tensor dimensions is nullptr"); - - data_length = utils::GetElementSizeByType(data_type); - - for (size_t r = 0; r < dims.size(); r++) { - data_length *= dims[r]; - } - - return Status::OK(); -} - // Setup information for Qnn inputs/outputs used during execution. Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, const std::vector& tensor_wrappers, @@ -331,11 +322,8 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, qnn_tensor_infos.resize(tensor_count); for (auto& tensor_wrapper : tensor_wrappers) { - size_t length = 0; - using namespace qnn::utils; - ORT_RETURN_IF_ERROR(GetQnnTensorDataLength(tensor_wrapper.GetTensorDims(), - tensor_wrapper.GetTensorDataType(), - length)); + const size_t length = utils::GetQnnTensorDataSize(tensor_wrapper.GetTensorDims(), + tensor_wrapper.GetTensorDataType()); const auto& tensor_name = tensor_wrapper.GetName(); auto qnn_index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name); auto ort_index = is_input ? GetOrtInputIndex(tensor_name) : qnn_index; @@ -405,9 +393,9 @@ Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_Graph graph_info_ = std::make_unique(graph, graph_name, + context, std::move(input_tensor_wrappers), std::move(output_tensor_wrappers)); - ORT_RETURN_IF(graph_info_ == nullptr, "Failed to allocate GraphInfo"); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 85d50eff09d67..2f220e708c50e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -113,10 +113,6 @@ class QnnModel { const std::unordered_map& node_unit_map) const; bool GetGraphInfoFromModel(QnnModelWrapper& model_wrapper, const logging::Logger& logger); - Status GetQnnTensorDataLength(const std::vector& dims, - Qnn_DataType_t data_type, - size_t& data_length) const; - Status SetupTensors(std::vector& tensors, const std::vector& tensor_wrappers, bool is_input = true); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 2c7f3c8b22ddd..c2e3e9516150f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -30,21 +30,23 @@ bool QnnModelWrapper::CreateQnnGraph(const Qnn_ContextHandle_t& context, return false; } if (graph_name.length() == 0) { - LOGS(logger_, ERROR) << "Empty grpah name."; + LOGS(logger_, ERROR) << "Empty graph name."; return false; } - graph_name_ = graph_name; - auto rt = qnn_interface_.graphCreate(context, graph_name_.c_str(), graph_configs, &graph_); + auto rt = qnn_interface_.graphCreate(context, graph_name.c_str(), graph_configs, &graph_); if (rt != QNN_GRAPH_NO_ERROR || graph_ == nullptr) { - rt = qnn_interface_.graphRetrieve(context, graph_name_.c_str(), &graph_); + rt = qnn_interface_.graphRetrieve(context, graph_name.c_str(), &graph_); if (rt != QNN_GRAPH_NO_ERROR || graph_ == nullptr) { LOGS(logger_, ERROR) << "Failed to create Qnn graph: " << graph_name; return false; } } + LOGS(logger_, VERBOSE) << "Created Qnn graph: " << graph_name; + graph_name_ = graph_name; + graph_context_ = context; return true; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index f3e52050e79e0..6e165a5f95afe 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -93,10 +93,12 @@ class QnnModelWrapper { bool ComposeQnnGraph(); - Qnn_GraphHandle_t GetQnnGraph() { return graph_; } + Qnn_GraphHandle_t GetQnnGraph() const { return graph_; } std::string GetQnnGraphName() const { return graph_name_; } + Qnn_ContextHandle_t GetQnnGraphContext() const { return graph_context_; } + // Move input tensor wrappers to GraphInfo, QnnModelWrapper end of live std::vector&& GetGraphInputTensorWrappers() { GetGraphInputOutputTensorWrapper(model_input_names_, model_input_tensor_wrappers_); @@ -270,6 +272,8 @@ class QnnModelWrapper { const Qnn_BackendHandle_t& backend_handle_; Qnn_GraphHandle_t graph_ = nullptr; std::string graph_name_ = ""; + // QNN context that holds the QNN graph referenced by `graph_` + Qnn_ContextHandle_t graph_context_ = nullptr; std::vector model_input_names_; std::vector model_output_names_; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 8d2cb5bdb6da0..39b18ccc55fb7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -1,15 +1,17 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/providers/qnn/builder/qnn_utils.h" + #include +#include #include #include #include -#include #include "core/common/common.h" +#include "core/common/safeint.h" #include "core/framework/data_types.h" -#include "qnn_utils.h" #include "core/providers/qnn/builder/qnn_def.h" namespace onnxruntime { @@ -63,6 +65,12 @@ size_t GetElementSizeByType(ONNXTensorElementDataType elem_type) { return pos->second; } +size_t GetQnnTensorDataSize(gsl::span shape, Qnn_DataType_t element_type) { + ORT_ENFORCE(!shape.empty(), "Empty shape not allowed."); // TODO can we just treat empty shape as a scalar? + SafeInt data_length = GetElementSizeByType(element_type); + return std::accumulate(shape.begin(), shape.end(), data_length, std::multiplies<>{}); +} + std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { switch (scalar.dataType) { case QNN_DATATYPE_INT_8: diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index aa4a27460563f..ac299706b8588 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -8,6 +8,8 @@ #include #include +#include + #include "QnnTypes.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/node_unit.h" @@ -22,6 +24,9 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type); size_t GetElementSizeByType(ONNXTensorElementDataType elem_type); +// Gets tensor data size in bytes. +size_t GetQnnTensorDataSize(gsl::span shape, Qnn_DataType_t element_data_type); + // TODO: make these work with Wrappers? std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param); std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor); diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index cf134b81e7a60..d06c4b95584e4 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -3,150 +3,256 @@ #include "core/providers/qnn/qnn_allocator.h" +#include +#include #include #include -#include - #include "core/common/common.h" -#include "core/common/logging/logging.h" -#include "core/common/inlined_containers.h" -#include "core/common/narrow.h" -#include "core/framework/tensor.h" -#include "core/providers/qnn/builder/qnn_utils.h" -#include "core/providers/qnn/shared_context.h" // for shared mem handle access +#include "core/mlas/inc/mlas.h" // for MlasGetPreferredBufferAlignment() namespace onnxruntime::qnn { namespace { -Qnn_MemHandle_t RegisterQnnMemHandle(const QNN_INTERFACE_VER_TYPE& qnn_interface, - Qnn_ContextHandle_t qnn_context_handle, - int shared_memory_fd, - MLDataType element_data_type, const TensorShape& shape) { - auto qnn_shape = [shape_span = shape.GetDims()]() { - InlinedVector qnn_shape; - std::transform(shape_span.begin(), shape_span.end(), std::back_inserter(qnn_shape), - [](int64_t dim) { return narrow(dim); }); - return qnn_shape; - }(); +struct AllocationHeader { + static constexpr std::array kAllocationHeaderMarker{'o', 'r', 't', 'a', 'l', 'l', 'o', 'c'}; - const auto qnn_data_type = [element_data_type]() { - Qnn_DataType_t qnn_data_type; - ORT_ENFORCE(element_data_type->IsPrimitiveDataType()); - const auto onnx_data_type = element_data_type->AsPrimitiveDataType()->GetDataType(); - const bool is_quantized = false; // TODO how should we set this? - if (!utils::OnnxDataTypeToQnnDataType(onnx_data_type, qnn_data_type, is_quantized)) { - ORT_THROW("Unable to get QNN data type from ONNX data type: ", onnx_data_type); - } - return qnn_data_type; - }(); + // Marker bytes to verify as a sanity check. + std::array marker; + + // Pointer to the allocating allocator instance. + // Note: A critical assumption here is that the allocating allocator is not destroyed before the allocation is freed. + HtpSharedMemoryAllocator* allocator_ptr; - // set up QNN memory descriptor - Qnn_MemDescriptor_t qnn_mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; - qnn_mem_descriptor.memShape = {narrow(qnn_shape.size()), - qnn_shape.data(), - nullptr}; - qnn_mem_descriptor.dataType = qnn_data_type; - qnn_mem_descriptor.memType = QNN_MEM_TYPE_ION; - qnn_mem_descriptor.ionInfo.fd = shared_memory_fd; - - Qnn_MemHandle_t qnn_mem_handle = nullptr; - const auto register_status = qnn_interface.memRegister(qnn_context_handle, &qnn_mem_descriptor, 1, - &qnn_mem_handle); - // TODO show error message - ORT_ENFORCE(register_status == QNN_SUCCESS, - "qnn_interface.memRegister() failed with error code ", register_status); - - return qnn_mem_handle; -} - -void DeregisterQnnMemHandle(const QNN_INTERFACE_VER_TYPE& qnn_interface, - Qnn_MemHandle_t qnn_mem_handle) { - const auto deregister_status = qnn_interface.memDeRegister(&qnn_mem_handle, 1); - // TODO show error message - if (deregister_status != QNN_SUCCESS) { - LOGS_DEFAULT(ERROR) << "qnn_interface.memDeRegister() failed with error code " << deregister_status; + AllocationHeader(HtpSharedMemoryAllocator* allocator_ptr) + : marker{kAllocationHeaderMarker}, + allocator_ptr{allocator_ptr} { } + + ~AllocationHeader() { + marker.fill('\0'); + allocator_ptr = nullptr; + } +}; + +size_t AllocationAlignment() { + return std::max(alignof(AllocationHeader), MlasGetPreferredBufferAlignment()); +} + +size_t DivRoundUp(size_t a, size_t b) { // TODO is there already a helper function somewhere for this? + return (a + b - 1) / b; +} + +bool IsAligned(const void* address, size_t alignment) { + assert((alignment & alignment - 1) == 0); + return (reinterpret_cast(address) & (alignment - 1)) == 0; +} + +size_t AllocationOffsetFromStartOfHeader() { + const size_t allocation_alignment = AllocationAlignment(); + const size_t offset = DivRoundUp(sizeof(AllocationHeader), allocation_alignment) * allocation_alignment; + return offset; +} + +std::byte* GetAllocationHeaderAddress(void* allocation_address) { + auto* allocation_header_address = reinterpret_cast(allocation_address) - sizeof(AllocationHeader); + return allocation_header_address; } -using RpcMemUniquePtr = std::unique_ptr; +AllocationHeader& ValidateAllocationAddressAndGetHeader(void* allocation_address) { + const size_t allocation_alignment = AllocationAlignment(); + ORT_ENFORCE(IsAligned(allocation_address, allocation_alignment), + "Allocation address (", allocation_address, ") does not have required alignment (", + allocation_alignment, " bytes)."); + + auto* allocation_header = reinterpret_cast(GetAllocationHeaderAddress(allocation_address)); + ORT_ENFORCE(allocation_header->marker == AllocationHeader::kAllocationHeaderMarker, + "AllocationHeader for allocation address (", allocation_address, + ") does not have the expected marker bytes."); + + return *allocation_header; +} -RpcMemUniquePtr WrapSharedMemoryWithUniquePtr(void* shared_memory_raw, const RpcMemApi& rpcmem_api) { +std::unique_ptr WrapSharedMemoryWithUniquePtr(void* shared_memory_raw, + const RpcMemApi& rpcmem_api) { return {shared_memory_raw, rpcmem_api.free}; } } // namespace -OrtMemoryInfo HtpSharedMemoryAllocator::MemoryInfo() { +OrtMemoryInfo HtpSharedMemoryAllocator::AssociatedMemoryInfo() { return OrtMemoryInfo{QNN_HTP_SHARED, OrtAllocatorType::OrtDeviceAllocator, OrtDevice{OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, /* device_id */ 0}, /* id */ 0, OrtMemTypeDefault}; } -HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, - std::shared_ptr qnn_backend_manager) - : IAllocator{MemoryInfo()}, - rpcmem_lib_{std::move(rpcmem_lib)}, - qnn_backend_manager_{std::move(qnn_backend_manager)} { +HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib) + : IAllocator{AssociatedMemoryInfo()}, + rpcmem_lib_{std::move(rpcmem_lib)} { ORT_ENFORCE(rpcmem_lib_ != nullptr); - ORT_ENFORCE(qnn_backend_manager_ != nullptr); } -void* HtpSharedMemoryAllocator::Alloc(size_t /* size */) { - LOGS_DEFAULT(ERROR) << "hey this ain't right"; - std::exit(1); - ORT_THROW("HtpSharedMemoryAllocator::Alloc() is not implemented. Use HtpSharedMemoryAllocator::TensorAlloc() instead."); -} - -void* HtpSharedMemoryAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { - const auto size_in_bytes = Tensor::CalculateTensorStorageSize(element_data_type, shape); - - if (size_in_bytes == 0) { - return nullptr; - } +void* HtpSharedMemoryAllocator::Alloc(size_t requested_size) { + const size_t allocation_offset = AllocationOffsetFromStartOfHeader(); + const size_t shared_memory_block_size_in_bytes = allocation_offset + requested_size; // rpcmem_alloc() has an int size parameter. make sure we don't overflow. constexpr size_t max_size_in_bytes = std::numeric_limits::max(); - ORT_ENFORCE(size_in_bytes <= max_size_in_bytes, - "Allocation size (", size_in_bytes, ") is larger than maximum allowed (", max_size_in_bytes, ")."); + ORT_ENFORCE(shared_memory_block_size_in_bytes <= max_size_in_bytes, + "Allocation size (", shared_memory_block_size_in_bytes, ") is larger than maximum allowed (", + max_size_in_bytes, ")."); // allocate shared memory void* shared_memory_raw = rpcmem_lib_->Api().alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, - static_cast(size_in_bytes)); + static_cast(shared_memory_block_size_in_bytes)); auto shared_memory = WrapSharedMemoryWithUniquePtr(shared_memory_raw, rpcmem_lib_->Api()); + const size_t allocation_alignment = AllocationAlignment(); + ORT_ENFORCE(IsAligned(shared_memory_raw, allocation_alignment), + "Shared memory address (", shared_memory_raw, ") does not have required alignment (", + allocation_alignment, " bytes)."); + // get shared memory fd const auto shared_memory_fd = rpcmem_lib_->Api().to_fd(shared_memory.get()); ORT_ENFORCE(shared_memory_fd != -1, "rpcmem_to_fd() returned invalid file descriptor."); - // register mem handle - // TODO synchronize calls to qnn_interface.memRegister()? - const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); - const auto qnn_context_handle = qnn_backend_manager_->GetQnnContext(); - const auto qnn_mem_handle = RegisterQnnMemHandle(qnn_interface, qnn_context_handle, - shared_memory_fd, element_data_type, shape); + std::byte* allocation_address = reinterpret_cast(shared_memory_raw) + allocation_offset; + + // store allocation record + { + SharedMemoryInfo shared_memory_info{}; + shared_memory_info.fd = shared_memory_fd; + shared_memory_info.offset = allocation_offset; + shared_memory_info.total_size = shared_memory_block_size_in_bytes; + + AllocationRecord allocation_record{}; + allocation_record.shared_memory_info = std::move(shared_memory_info); - // save mem handle. for now, the global SharedContext will do... - SharedContext::GetInstance().GetSharedMemHandles().Add(shared_memory.get(), qnn_mem_handle); + std::scoped_lock g{allocations_mutex_}; + const bool inserted = allocations_.emplace(allocation_address, std::move(allocation_record)).second; + ORT_ENFORCE(inserted, "Allocation info already exists for address (", allocation_address, ")."); + } + + // initialize header + { + std::byte* allocation_header_address = GetAllocationHeaderAddress(allocation_address); + new (allocation_header_address) AllocationHeader(this); + } - return shared_memory.release(); + shared_memory.release(); + return allocation_address; } -void HtpSharedMemoryAllocator::Free(void* p) { - if (!p) { +void HtpSharedMemoryAllocator::Free(void* allocation_address) { + if (allocation_address == nullptr) { return; } + // TODO should we throw exceptions at all from Free()? + + auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); + ORT_ENFORCE(allocation_header.allocator_ptr == this, + "AllocationHeader points to a different allocator (", allocation_header.allocator_ptr, + ") than this one (", this, ")."); + + const auto allocation_node = [this, allocation_address]() { + std::scoped_lock g{allocations_mutex_}; + return allocations_.extract(allocation_address); + }(); + + ORT_ENFORCE(!allocation_node.empty(), "Failed to get allocation info for address (", allocation_address, ")."); + // take ownership of shared memory and free at end of scope - auto shared_memory = WrapSharedMemoryWithUniquePtr(p, rpcmem_lib_->Api()); + auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); + + // destroy header + allocation_header.~AllocationHeader(); + + // clean up allocation record + const auto& allocation_info = allocation_node.mapped(); + for (auto& clean_up_fn : allocation_info.clean_up_fns) { + clean_up_fn(allocation_address); // TODO handle exceptions? + } +} + +Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(void* allocation_address, + SharedMemoryInfo& allocation_info) { + auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); + return allocation_header.allocator_ptr->GetAllocationSharedMemoryInfoForThisAllocator(allocation_address, + allocation_info); +} + +Status HtpSharedMemoryAllocator::AddAllocationCleanUp(void* allocation_address, + AllocationCleanUpFn&& allocation_clean_up, + size_t& allocation_clean_up_idx) { + auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); + return allocation_header.allocator_ptr->AddAllocationCleanUpForThisAllocator(allocation_address, + std::move(allocation_clean_up), + allocation_clean_up_idx); +} + +Status HtpSharedMemoryAllocator::RemoveAllocationCleanUp(void* allocation_address, + size_t allocation_clean_up_idx, + AllocationCleanUpFn* allocation_clean_up) { + auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); + return allocation_header.allocator_ptr->RemoveAllocationCleanUpForThisAllocator(allocation_address, + allocation_clean_up_idx, + allocation_clean_up); +} + +Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, + SharedMemoryInfo& allocation_info) { + std::scoped_lock g{allocations_mutex_}; + const auto allocation_infos_it = allocations_.find(allocation_address); + ORT_RETURN_IF(allocation_infos_it == allocations_.end(), + "Failed to get allocation info for address (", allocation_address, ")."); + + allocation_info = allocation_infos_it->second.shared_memory_info; + return Status::OK(); +} + +Status HtpSharedMemoryAllocator::AddAllocationCleanUpForThisAllocator(void* allocation_address, + AllocationCleanUpFn&& allocation_clean_up, + size_t& allocation_clean_up_idx) { + ORT_RETURN_IF(allocation_clean_up == nullptr, "allocation_clean_up should not be empty."); + + std::scoped_lock g{allocations_mutex_}; + const auto allocation_infos_it = allocations_.find(allocation_address); + ORT_RETURN_IF(allocation_infos_it == allocations_.end(), + "Failed to get allocation info for address (", allocation_address, ")."); + + auto& clean_up_fns = allocation_infos_it->second.clean_up_fns; + clean_up_fns.emplace_back(std::move(allocation_clean_up)); + allocation_clean_up_idx = clean_up_fns.size() - 1; + return Status::OK(); +} + +Status HtpSharedMemoryAllocator::RemoveAllocationCleanUpForThisAllocator(void* allocation_address, + size_t allocation_clean_up_idx, + AllocationCleanUpFn* allocation_clean_up) { + std::scoped_lock g{allocations_mutex_}; + const auto allocation_infos_it = allocations_.find(allocation_address); + ORT_RETURN_IF(allocation_infos_it == allocations_.end(), + "Failed to get allocation info for address (", allocation_address, ")."); + + auto& clean_up_fns = allocation_infos_it->second.clean_up_fns; + ORT_RETURN_IF_NOT(allocation_clean_up_idx < clean_up_fns.size(), + "Invalid allocation_clean_up_idx: ", allocation_clean_up_idx); + + AllocationCleanUpFn& clean_up_fn = clean_up_fns[allocation_clean_up_idx]; + ORT_RETURN_IF(clean_up_fn == nullptr, + "Allocation clean up has already been removed at allocation_clean_up_idx: ", allocation_clean_up_idx); + + AllocationCleanUpFn removed_clean_up_fn = nullptr; + removed_clean_up_fn.swap(clean_up_fn); + + if (allocation_clean_up != nullptr) { + *allocation_clean_up = std::move(removed_clean_up_fn); + } - // deregister mem handle - // TODO synchronize calls to qnn_interface.memDeRegister()? - const auto& qnn_interface = qnn_backend_manager_->GetQnnInterface(); - const auto qnn_mem_handle = SharedContext::GetInstance().GetSharedMemHandles().GetAndRemove(p); - DeregisterQnnMemHandle(qnn_interface, qnn_mem_handle); + return Status::OK(); } } // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index 0e80df5c2a175..c7619657c92d1 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -4,33 +4,76 @@ #pragma once #include +#include +#include "core/common/common.h" +#include "core/common/inlined_containers.h" +#include "core/common/status.h" #include "core/framework/allocator.h" - -#include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/rpcmem_library.h" namespace onnxruntime::qnn { -class QnnBackendManager; -class RpcMemLibrary; - class HtpSharedMemoryAllocator : public IAllocator { public: - // Gets the single OrtMemoryInfo value that is associated with this allocator type. - static OrtMemoryInfo MemoryInfo(); + // Gets the OrtMemoryInfo value that is associated with this allocator type. + static OrtMemoryInfo AssociatedMemoryInfo(); + + HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib); - HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, - std::shared_ptr qnn_backend_manager); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(HtpSharedMemoryAllocator); + + // IAllocator overrides void* Alloc(size_t size) override; - void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape) override; void Free(void* p) override; // void GetStats(AllocatorStats* stats) override; + struct SharedMemoryInfo { + int fd; + uint64_t offset; + uint64_t total_size; + }; + + // Get an allocation's shared memory info. + // `allocation_address` must be an address returned by Alloc() which has not yet been freed. + static Status GetAllocationSharedMemoryInfo(void* allocation_address, + SharedMemoryInfo& allocation_info); + + using AllocationCleanUpFn = std::function; + + // Add allocation clean up callback to call when the allocation is freed. + // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been freed. + // `allocation_clean_up` is the clean up callback. This call takes ownership. + // `allocation_clean_up_idx` identifies this clean up callback. It can be passed to RemoveAllocationCleanUp() to remove this callback later. + static Status AddAllocationCleanUp(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up, + size_t& allocation_clean_up_idx); + + // Remove allocation clean up callback that was previously added. + // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been freed. + // `allocation_clean_up_idx` identifies this clean up callback. + // `allocation_clean_up` is optional and, if provided, will contain the removed allocation clean up callback. + static Status RemoveAllocationCleanUp(void* allocation_address, size_t allocation_clean_up_idx, + AllocationCleanUpFn* allocation_clean_up); + private: + Status GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, + SharedMemoryInfo& allocation_info); + Status AddAllocationCleanUpForThisAllocator(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up, + size_t& allocation_clean_up_idx); + Status RemoveAllocationCleanUpForThisAllocator(void* allocation_address, size_t allocation_clean_up_idx, + AllocationCleanUpFn* allocation_clean_up); + + struct AllocationRecord { + SharedMemoryInfo shared_memory_info; + InlinedVector clean_up_fns; + }; + + // allocation address -> corresponding allocation record + InlinedHashMap allocations_; + std::mutex allocations_mutex_; // synchronize access to allocation_ + std::shared_ptr rpcmem_lib_; - std::shared_ptr qnn_backend_manager_; }; } // namespace onnxruntime::qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index f8af1752bbc62..1eedaec54f5c8 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -396,7 +396,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio rpcmem_library_ = std::make_shared(); } - qnn_backend_manager_ = std::make_shared( + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level_etw, profiling_level, @@ -1166,11 +1166,11 @@ Status QNNExecutionProvider::OnRunEnd(bool /*sync_stream*/, const onnxruntime::R std::vector QNNExecutionProvider::CreatePreferredAllocators() { std::vector allocators{}; - if (IsRpcMemAllocatorAvailable()) { + if (IsHtpSharedMemoryAllocatorAvailable()) { LOGS_DEFAULT(INFO) << "Creating HtpSharedMemoryAllocator."; AllocatorFactory rpcmem_allocator_factory = [this](OrtDevice::DeviceId) { - return std::make_unique(rpcmem_library_, qnn_backend_manager_); + return std::make_unique(rpcmem_library_); }; AllocatorCreationInfo rpcmem_allocator_creation_info{rpcmem_allocator_factory, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index bb6bae688d669..89e79326a60b2 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -75,12 +75,11 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); - bool IsRpcMemAllocatorAvailable() const { return rpcmem_library_ != nullptr; } + bool IsHtpSharedMemoryAllocatorAvailable() const { return rpcmem_library_ != nullptr; } private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; - // This is potentially shared with HtpSharedMemoryAllocator which may be returned by CreatePreferredAllocators(). - std::shared_ptr qnn_backend_manager_; + std::unique_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; From c527dee22771d7fa996fc6049645d89dac5cdb8f Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:15:40 -0800 Subject: [PATCH 15/53] remove IAllocator::TensorAlloc() --- .../onnxruntime/core/framework/allocator.h | 15 ---------- .../core/session/onnxruntime_c_api.h | 3 -- onnxruntime/core/framework/allocator.cc | 7 ----- onnxruntime/core/framework/tensor.cc | 7 ++++- .../provider_bridge_provider.cc | 1 - .../core/session/allocator_adapters.cc | 28 ------------------- onnxruntime/core/session/allocator_adapters.h | 3 -- .../core/session/provider_bridge_ort.cc | 1 - 8 files changed, 6 insertions(+), 59 deletions(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 449baa4383b6d..525277375830c 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -69,12 +69,6 @@ void* AllocateBufferWithOptions(IAllocator& allocator, size_t size, bool use_res template using IAllocatorUniquePtr = std::unique_ptr>; -// Note: Re-declare these from core/framework/data_types.h to avoid including the ONNX protobuf header. -class DataTypeImpl; -using MLDataType = const DataTypeImpl*; - -class TensorShape; - class IAllocator { public: IAllocator(const OrtMemoryInfo& info) : memory_info_(info) {} @@ -90,15 +84,6 @@ class IAllocator { virtual void Free(void* p) = 0; - /** - * Allocate memory for a tensor of the given shape and element data type. - * If the tensor size is 0, nullptr is returned. - * On other failures, an exception is thrown. - * - * Note: The default implementation will call Alloc(). - */ - virtual void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape); - // Reserve() is an interface exposed for an implementation of IAllocator // to optionally implement some allocation logic that by-passes any arena-based // logic that may be housed in the Alloc() implementation. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index a65cfc7e72a57..b1a79f5921328 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -329,9 +329,6 @@ typedef struct OrtAllocator { * those made during session initialization. This allows for separate memory management strategies for these allocations. */ void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes - // TODO docs - void*(ORT_API_CALL* TensorAlloc)(struct OrtAllocator* this_, - const int64_t* shape, size_t shape_len, ONNXTensorElementDataType element_data_type); } OrtAllocator; typedef void(ORT_API_CALL* OrtLoggingFunction)( diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index cd63ad98ab10b..02dbb3e518783 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -13,17 +13,10 @@ #include #endif -#include "core/framework/data_types.h" #include "core/framework/bfc_arena.h" -#include "core/framework/tensor.h" namespace onnxruntime { -void* IAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { - const auto size_in_bytes = Tensor::CalculateTensorStorageSize(element_data_type, shape); - return Alloc(size_in_bytes); -} - // private helper for calculation so SafeInt usage doesn't bleed into the public allocator.h header bool IAllocator::CalcMemSizeForArrayWithAlignment(size_t nmemb, size_t size, size_t alignment, size_t* out) noexcept { bool ok = true; diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index ea80f55ac0327..60d768cc59a5d 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -87,7 +87,12 @@ Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, cons Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator) : alloc_info_(allocator->Info()) { ORT_ENFORCE(elt_type != nullptr); - void* p_data = allocator->TensorAlloc(elt_type, shape); + size_t len = Tensor::CalculateTensorStorageSize(elt_type, shape); + + void* p_data = nullptr; + if (len > 0) { + p_data = allocator->Alloc(len); + } Init(elt_type, shape, p_data, allocator, 0L); } diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 00efc10a1fbc5..d3b12f9728135 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -114,7 +114,6 @@ struct OnUnload { } g_on_unload; -void* IAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { return g_host->IAllocator__TensorAlloc(this, element_data_type, shape); } void* CPUAllocator::Alloc(size_t size) { return g_host->CPUAllocator__Alloc(this, size); } void CPUAllocator::Free(void* p) { g_host->CPUAllocator__Free(this, p); } diff --git a/onnxruntime/core/session/allocator_adapters.cc b/onnxruntime/core/session/allocator_adapters.cc index 2397b128e8163..bebf6e98ff3fa 100644 --- a/onnxruntime/core/session/allocator_adapters.cc +++ b/onnxruntime/core/session/allocator_adapters.cc @@ -2,7 +2,6 @@ // Licensed under the MIT License. #include "allocator_adapters.h" -#include "core/framework/data_types.h" #include "core/framework/error_code_helper.h" #include "core/session/inference_session.h" #include "core/session/ort_env.h" @@ -12,7 +11,6 @@ namespace onnxruntime { namespace { constexpr uint32_t kOrtAllocatorReserveMinVersion = 18; -constexpr uint32_t kOrtAllocatorTensorAllocMinVersion = 21; } // namespace OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxruntime::AllocatorPtr&& i_allocator) @@ -28,13 +26,6 @@ OrtAllocatorImplWrappingIAllocator::OrtAllocatorImplWrappingIAllocator(onnxrunti OrtAllocator::Reserve = [](OrtAllocator* this_, size_t size) { return static_cast(this_)->Reserve(size); }; } - if (OrtAllocator::version >= kOrtAllocatorTensorAllocMinVersion) { - OrtAllocator::TensorAlloc = - [](OrtAllocator* this_, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType element_data_type) { - return static_cast(this_)->TensorAlloc(shape, shape_len, - element_data_type); - }; - } } void* OrtAllocatorImplWrappingIAllocator::Alloc(size_t size) { @@ -45,13 +36,6 @@ void* OrtAllocatorImplWrappingIAllocator::Reserve(size_t size) { return i_allocator_->Reserve(size); } -void* OrtAllocatorImplWrappingIAllocator::TensorAlloc(const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType onnx_element_data_type) { - const auto tensor_type = DataTypeImpl::TensorTypeFromONNXEnum(onnx_element_data_type); - const TensorShape tensor_shape(gsl::span{shape, shape_len}); - return i_allocator_->TensorAlloc(tensor_type->GetElementType(), tensor_shape); -} - void OrtAllocatorImplWrappingIAllocator::Free(void* p) { i_allocator_->Free(p); } @@ -79,18 +63,6 @@ void* IAllocatorImplWrappingOrtAllocator::Reserve(size_t size) { return ort_allocator_->Alloc(ort_allocator_, size); } -void* IAllocatorImplWrappingOrtAllocator::TensorAlloc(MLDataType element_data_type, const TensorShape& shape) { - if (ort_allocator_->version >= kOrtAllocatorTensorAllocMinVersion && ort_allocator_->TensorAlloc) { - const auto shape_span = shape.GetDims(); - ORT_ENFORCE(element_data_type->IsPrimitiveDataType()); - const auto onnx_element_data_type = - static_cast(element_data_type->AsPrimitiveDataType()->GetDataType()); - return ort_allocator_->TensorAlloc(ort_allocator_, shape_span.data(), shape_span.size(), onnx_element_data_type); - } - - return IAllocator::TensorAlloc(element_data_type, shape); -} - void IAllocatorImplWrappingOrtAllocator::Free(void* p) { return ort_allocator_->Free(ort_allocator_, p); } diff --git a/onnxruntime/core/session/allocator_adapters.h b/onnxruntime/core/session/allocator_adapters.h index a8f3b6460574f..48f4ea03118c8 100644 --- a/onnxruntime/core/session/allocator_adapters.h +++ b/onnxruntime/core/session/allocator_adapters.h @@ -29,8 +29,6 @@ struct OrtAllocatorImplWrappingIAllocator final : public OrtAllocatorImpl { const OrtMemoryInfo* Info() const; void* Reserve(size_t size); - void* TensorAlloc(const int64_t* shape, size_t shape_len, ONNXTensorElementDataType element_data_type); - ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtAllocatorImplWrappingIAllocator); onnxruntime::AllocatorPtr GetWrappedIAllocator(); @@ -47,7 +45,6 @@ class IAllocatorImplWrappingOrtAllocator final : public IAllocator { void* Alloc(size_t size) override; void* Reserve(size_t size) override; - void* TensorAlloc(MLDataType element_data_type, const TensorShape& shape) override; void Free(void* p) override; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index eb8ad28f0a146..d55fd34d5a8f2 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -231,7 +231,6 @@ struct ProviderHostImpl : ProviderHost { AllocatorPtr CreateAllocator(const AllocatorCreationInfo& info) override { return onnxruntime::CreateAllocator(info); } std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info) override { return std::make_unique(memory_info); }; - void* IAllocator__TensorAlloc(IAllocator* p, MLDataType element_data_type, const TensorShape& shape) override { return p->IAllocator::TensorAlloc(element_data_type, shape); } void* CPUAllocator__Alloc(CPUAllocator* p, size_t size) override { return p->CPUAllocator::Alloc(size); } void CPUAllocator__Free(CPUAllocator* p, void* allocation) override { return p->CPUAllocator::Free(allocation); } From e4f72b36b12d7403783d2d198d0e6168fc35a49a Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:22:58 -0800 Subject: [PATCH 16/53] document IAllocator::Free --- include/onnxruntime/core/framework/allocator.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 525277375830c..523d2a9d1a8be 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -82,6 +82,10 @@ class IAllocator { */ virtual void* Alloc(size_t size) = 0; + /** + * Free memory at p. + * If p is nullptr, do nothing. + */ virtual void Free(void* p) = 0; // Reserve() is an interface exposed for an implementation of IAllocator From 39ff9012cc2b86c77e17aa4f3a56726dba9189bf Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 9 Dec 2024 17:25:04 -0800 Subject: [PATCH 17/53] remove IAllocator__TensorAlloc --- onnxruntime/core/providers/shared_library/provider_interfaces.h | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index ae75ad7d55131..f9f2bb69a9d1a 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -158,7 +158,6 @@ struct ProviderHost { virtual std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info) = 0; - virtual void* IAllocator__TensorAlloc(IAllocator* p, MLDataType element_data_type, const TensorShape& shape) = 0; virtual void* CPUAllocator__Alloc(CPUAllocator* p, size_t size) = 0; virtual void CPUAllocator__Free(CPUAllocator* p, void* allocation) = 0; From d70db84e33e7e73898d8d482340e5c78a63c8d97 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 10 Dec 2024 11:15:18 -0800 Subject: [PATCH 18/53] fix android build warning --- onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 29a735159f398..9ef5db78af0eb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1657,7 +1657,7 @@ Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) { Status QnnBackendManager::ReleaseQnnContextMemHandles() { // remove outstanding allocation clean up callbacks for (auto& [context_handle, context_mem_handle_record] : context_mem_handles_) { - for (const auto [shared_memory_address, idx] : + for (const auto& [shared_memory_address, idx] : context_mem_handle_record.outstanding_allocation_clean_up_callbacks) { ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::RemoveAllocationCleanUp(shared_memory_address, idx, /* allocation_clean_up */ nullptr)); From 45ef88371923f4597f9e897d36137f59aa3738e8 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 11 Dec 2024 11:05:21 -0800 Subject: [PATCH 19/53] remove shared mem handles from shared context --- .../core/providers/qnn/shared_context.h | 38 ------------------- 1 file changed, 38 deletions(-) diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index 4ce4aa15029a3..fdd3e411e0b7e 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -1,13 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include #include #include #include -#include - #include "core/common/common.h" #include "core/providers/qnn/builder/qnn_model.h" @@ -15,36 +12,6 @@ namespace onnxruntime { -class SharedMemHandles { - public: - Qnn_MemHandle_t Get(const void* addr) { - std::lock_guard g{mutex_}; - const auto it = qnn_mem_handles_.find(addr); - ORT_ENFORCE(it != qnn_mem_handles_.end(), "Failed to find mem handle associated with address (", addr, ")."); - return it->second; - } - - void Add(const void* addr, Qnn_MemHandle_t mem_handle) { - std::lock_guard g{mutex_}; - auto [it, added] = qnn_mem_handles_.emplace(addr, mem_handle); - ORT_ENFORCE(added, - "There is already a mem handle (", mem_handle, ") associated with the address (", addr, ")."); - } - - Qnn_MemHandle_t GetAndRemove(const void* addr) { - std::lock_guard g{mutex_}; - const auto it = qnn_mem_handles_.find(addr); - ORT_ENFORCE(it != qnn_mem_handles_.end(), "Failed to find mem handle associated with address (", addr, ")."); - const auto qnn_mem_handle = it->second; - qnn_mem_handles_.erase(it); - return qnn_mem_handle; - } - - private: - std::unordered_map qnn_mem_handles_; - std::mutex mutex_; -}; - class SharedContext { public: static SharedContext& GetInstance() { @@ -94,8 +61,6 @@ class SharedContext { return graph_exist; } - SharedMemHandles& GetSharedMemHandles() { return shared_mem_handles_; } - private: SharedContext() = default; ~SharedContext() = default; @@ -106,9 +71,6 @@ class SharedContext { // Producer sessions can be in parallel // Consumer sessions have to be after producer sessions initialized std::mutex mtx_; - - // TODO can we avoid keeping mem handles in SharedContext? - SharedMemHandles shared_mem_handles_; }; } // namespace onnxruntime From d2e7b3c5a88656e54cea08336be545f127251f9b Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 16 Dec 2024 15:17:17 -0800 Subject: [PATCH 20/53] remove allocation clean up callback removal, use weak_ptrs in allocation clean up callback --- .../qnn/builder/qnn_backend_manager.cc | 57 ++++++++----------- .../qnn/builder/qnn_backend_manager.h | 12 ++-- .../core/providers/qnn/qnn_allocator.cc | 45 +-------------- .../core/providers/qnn/qnn_allocator.h | 17 +----- .../providers/qnn/qnn_execution_provider.h | 4 +- 5 files changed, 38 insertions(+), 97 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 9ef5db78af0eb..ab1dcc299709e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -554,7 +554,7 @@ Status QnnBackendManager::CreateContext() { ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result)); - ORT_RETURN_IF_ERROR(AddQnnContext(context)); // TODO use RAII type for context handle? + ORT_RETURN_IF_ERROR(AddQnnContext(context)); context_created_ = true; return Status::OK(); @@ -565,7 +565,8 @@ Status QnnBackendManager::ReleaseContext() { return Status::OK(); } - ORT_RETURN_IF_ERROR(ReleaseQnnContextMemHandles()); + // release context mem handles + context_mem_handles_.clear(); bool failed = false; for (auto context : contexts_) { @@ -1644,9 +1645,8 @@ void* QnnBackendManager::LibFunction(void* handle, const char* symbol, std::stri Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) { ORT_RETURN_IF(logger_ == nullptr, "logger_ should be set."); - auto mem_handle_manager = std::make_unique(GetQnnInterface(), context, *logger_); - auto mem_handle_record = ContextMemHandleRecord{std::move(mem_handle_manager), {}}; - const bool inserted = context_mem_handles_.try_emplace(context, std::move(mem_handle_record)).second; + auto mem_handle_manager = std::make_shared(GetQnnInterface(), context, *logger_); + const bool inserted = context_mem_handles_.try_emplace(context, std::move(mem_handle_manager)).second; ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", context); contexts_.push_back(context); @@ -1654,50 +1654,43 @@ Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) { return Status::OK(); } -Status QnnBackendManager::ReleaseQnnContextMemHandles() { - // remove outstanding allocation clean up callbacks - for (auto& [context_handle, context_mem_handle_record] : context_mem_handles_) { - for (const auto& [shared_memory_address, idx] : - context_mem_handle_record.outstanding_allocation_clean_up_callbacks) { - ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::RemoveAllocationCleanUp(shared_memory_address, idx, - /* allocation_clean_up */ nullptr)); - } - } - - context_mem_handles_.clear(); - - return Status::OK(); -} - Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t& mem_handle) { const auto context_mem_handles_it = context_mem_handles_.find(context); ORT_RETURN_IF_NOT(context_mem_handles_it != context_mem_handles_.end(), "QNN context not found: ", context); - auto& context_mem_handle_record = context_mem_handles_it->second; - auto& context_mem_handle_manager = *context_mem_handle_record.mem_handle_manager; + auto& context_mem_handle_manager = context_mem_handles_it->second; bool did_register{}; - ORT_RETURN_IF_ERROR(context_mem_handle_manager.GetOrRegister(shared_memory_address, qnn_tensor, - mem_handle, did_register)); + ORT_RETURN_IF_ERROR(context_mem_handle_manager->GetOrRegister(shared_memory_address, qnn_tensor, + mem_handle, did_register)); if (did_register) { HtpSharedMemoryAllocator::AllocationCleanUpFn allocation_clean_up = - [&logger = *logger_, &context_mem_handle_manager](void* shared_memory_address) { - auto unregister_status = context_mem_handle_manager.Unregister(shared_memory_address); + [&logger = *logger_, + weak_backend_manager = weak_from_this(), + weak_context_mem_handle_manager = std::weak_ptr{context_mem_handle_manager}]( + void* shared_memory_address) { + // get QnnBackendManager shared_ptr to ensure that qnn_interface is still valid + auto backend_manager = weak_backend_manager.lock(); + if (!backend_manager) { + return; + } + + auto context_mem_handle_manager = weak_context_mem_handle_manager.lock(); + if (!context_mem_handle_manager) { + return; + } + + auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address); if (!unregister_status.IsOK()) { LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: " << shared_memory_address << ", error: " << unregister_status.ErrorMessage(); } }; - size_t allocation_clean_up_idx{}; ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::AddAllocationCleanUp(shared_memory_address, - std::move(allocation_clean_up), - allocation_clean_up_idx)); - - context_mem_handle_record.outstanding_allocation_clean_up_callbacks.emplace_back(shared_memory_address, - allocation_clean_up_idx); + std::move(allocation_clean_up))); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 728d9e2fcddd1..cddeffd21f32e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -32,7 +32,7 @@ namespace qnn { class QnnModel; -class QnnBackendManager { +class QnnBackendManager : public std::enable_shared_from_this { public: QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level_etw, @@ -261,13 +261,9 @@ class QnnBackendManager { Qnn_LogHandle_t log_handle_ = nullptr; Qnn_DeviceHandle_t device_handle_ = nullptr; std::vector contexts_; - - struct ContextMemHandleRecord { - std::unique_ptr mem_handle_manager; - InlinedVector> outstanding_allocation_clean_up_callbacks; - }; - - std::unordered_map context_mem_handles_; + // Note: Using shared_ptr so that we can refer to it with a weak_ptr from a + // HtpSharedMemoryAllocator allocation cleanup callback. + std::unordered_map> context_mem_handles_; ProfilingLevel profiling_level_etw_; ProfilingLevel profiling_level_; ProfilingLevel profiling_level_merge_; diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index d06c4b95584e4..a013cf627b829 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -185,21 +185,10 @@ Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(void* allocation_ } Status HtpSharedMemoryAllocator::AddAllocationCleanUp(void* allocation_address, - AllocationCleanUpFn&& allocation_clean_up, - size_t& allocation_clean_up_idx) { + AllocationCleanUpFn&& allocation_clean_up) { auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); return allocation_header.allocator_ptr->AddAllocationCleanUpForThisAllocator(allocation_address, - std::move(allocation_clean_up), - allocation_clean_up_idx); -} - -Status HtpSharedMemoryAllocator::RemoveAllocationCleanUp(void* allocation_address, - size_t allocation_clean_up_idx, - AllocationCleanUpFn* allocation_clean_up) { - auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); - return allocation_header.allocator_ptr->RemoveAllocationCleanUpForThisAllocator(allocation_address, - allocation_clean_up_idx, - allocation_clean_up); + std::move(allocation_clean_up)); } Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, @@ -214,8 +203,7 @@ Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfoForThisAllocator(v } Status HtpSharedMemoryAllocator::AddAllocationCleanUpForThisAllocator(void* allocation_address, - AllocationCleanUpFn&& allocation_clean_up, - size_t& allocation_clean_up_idx) { + AllocationCleanUpFn&& allocation_clean_up) { ORT_RETURN_IF(allocation_clean_up == nullptr, "allocation_clean_up should not be empty."); std::scoped_lock g{allocations_mutex_}; @@ -225,33 +213,6 @@ Status HtpSharedMemoryAllocator::AddAllocationCleanUpForThisAllocator(void* allo auto& clean_up_fns = allocation_infos_it->second.clean_up_fns; clean_up_fns.emplace_back(std::move(allocation_clean_up)); - allocation_clean_up_idx = clean_up_fns.size() - 1; - return Status::OK(); -} - -Status HtpSharedMemoryAllocator::RemoveAllocationCleanUpForThisAllocator(void* allocation_address, - size_t allocation_clean_up_idx, - AllocationCleanUpFn* allocation_clean_up) { - std::scoped_lock g{allocations_mutex_}; - const auto allocation_infos_it = allocations_.find(allocation_address); - ORT_RETURN_IF(allocation_infos_it == allocations_.end(), - "Failed to get allocation info for address (", allocation_address, ")."); - - auto& clean_up_fns = allocation_infos_it->second.clean_up_fns; - ORT_RETURN_IF_NOT(allocation_clean_up_idx < clean_up_fns.size(), - "Invalid allocation_clean_up_idx: ", allocation_clean_up_idx); - - AllocationCleanUpFn& clean_up_fn = clean_up_fns[allocation_clean_up_idx]; - ORT_RETURN_IF(clean_up_fn == nullptr, - "Allocation clean up has already been removed at allocation_clean_up_idx: ", allocation_clean_up_idx); - - AllocationCleanUpFn removed_clean_up_fn = nullptr; - removed_clean_up_fn.swap(clean_up_fn); - - if (allocation_clean_up != nullptr) { - *allocation_clean_up = std::move(removed_clean_up_fn); - } - return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index c7619657c92d1..0436362b20154 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -45,24 +45,13 @@ class HtpSharedMemoryAllocator : public IAllocator { // Add allocation clean up callback to call when the allocation is freed. // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been freed. // `allocation_clean_up` is the clean up callback. This call takes ownership. - // `allocation_clean_up_idx` identifies this clean up callback. It can be passed to RemoveAllocationCleanUp() to remove this callback later. - static Status AddAllocationCleanUp(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up, - size_t& allocation_clean_up_idx); - - // Remove allocation clean up callback that was previously added. - // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been freed. - // `allocation_clean_up_idx` identifies this clean up callback. - // `allocation_clean_up` is optional and, if provided, will contain the removed allocation clean up callback. - static Status RemoveAllocationCleanUp(void* allocation_address, size_t allocation_clean_up_idx, - AllocationCleanUpFn* allocation_clean_up); + static Status AddAllocationCleanUp(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); private: Status GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, SharedMemoryInfo& allocation_info); - Status AddAllocationCleanUpForThisAllocator(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up, - size_t& allocation_clean_up_idx); - Status RemoveAllocationCleanUpForThisAllocator(void* allocation_address, size_t allocation_clean_up_idx, - AllocationCleanUpFn* allocation_clean_up); + + Status AddAllocationCleanUpForThisAllocator(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); struct AllocationRecord { SharedMemoryInfo shared_memory_info; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 9695a64cdd109..317b34e66a6e4 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -79,7 +79,9 @@ class QNNExecutionProvider : public IExecutionProvider { private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; - std::unique_ptr qnn_backend_manager_; + // Note: Using shared_ptr so that we can refer to it with a weak_ptr from a + // HtpSharedMemoryAllocator allocation cleanup callback. + std::shared_ptr qnn_backend_manager_; std::unordered_map> qnn_models_; bool context_cache_enabled_ = false; std::string context_cache_path_cfg_ = ""; From c892c18ee886d14cb911ba0ffcbfbc9f1378fc8a Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:59:22 -0800 Subject: [PATCH 21/53] some clean up --- .../core/providers/qnn/qnn_allocator.cc | 46 ++++++++++++------- .../core/providers/qnn/qnn_allocator.h | 18 +++++--- 2 files changed, 41 insertions(+), 23 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index a013cf627b829..29b2cd6682fe2 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -45,7 +45,7 @@ size_t DivRoundUp(size_t a, size_t b) { // TODO is there already a helper funct } bool IsAligned(const void* address, size_t alignment) { - assert((alignment & alignment - 1) == 0); + assert((alignment & alignment - 1) == 0); // alignment must be a power of two return (reinterpret_cast(address) & (alignment - 1)) == 0; } @@ -87,9 +87,11 @@ OrtMemoryInfo HtpSharedMemoryAllocator::AssociatedMemoryInfo() { /* id */ 0, OrtMemTypeDefault}; } -HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib) +HtpSharedMemoryAllocator::HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, + const logging::Logger* logger) : IAllocator{AssociatedMemoryInfo()}, - rpcmem_lib_{std::move(rpcmem_lib)} { + rpcmem_lib_{std::move(rpcmem_lib)}, + logger_(logger != nullptr ? *logger : logging::LoggingManager::DefaultLogger()) { ORT_ENFORCE(rpcmem_lib_ != nullptr); } @@ -106,7 +108,7 @@ void* HtpSharedMemoryAllocator::Alloc(size_t requested_size) { // allocate shared memory void* shared_memory_raw = rpcmem_lib_->Api().alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, static_cast(shared_memory_block_size_in_bytes)); - + ORT_ENFORCE(shared_memory_raw != nullptr, "rpcmem_alloc() failed to allocate and returned nullptr."); auto shared_memory = WrapSharedMemoryWithUniquePtr(shared_memory_raw, rpcmem_lib_->Api()); const size_t allocation_alignment = AllocationAlignment(); @@ -132,7 +134,7 @@ void* HtpSharedMemoryAllocator::Alloc(size_t requested_size) { std::scoped_lock g{allocations_mutex_}; const bool inserted = allocations_.emplace(allocation_address, std::move(allocation_record)).second; - ORT_ENFORCE(inserted, "Allocation info already exists for address (", allocation_address, ")."); + ORT_ENFORCE(inserted, "Allocation record already exists for address (", allocation_address, ")."); } // initialize header @@ -150,8 +152,6 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { return; } - // TODO should we throw exceptions at all from Free()? - auto& allocation_header = ValidateAllocationAddressAndGetHeader(allocation_address); ORT_ENFORCE(allocation_header.allocator_ptr == this, "AllocationHeader points to a different allocator (", allocation_header.allocator_ptr, @@ -164,16 +164,28 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { ORT_ENFORCE(!allocation_node.empty(), "Failed to get allocation info for address (", allocation_address, ")."); - // take ownership of shared memory and free at end of scope - auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); - - // destroy header - allocation_header.~AllocationHeader(); - - // clean up allocation record - const auto& allocation_info = allocation_node.mapped(); - for (auto& clean_up_fn : allocation_info.clean_up_fns) { - clean_up_fn(allocation_address); // TODO handle exceptions? + // At this point, we have a valid allocation to free. + // Avoid throwing exceptions as this may be running from a destructor. + try { + // take ownership of shared memory and free at end of scope + auto shared_memory = WrapSharedMemoryWithUniquePtr(allocation_address, rpcmem_lib_->Api()); + + // destroy header + allocation_header.~AllocationHeader(); + + // clean up allocation record + const auto& allocation_record = allocation_node.mapped(); + for (auto& clean_up_fn : allocation_record.clean_up_fns) { + // attempt to run each clean_up_fn even if exceptions are thrown + try { + clean_up_fn(allocation_address); + } catch (const std::exception& e) { + LOGS(logger_, ERROR) << "Caught exception while running clean up callback for address (" << allocation_address + << "): " << e.what(); + } + } + } catch(const std::exception& e) { + LOGS(logger_, ERROR) << "Caught exception while freeing address (" << allocation_address << "): " << e.what(); } } diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index 0436362b20154..5b854a70fc00f 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -8,6 +8,7 @@ #include "core/common/common.h" #include "core/common/inlined_containers.h" +#include "core/common/logging/logging.h" #include "core/common/status.h" #include "core/framework/allocator.h" #include "core/providers/qnn/rpcmem_library.h" @@ -19,7 +20,8 @@ class HtpSharedMemoryAllocator : public IAllocator { // Gets the OrtMemoryInfo value that is associated with this allocator type. static OrtMemoryInfo AssociatedMemoryInfo(); - HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib); + HtpSharedMemoryAllocator(std::shared_ptr rpcmem_lib, + const logging::Logger* logger = nullptr); ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(HtpSharedMemoryAllocator); @@ -27,7 +29,7 @@ class HtpSharedMemoryAllocator : public IAllocator { void* Alloc(size_t size) override; void Free(void* p) override; - // void GetStats(AllocatorStats* stats) override; + // void GetStats(AllocatorStats* stats) override; // TODO override struct SharedMemoryInfo { int fd; @@ -36,15 +38,17 @@ class HtpSharedMemoryAllocator : public IAllocator { }; // Get an allocation's shared memory info. - // `allocation_address` must be an address returned by Alloc() which has not yet been freed. + // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been + // freed. static Status GetAllocationSharedMemoryInfo(void* allocation_address, SharedMemoryInfo& allocation_info); using AllocationCleanUpFn = std::function; // Add allocation clean up callback to call when the allocation is freed. - // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been freed. - // `allocation_clean_up` is the clean up callback. This call takes ownership. + // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been + // freed. + // `allocation_clean_up` is the clean up callback. The associated allocator takes ownership of the callback. static Status AddAllocationCleanUp(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); private: @@ -60,9 +64,11 @@ class HtpSharedMemoryAllocator : public IAllocator { // allocation address -> corresponding allocation record InlinedHashMap allocations_; - std::mutex allocations_mutex_; // synchronize access to allocation_ + std::mutex allocations_mutex_; // synchronize access to allocations_ std::shared_ptr rpcmem_lib_; + + const logging::Logger& logger_; }; } // namespace onnxruntime::qnn From b295eef03a612ea370cd244043810b463c8ae3d7 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:54:04 -0800 Subject: [PATCH 22/53] more clean up --- onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc | 4 ++++ .../providers/qnn/builder/qnn_context_mem_handle_manager.cc | 1 - 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index ab1dcc299709e..d65e4631921dc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1682,6 +1682,10 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont return; } + // TODO should also ensure that the QNN context handle is still valid. + // This *should* be true as long as the QNN contexts are not freed from anywhere other than + // ~QnnBackendManager(). If we are able to lock weak_backend_manager, we haven't gotten to the dtor yet. + auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address); if (!unregister_status.IsOK()) { LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: " diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc index de77b309c0105..18be779f50910 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc @@ -91,7 +91,6 @@ Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, co const auto unregister_result = qnn_interface_.memDeRegister(&raw_mem_handle, 1); if (unregister_result != QNN_SUCCESS) { LOGS(logger_, ERROR) << "qnn_interface.memDeRegister() failed: " << unregister_result; - return; } }; From 13f5e30883f665059587f588b6e00a2b502ebfa0 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 16 Dec 2024 19:28:53 -0800 Subject: [PATCH 23/53] add helper to get qnn error message --- .../qnn/builder/qnn_backend_manager.cc | 8 ++----- .../builder/qnn_context_mem_handle_manager.cc | 6 ++++-- .../core/providers/qnn/builder/qnn_utils.cc | 21 +++++++++++++++++++ .../core/providers/qnn/builder/qnn_utils.h | 10 +++++++++ 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index d65e4631921dc..98576c5903eda 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -22,6 +22,7 @@ #include "core/providers/qnn/qnn_allocator.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "core/providers/qnn/builder/qnn_utils.h" #ifdef _WIN32 #include @@ -1404,12 +1405,7 @@ const char* QnnBackendManager::QnnProfileErrorToString(QnnProfile_Error_t error) } const char* QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) { - // From QNN SDK: The memory is statically owned and should not be freed by the caller. - const char* error_msg = nullptr; - if (QNN_SUCCESS == qnn_interface_.errorGetMessage(error, &error_msg)) { - return error_msg; - } - return "Unknown"; + return utils::GetQnnErrorMessage(qnn_interface_, error); } const std::string QnnBackendManager::ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc index 18be779f50910..73d433942b575 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc @@ -81,7 +81,8 @@ Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, co Qnn_MemHandle_t raw_mem_handle{}; const auto register_result = qnn_interface_.memRegister(context_, &mem_descriptor, 1, &raw_mem_handle); ORT_RETURN_IF_NOT(register_result == QNN_SUCCESS, - "qnn_interface.memRegister() failed: ", register_result); // TODO get error message + "qnn_interface.memRegister() failed: ", + utils::GetVerboseQnnErrorMessage(qnn_interface_, register_result)); LOGS(logger_, VERBOSE) << "Registered QNN mem handle. mem_handle: " << raw_mem_handle; @@ -90,7 +91,8 @@ Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, co const auto unregister_result = qnn_interface_.memDeRegister(&raw_mem_handle, 1); if (unregister_result != QNN_SUCCESS) { - LOGS(logger_, ERROR) << "qnn_interface.memDeRegister() failed: " << unregister_result; + LOGS(logger_, ERROR) << "qnn_interface.memDeRegister() failed: " + << utils::GetVerboseQnnErrorMessage(qnn_interface_, unregister_result); } }; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 39b18ccc55fb7..ad6f48a6d2c48 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -578,6 +578,27 @@ Status Quantize(const double double_value, return Status::OK(); } +const char* GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, Qnn_ErrorHandle_t qnn_error_handle) { + // From QNN SDK: The memory is statically owned and should not be freed by the caller. + const char* error_msg = nullptr; + if (qnn_interface.errorGetMessage(qnn_error_handle, &error_msg) == QNN_SUCCESS) { + return error_msg; + } + return "Unknown error."; +} + +std::string GetVerboseQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ErrorHandle_t qnn_error_handle) { + const char* error_msg = nullptr; + if (qnn_interface.errorGetVerboseMessage(qnn_error_handle, &error_msg) == QNN_SUCCESS) { + auto free_error_msg = gsl::finally([&qnn_interface, error_msg] { + qnn_interface.errorFreeVerboseMessage(error_msg); + }); + return error_msg; + } + return "Unknown error."; +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index ac299706b8588..e07ee64ce33bd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -10,7 +10,9 @@ #include +#include "QnnInterface.h" #include "QnnTypes.h" + #include "core/session/onnxruntime_cxx_api.h" #include "core/framework/node_unit.h" #include "core/util/qmath.h" @@ -109,6 +111,14 @@ Status Quantize(const double double_value, const Qnn_DataType_t qnn_data_type, int& quant_value); +// Gets error message associated with QNN error handle value. +const char* GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ErrorHandle_t qnn_error_handle); + +// Gets verbose error message associated with QNN error handle value. +std::string GetVerboseQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ErrorHandle_t qnn_error_handle); + } // namespace utils } // namespace qnn } // namespace onnxruntime From d5eace13bf9e4148e6d6758b84c127c8a5674097 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 17 Dec 2024 13:54:38 -0800 Subject: [PATCH 24/53] use make_shared for QnnBackendManager --- onnxruntime/core/providers/qnn/qnn_execution_provider.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index eab82768a1f0f..e14f6fb8aba57 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -400,7 +400,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio rpcmem_library_ = std::make_shared(); } - qnn_backend_manager_ = std::make_unique( + qnn_backend_manager_ = std::make_shared( std::move(backend_path), profiling_level_etw, profiling_level, From bacbcdc1f0133dd82a0c0c201608fe5ad5ff6932 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 17 Dec 2024 15:27:59 -0800 Subject: [PATCH 25/53] add test to qnn_basic_test.cc, document allocator parameter. --- .../test/providers/qnn/qnn_basic_test.cc | 32 +++++++++++++++++-- .../test/providers/qnn/qnn_test_utils.h | 32 +++++++++++-------- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index e8282dbad9f72..9084ec70fbd6c 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -5,11 +5,12 @@ #include #include +#include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU +#include "core/providers/qnn/qnn_allocator.h" +#include "core/session/inference_session.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/onnxruntime_run_options_config_keys.h" -#include "core/providers/cpu/cpu_provider_factory.h" // For OrtSessionOptionsAppendExecutionProvider_CPU -#include "core/session/inference_session.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -1098,6 +1099,33 @@ TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { } } +TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + provider_options["enable_htp_shared_memory_allocator"] = "1"; + + AllocatorPtr htp_shared_memory_allocator{}; + { + auto allocators = QnnExecutionProviderWithOptions(provider_options)->CreatePreferredAllocators(); + ASSERT_FALSE(allocators.empty()); + auto& allocator = allocators[0]; + ASSERT_EQ(allocator->Info(), qnn::HtpSharedMemoryAllocator::AssociatedMemoryInfo()); + htp_shared_memory_allocator = std::move(allocator); + } + + auto input_defs = {TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f), + TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f)}; + RunQnnModelTest(BuildOpTestCase("Add", input_defs, {}, {}, kOnnxDomain, htp_shared_memory_allocator), + provider_options, + 13, + ExpectedEPNodeAssignment::All, + 0.008f); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 6c8ae5392bee4..676460e108b0e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -901,11 +901,12 @@ inline void TestFp16ModelAccuracy(const GetTestModelFn& f32_model_fn, * * \param builder Model builder object used to build the model's inputs, outputs, and nodes. * \param input_def Input definition that describes what kind of input to create. + * \param allocator Optional allocator to use to allocate the input ORT value. * \return A pointer to the new input. */ template inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def, - AllocatorPtr io_allocator = nullptr) { + AllocatorPtr allocator = nullptr) { NodeArg* input = nullptr; const auto& shape = input_def.GetShape(); const bool is_initializer = input_def.IsInitializer(); @@ -916,7 +917,7 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& if (is_initializer) { input = builder.MakeInitializer(shape, raw_data); } else { - input = builder.MakeInput(shape, raw_data, io_allocator); + input = builder.MakeInput(shape, raw_data, allocator); } } else { // Random data const auto& rand_info = input_def.GetRandomDataInfo(); @@ -924,7 +925,7 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& if (is_initializer) { input = builder.MakeInitializer(shape, rand_info.min, rand_info.max); } else { - input = builder.MakeInput(shape, rand_info.min, rand_info.max, io_allocator); + input = builder.MakeInput(shape, rand_info.min, rand_info.max, allocator); } } @@ -933,7 +934,7 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& template <> inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef& input_def, - AllocatorPtr io_allocator) { + AllocatorPtr allocator) { NodeArg* input = nullptr; const auto& shape = input_def.GetShape(); const bool is_initializer = input_def.IsInitializer(); @@ -944,13 +945,13 @@ inline NodeArg* MakeTestInput(ModelTestBuilder& builder, const TestInputDef(shape, raw_data, io_allocator); + input = builder.MakeInput(shape, raw_data, allocator); } } else { // Random data if (is_initializer) { input = builder.MakeRandInitializerBool(shape); } else { - input = builder.MakeInputBool(shape, io_allocator); + input = builder.MakeInputBool(shape, allocator); } } @@ -975,6 +976,7 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef @@ -983,18 +985,18 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, const std::vector>& input_defs_2, const std::vector& attrs, const std::string& op_domain = kOnnxDomain, - AllocatorPtr io_allocator = nullptr) { - return [op_type, input_defs_1, input_defs_2, attrs, op_domain, io_allocator](ModelTestBuilder& builder) { + AllocatorPtr input_allocator = nullptr) { + return [op_type, input_defs_1, input_defs_2, attrs, op_domain, input_allocator](ModelTestBuilder& builder) { std::vector op_inputs; op_inputs.reserve(input_defs_1.size() + input_defs_2.size()); for (const auto& input_def : input_defs_1) { - NodeArg* input = MakeTestInput(builder, input_def, io_allocator); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); op_inputs.push_back(input); } for (const auto& input_def : input_defs_2) { - NodeArg* input = MakeTestInput(builder, input_def, io_allocator); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); op_inputs.push_back(input); } @@ -1015,6 +1017,8 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, * \param input_defs List of input definitions. * \param attrs List of operator attributes. * \param op_domain The operator's domain. Defaults to the ONNX domain (i.e., ""). + * \param use_contrib_qdq Whether to use Q/DQ ops from the MS domain instead of the ONNX domain. + * \param input_allocator Optional allocator to use to allocate input ORT values. * \returns A model building function. */ template @@ -1025,16 +1029,16 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( const std::vector& attrs, const std::string& op_domain = kOnnxDomain, bool use_contrib_qdq = false, - AllocatorPtr io_allocator = nullptr) { + AllocatorPtr input_allocator = nullptr) { return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, - use_contrib_qdq, io_allocator]( + use_contrib_qdq, input_allocator]( ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size()); // Create QDQ inputs for (const auto& input_def : quant_input_defs) { - NodeArg* input = MakeTestInput(builder, input_def, io_allocator); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); QuantParams input_qparams = GetTestInputQuantParams(input_def); NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, use_contrib_qdq); @@ -1043,7 +1047,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase( // Create non-QDQ inputs for (const auto& input_def : non_quant_input_defs) { - NodeArg* input = MakeTestInput(builder, input_def, io_allocator); + NodeArg* input = MakeTestInput(builder, input_def, input_allocator); op_inputs.push_back(input); } From b29ab6106cbd3530be45a7e4be9a8ffd75e98619 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:02:59 -0800 Subject: [PATCH 26/53] rename variables --- onnxruntime/core/providers/qnn/qnn_allocator.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 29b2cd6682fe2..65ca0b9c9efb3 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -206,11 +206,11 @@ Status HtpSharedMemoryAllocator::AddAllocationCleanUp(void* allocation_address, Status HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfoForThisAllocator(void* allocation_address, SharedMemoryInfo& allocation_info) { std::scoped_lock g{allocations_mutex_}; - const auto allocation_infos_it = allocations_.find(allocation_address); - ORT_RETURN_IF(allocation_infos_it == allocations_.end(), + const auto allocation_it = allocations_.find(allocation_address); + ORT_RETURN_IF(allocation_it == allocations_.end(), "Failed to get allocation info for address (", allocation_address, ")."); - allocation_info = allocation_infos_it->second.shared_memory_info; + allocation_info = allocation_it->second.shared_memory_info; return Status::OK(); } @@ -219,11 +219,11 @@ Status HtpSharedMemoryAllocator::AddAllocationCleanUpForThisAllocator(void* allo ORT_RETURN_IF(allocation_clean_up == nullptr, "allocation_clean_up should not be empty."); std::scoped_lock g{allocations_mutex_}; - const auto allocation_infos_it = allocations_.find(allocation_address); - ORT_RETURN_IF(allocation_infos_it == allocations_.end(), + const auto allocation_it = allocations_.find(allocation_address); + ORT_RETURN_IF(allocation_it == allocations_.end(), "Failed to get allocation info for address (", allocation_address, ")."); - auto& clean_up_fns = allocation_infos_it->second.clean_up_fns; + auto& clean_up_fns = allocation_it->second.clean_up_fns; clean_up_fns.emplace_back(std::move(allocation_clean_up)); return Status::OK(); } From 67a54b89c2fb43aca0c5fc7e9123cd937f1241c6 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 17 Dec 2024 17:05:04 -0800 Subject: [PATCH 27/53] revert changes to onnxruntime/test/providers/qnn/max_min_op_test.cc --- .../test/providers/qnn/max_min_op_test.cc | 37 ++----------------- 1 file changed, 4 insertions(+), 33 deletions(-) diff --git a/onnxruntime/test/providers/qnn/max_min_op_test.cc b/onnxruntime/test/providers/qnn/max_min_op_test.cc index 6e0f9f191cf47..3deff121f3c72 100644 --- a/onnxruntime/test/providers/qnn/max_min_op_test.cc +++ b/onnxruntime/test/providers/qnn/max_min_op_test.cc @@ -39,30 +39,20 @@ template static void RunQDQMinOrMaxOpTest(const std::string& op_type, const std::vector>& input_defs, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 13, - AllocatorPtr io_allocator = nullptr, - const ProviderOptions& extra_provider_options = {}) { + int opset = 13) { ProviderOptions provider_options; - if (!extra_provider_options.empty()) { - provider_options.insert(extra_provider_options.begin(), extra_provider_options.end()); - } - #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain, - io_allocator), // baseline float32 model - BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain, /* use_contrib_qdq*/ false, - io_allocator), // QDQ model + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // baseline float32 model + BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // QDQ model provider_options, opset, - expected_ep_assignment, - {}, - logging::Severity::kVERBOSE); + expected_ep_assignment); } // @@ -138,25 +128,6 @@ TEST_F(QnnHTPBackendTests, Max_2Inputs) { ExpectedEPNodeAssignment::All, 13); } -// Test accuracy of 8-bit Q/DQ Min with 2 inputs on HTP backend. -TEST_F(QnnHTPBackendTests, Min_2Inputs_HtpSharedMemoryAllocator) { - ProviderOptions qnn_ep_options{ - {"enable_htp_shared_memory_allocator", "1"}, - {"backend_path", "libQnnHtp.so"}, - }; - - AllocatorPtr htp_shared_memory_allocator = - QnnExecutionProviderWithOptions(qnn_ep_options)->CreatePreferredAllocators()[0]; - - std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); - RunQDQMinOrMaxOpTest("Min", - {TestInputDef({1, 3, 4, 4}, false, input_data), - TestInputDef({1, 3, 4, 4}, false, input_data)}, - ExpectedEPNodeAssignment::All, 13, - htp_shared_memory_allocator, - qnn_ep_options); -} - #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test } // namespace onnxruntime From c0569e2259b716e40015ed85c66d5157d8dfa8fe Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:33:00 -0800 Subject: [PATCH 28/53] fix formatting --- onnxruntime/core/providers/qnn/qnn_allocator.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 65ca0b9c9efb3..84d67615ff3c9 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -184,7 +184,7 @@ void HtpSharedMemoryAllocator::Free(void* allocation_address) { << "): " << e.what(); } } - } catch(const std::exception& e) { + } catch (const std::exception& e) { LOGS(logger_, ERROR) << "Caught exception while freeing address (" << allocation_address << "): " << e.what(); } } From dd45c84b6528a06a7a4992fcba7f1f0926c93ad8 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:33:31 -0800 Subject: [PATCH 29/53] skip test if not android and not windows --- onnxruntime/test/providers/qnn/qnn_basic_test.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 9084ec70fbd6c..90ddf6b7a6ade 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1100,6 +1100,11 @@ TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { } TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { +#if !defined(__ANDROID__) && !defined(_WIN32) + // TODO there's probably a better way to check that we are on a Qualcomm device + GTEST_SKIP() << "Test should be run on Qualcomm device."; +#endif + ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; From 959d8df03948741d2f91472ce12513c7d7b90a28 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 18 Dec 2024 17:36:10 -0800 Subject: [PATCH 30/53] update comment --- onnxruntime/test/providers/qnn/qnn_basic_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 90ddf6b7a6ade..ed21ebbccc923 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1102,7 +1102,7 @@ TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { #if !defined(__ANDROID__) && !defined(_WIN32) // TODO there's probably a better way to check that we are on a Qualcomm device - GTEST_SKIP() << "Test should be run on Qualcomm device."; + GTEST_SKIP() << "Test is only supported on a Qualcomm device."; #endif ProviderOptions provider_options; From ab48516be5f1fc08de7986964af782d129fedd89 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 18 Dec 2024 18:08:17 -0800 Subject: [PATCH 31/53] remove QnnBackendManager::ReleaseQnnContextMemHandles declaration, update comments --- .../core/providers/qnn/builder/qnn_backend_manager.h | 1 - onnxruntime/core/providers/qnn/qnn_allocator.h | 8 ++++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index cddeffd21f32e..45a8f47a38f1e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -246,7 +246,6 @@ class QnnBackendManager : public std::enable_shared_from_this #endif Status AddQnnContext(Qnn_ContextHandle_t context); - Status ReleaseQnnContextMemHandles(); private: const std::string backend_path_; diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.h b/onnxruntime/core/providers/qnn/qnn_allocator.h index 5b854a70fc00f..f642368697aae 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.h +++ b/onnxruntime/core/providers/qnn/qnn_allocator.h @@ -37,17 +37,17 @@ class HtpSharedMemoryAllocator : public IAllocator { uint64_t total_size; }; - // Get an allocation's shared memory info. + // Gets an allocation's shared memory info. // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been - // freed. + // freed. static Status GetAllocationSharedMemoryInfo(void* allocation_address, SharedMemoryInfo& allocation_info); using AllocationCleanUpFn = std::function; - // Add allocation clean up callback to call when the allocation is freed. + // Adds allocation clean up callback to call when the allocation is freed. // `allocation_address` identifies the allocation. It must be an address returned by Alloc() which has not yet been - // freed. + // freed. // `allocation_clean_up` is the clean up callback. The associated allocator takes ownership of the callback. static Status AddAllocationCleanUp(void* allocation_address, AllocationCleanUpFn&& allocation_clean_up); From 4a3f6c39ba1ae8e140be7eeb45444e01dd24435b Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 6 Jan 2025 10:21:48 -0800 Subject: [PATCH 32/53] add onnxruntime_c_api.h include to ortmemoryinfo.h --- include/onnxruntime/core/framework/ortmemoryinfo.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/onnxruntime/core/framework/ortmemoryinfo.h b/include/onnxruntime/core/framework/ortmemoryinfo.h index d060c6546ae27..82f581e994904 100644 --- a/include/onnxruntime/core/framework/ortmemoryinfo.h +++ b/include/onnxruntime/core/framework/ortmemoryinfo.h @@ -7,6 +7,7 @@ #include "core/common/hash_combine.h" #include "core/framework/ortdevice.h" +#include "core/session/onnxruntime_c_api.h" // for OrtMemType, OrtAllocatorType struct OrtMemoryInfo { OrtMemoryInfo() = default; // to allow default construction of Tensor From ff1254132ab9096080b8324725cba7de04ffcd2d Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 6 Jan 2025 11:41:20 -0800 Subject: [PATCH 33/53] rename GetQnnTensorDataSize to GetQnnTensorDataSizeInBytes --- .../providers/qnn/builder/qnn_context_mem_handle_manager.cc | 2 +- onnxruntime/core/providers/qnn/builder/qnn_model.cc | 4 ++-- onnxruntime/core/providers/qnn/builder/qnn_utils.cc | 2 +- onnxruntime/core/providers/qnn/builder/qnn_utils.h | 3 +-- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc index 73d433942b575..a1f28762be48f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc @@ -31,7 +31,7 @@ Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, co const auto qnn_tensor_data_type = GetQnnTensorDataType(qnn_tensor); const size_t qnn_tensor_data_size = - utils::GetQnnTensorDataSize(gsl::span{qnn_tensor_dims, size_t{qnn_tensor_rank}}, qnn_tensor_data_type); + utils::GetQnnTensorDataSizeInBytes(gsl::span{qnn_tensor_dims, size_t{qnn_tensor_rank}}, qnn_tensor_data_type); { std::scoped_lock g{mem_handles_mutex_}; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 0bbb046605604..9ca1b86de9e40 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -322,8 +322,8 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, qnn_tensor_infos.resize(tensor_count); for (auto& tensor_wrapper : tensor_wrappers) { - const size_t length = utils::GetQnnTensorDataSize(tensor_wrapper.GetTensorDims(), - tensor_wrapper.GetTensorDataType()); + const size_t length = utils::GetQnnTensorDataSizeInBytes(tensor_wrapper.GetTensorDims(), + tensor_wrapper.GetTensorDataType()); const auto& tensor_name = tensor_wrapper.GetName(); auto qnn_index = is_input ? GetGraphInputIndex(tensor_name) : GetOutputIndex(tensor_name); auto ort_index = is_input ? GetOrtInputIndex(tensor_name) : qnn_index; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index ad6f48a6d2c48..444dbafe2d4cc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -65,7 +65,7 @@ size_t GetElementSizeByType(ONNXTensorElementDataType elem_type) { return pos->second; } -size_t GetQnnTensorDataSize(gsl::span shape, Qnn_DataType_t element_type) { +size_t GetQnnTensorDataSizeInBytes(gsl::span shape, Qnn_DataType_t element_type) { ORT_ENFORCE(!shape.empty(), "Empty shape not allowed."); // TODO can we just treat empty shape as a scalar? SafeInt data_length = GetElementSizeByType(element_type); return std::accumulate(shape.begin(), shape.end(), data_length, std::multiplies<>{}); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index e07ee64ce33bd..08c20177821e6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -26,8 +26,7 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type); size_t GetElementSizeByType(ONNXTensorElementDataType elem_type); -// Gets tensor data size in bytes. -size_t GetQnnTensorDataSize(gsl::span shape, Qnn_DataType_t element_data_type); +size_t GetQnnTensorDataSizeInBytes(gsl::span shape, Qnn_DataType_t element_data_type); // TODO: make these work with Wrappers? std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param); From 5e6e103973cb7d1a64d5d7a54394bb51508e9780 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 6 Jan 2025 12:59:04 -0800 Subject: [PATCH 34/53] add QnnBackendManager::Create function to ensure shared_ptr usage --- .../qnn/builder/qnn_backend_manager.h | 56 ++++++++++++------- .../providers/qnn/qnn_execution_provider.cc | 22 ++++---- 2 files changed, 47 insertions(+), 31 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 0749c60a5e547..2b7cd97b3026c 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -32,29 +32,45 @@ namespace qnn { class QnnModel; +// configuration values for QnnBackendManager creation +struct QnnBackendManagerConfig { + std::string backend_path; + ProfilingLevel profiling_level_etw; + ProfilingLevel profiling_level; + std::string profiling_file_path; + ContextPriority context_priority; + std::string qnn_saver_path; + uint32_t device_id; + QnnHtpDevice_Arch_t htp_arch; + uint32_t soc_model; + bool enable_htp_weight_sharing; +}; + class QnnBackendManager : public std::enable_shared_from_this { + private: + // private tag to pass to constructor to ensure that constructor cannot be directly called externally + struct PrivateConstructorTag {}; + public: - QnnBackendManager(std::string&& backend_path, - ProfilingLevel profiling_level_etw, - ProfilingLevel profiling_level, - std::string&& profiling_file_path, - ContextPriority context_priority, - std::string&& qnn_saver_path, - uint32_t device_id, - QnnHtpDevice_Arch_t htp_arch, - uint32_t soc_model, - bool enable_htp_weight_sharing) - : backend_path_(backend_path), - profiling_level_etw_(profiling_level_etw), - profiling_level_(profiling_level), - profiling_file_path_(profiling_file_path), - context_priority_(context_priority), - qnn_saver_path_(qnn_saver_path), - device_id_(device_id), - htp_arch_(htp_arch), - soc_model_(soc_model), - enable_htp_weight_sharing_(enable_htp_weight_sharing) { + static std::shared_ptr Create(const QnnBackendManagerConfig& config) { + return std::make_shared(config, PrivateConstructorTag{}); } + + // Note: creation should be done via Create() + QnnBackendManager(const QnnBackendManagerConfig& config, PrivateConstructorTag) + : backend_path_(config.backend_path), + profiling_level_etw_(config.profiling_level_etw), + profiling_level_(config.profiling_level), + profiling_file_path_(config.profiling_file_path), + context_priority_(config.context_priority), + qnn_saver_path_(config.qnn_saver_path), + device_id_(config.device_id), + htp_arch_(config.htp_arch), + soc_model_(config.soc_model), + enable_htp_weight_sharing_(config.enable_htp_weight_sharing) { + } + + public: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); ~QnnBackendManager(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index cc9a26d361962..04e7d15d62fb3 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -400,17 +400,17 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio rpcmem_library_ = std::make_shared(); } - qnn_backend_manager_ = std::make_shared( - std::move(backend_path), - profiling_level_etw, - profiling_level, - std::move(profiling_file_path), - context_priority, - std::move(qnn_saver_path), - device_id_, - htp_arch, - soc_model, - enable_htp_weight_sharing); + qnn_backend_manager_ = qnn::QnnBackendManager::Create( + qnn::QnnBackendManagerConfig{backend_path, + profiling_level_etw, + profiling_level, + profiling_file_path, + context_priority, + qnn_saver_path, + device_id_, + htp_arch, + soc_model, + enable_htp_weight_sharing}); #ifdef _WIN32 auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); From 78e86cc78a35e6f8e348c4a20b951f0133d48142 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 6 Jan 2025 14:33:41 -0800 Subject: [PATCH 35/53] make some QnnBackendManager member functions private, update comment --- .../qnn/builder/qnn_backend_manager.cc | 17 +++++-- .../qnn/builder/qnn_backend_manager.h | 50 +++++++------------ 2 files changed, 30 insertions(+), 37 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 1695d50bd2388..33116d7d9f884 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -48,6 +48,14 @@ static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnSystemInterface_t* qnn_i return qnn_interface->systemApiVersion; } +static char* DlError() { +#ifdef _WIN32 + return ""; +#else + return ::dlerror(); +#endif +} + template Status QnnBackendManager::GetQnnInterfaceProvider(const char* lib_path, const char* interface_provider_name, @@ -1693,7 +1701,10 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont weak_backend_manager = weak_from_this(), weak_context_mem_handle_manager = std::weak_ptr{context_mem_handle_manager}]( void* shared_memory_address) { - // get QnnBackendManager shared_ptr to ensure that qnn_interface is still valid + // Get QnnBackendManager shared_ptr to ensure that: + // - QNN interface is still valid. + // - QNN context handle is still valid. This should be true as long as QNN contexts are not freed from + // anywhere other than the destructor. auto backend_manager = weak_backend_manager.lock(); if (!backend_manager) { return; @@ -1704,10 +1715,6 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont return; } - // TODO should also ensure that the QNN context handle is still valid. - // This *should* be true as long as the QNN contexts are not freed from anywhere other than - // ~QnnBackendManager(). If we are able to lock weak_backend_manager, we haven't gotten to the dtor yet. - auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address); if (!unregister_status.IsOK()) { LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: " diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 2b7cd97b3026c..1aec1e0765101 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -70,41 +70,9 @@ class QnnBackendManager : public std::enable_shared_from_this enable_htp_weight_sharing_(config.enable_htp_weight_sharing) { } - public: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); ~QnnBackendManager(); - char* DlError() { -#ifdef _WIN32 - return ""; -#else - return ::dlerror(); -#endif - } - - Status LoadBackend(); - - Status InitializeBackend(); - - Status CreateDevice(); - - Status ReleaseDevice(); - - Status ShutdownBackend(); - - Status InitializeProfiling(); - - Status ReleaseProfilehandle(); - - Status CreateContext(); - - Status ReleaseContext(); - - Status ResetContext() { - ORT_RETURN_IF_ERROR(ReleaseContext()); - - return CreateContext(); - } std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); @@ -170,6 +138,24 @@ class QnnBackendManager : public std::enable_shared_from_this Qnn_MemHandle_t& mem_handle); private: + Status LoadBackend(); + + Status InitializeBackend(); + + Status CreateDevice(); + + Status ReleaseDevice(); + + Status ShutdownBackend(); + + Status InitializeProfiling(); + + Status ReleaseProfilehandle(); + + Status CreateContext(); + + Status ReleaseContext(); + // Sets the ORT logger and creates a corresponding QNN logger with the same log level. // NOTE: caller must lock the `logger_mutex_` before calling this function. Status InitializeQnnLog(const logging::Logger& logger); From e665a2b81d21f06f339f99ad0c7cc9c8cb15e28b Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:41:20 -0800 Subject: [PATCH 36/53] document GetOrRegister functions --- onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h | 2 ++ .../core/providers/qnn/builder/qnn_context_mem_handle_manager.h | 2 ++ 2 files changed, 4 insertions(+) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 1aec1e0765101..61d6712db241e 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -133,6 +133,8 @@ class QnnBackendManager : public std::enable_shared_from_this uint64_t buffer_length, uint64_t& max_spill_fill_buffer_size); + // Gets an existing QNN mem handle or registers a new one. + // `mem_handle` is set to the QNN mem handle. Status GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t& mem_handle); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h index acb33d7175061..397ea8bad6d9a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.h @@ -28,6 +28,8 @@ class QnnContextMemHandleManager { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnContextMemHandleManager); + // Gets an existing QNN mem handle or registers a new one. + // `qnn_mem_handle` is set to the QNN mem handle and `did_register` is true if `qnn_mem_handle` was newly registered. Status GetOrRegister(void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t& qnn_mem_handle, bool& did_register); From 425023b297b303f1ba8dcf901c62325477446ed3 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Tue, 7 Jan 2025 17:05:12 -0800 Subject: [PATCH 37/53] add enable_htp_shared_memory_allocator to available_keys --- onnxruntime/test/perftest/ort_test_session.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index a0cac47585b48..a952e1c6ae4b2 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -202,7 +202,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device {"backend_path", "profiling_file_path", "profiling_level", "rpc_control_latency", "vtcm_mb", "soc_model", "device_id", "htp_performance_mode", "qnn_saver_path", "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", - "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer"}); + "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer", + "enable_htp_shared_memory_allocator"}); for (const auto& provider_option : provider_options) { const std::string& key = provider_option.first; const std::string& value = provider_option.second; From 4d292081cb860395544ad63b38f9de31bd56030f Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Wed, 8 Jan 2025 18:09:02 -0800 Subject: [PATCH 38/53] make DlError return const char* --- onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 90f5a77487831..39d0dcef9f167 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -48,7 +48,7 @@ static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnSystemInterface_t* qnn_i return qnn_interface->systemApiVersion; } -static char* DlError() { +static const char* DlError() { #ifdef _WIN32 return ""; #else From 568c9a73d22a30ff9f9a56583dd37a2dc6eda66f Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:31:28 -0800 Subject: [PATCH 39/53] Use ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE for SharedContext --- onnxruntime/core/providers/qnn/shared_context.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/qnn/shared_context.h b/onnxruntime/core/providers/qnn/shared_context.h index fdd3e411e0b7e..a111e57038304 100644 --- a/onnxruntime/core/providers/qnn/shared_context.h +++ b/onnxruntime/core/providers/qnn/shared_context.h @@ -64,8 +64,8 @@ class SharedContext { private: SharedContext() = default; ~SharedContext() = default; - SharedContext(const SharedContext&) = delete; - SharedContext& operator=(const SharedContext&) = delete; + + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SharedContext); std::vector> shared_qnn_models_; // Producer sessions can be in parallel From 8b955358f147a4cfe80eba7a3e3e3532468509f6 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 9 Jan 2025 10:50:29 -0800 Subject: [PATCH 40/53] use safeint instead of manually checking against int max --- onnxruntime/core/providers/qnn/qnn_allocator.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 84d67615ff3c9..6ef2a3da58a39 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -6,9 +6,9 @@ #include #include #include -#include #include "core/common/common.h" +#include "core/common/safeint.h" #include "core/mlas/inc/mlas.h" // for MlasGetPreferredBufferAlignment() namespace onnxruntime::qnn { @@ -100,14 +100,13 @@ void* HtpSharedMemoryAllocator::Alloc(size_t requested_size) { const size_t shared_memory_block_size_in_bytes = allocation_offset + requested_size; // rpcmem_alloc() has an int size parameter. make sure we don't overflow. - constexpr size_t max_size_in_bytes = std::numeric_limits::max(); - ORT_ENFORCE(shared_memory_block_size_in_bytes <= max_size_in_bytes, - "Allocation size (", shared_memory_block_size_in_bytes, ") is larger than maximum allowed (", - max_size_in_bytes, ")."); + // TODO switch to rpcmem_alloc2() which has size_t size parameter. + // need to verify that rpcmem_alloc2() is available in all environments we care about. + const SafeInt shared_memory_block_size_in_bytes_int = shared_memory_block_size_in_bytes; // allocate shared memory void* shared_memory_raw = rpcmem_lib_->Api().alloc(rpcmem::RPCMEM_HEAP_ID_SYSTEM, rpcmem::RPCMEM_DEFAULT_FLAGS, - static_cast(shared_memory_block_size_in_bytes)); + shared_memory_block_size_in_bytes_int); ORT_ENFORCE(shared_memory_raw != nullptr, "rpcmem_alloc() failed to allocate and returned nullptr."); auto shared_memory = WrapSharedMemoryWithUniquePtr(shared_memory_raw, rpcmem_lib_->Api()); From 515999c297eaffdb4b383f2db5ea861580b3f97e Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 9 Jan 2025 12:54:07 -0800 Subject: [PATCH 41/53] add/update doc for enable_htp_shared_memory_allocator option --- include/onnxruntime/core/session/onnxruntime_c_api.h | 4 ++++ onnxruntime/test/perftest/command_args_parser.cc | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 3d995e21e0017..ba4bfc7d6c28e 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3670,6 +3670,10 @@ struct OrtApi { * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary. * - "0": Default. Disabled. * - "1": Enabled. + * "enable_htp_shared_memory_allocator": Enable the QNN HTP shared memory allocator. Requires libcdsprpc.so/dll to be + * available. + * - "0": Default. Disabled. + * - "1": Enabled. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 70b5a07539554..5031d557ee2f0 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -102,7 +102,7 @@ namespace perftest { "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill fill buffer, used while generating QNN context binary.\n" - "\t [QNN only] [enable_htp_shared_memory_allocator]: Enable the QNN HTP shared memory allocator and use it for inputs and outputs.\n" + "\t [QNN only] [enable_htp_shared_memory_allocator]: Enable the QNN HTP shared memory allocator and use it for inputs and outputs. Requires libcdsprpc.so/dll to be available.\n" "\t Defaults to '0' (disabled).\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n" "\n" From 6986839ea9295b0eccab016fcffda9c20ec446b1 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:15:48 -0800 Subject: [PATCH 42/53] formatting --- include/onnxruntime/core/session/onnxruntime_c_api.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index ba4bfc7d6c28e..6fef2448be0fe 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -3670,8 +3670,8 @@ struct OrtApi { * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context binary. * - "0": Default. Disabled. * - "1": Enabled. - * "enable_htp_shared_memory_allocator": Enable the QNN HTP shared memory allocator. Requires libcdsprpc.so/dll to be - * available. + * "enable_htp_shared_memory_allocator": Enable the QNN HTP shared memory allocator. Requires libcdsprpc.so/dll to + * be available. * - "0": Default. Disabled. * - "1": Enabled. * From 00b286b9f2f7bd594716ccd1afbce5717ec40f89 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 9 Jan 2025 13:16:52 -0800 Subject: [PATCH 43/53] add some comments about HtpSharedmemoryAllocator impl --- onnxruntime/core/providers/qnn/qnn_allocator.cc | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/onnxruntime/core/providers/qnn/qnn_allocator.cc b/onnxruntime/core/providers/qnn/qnn_allocator.cc index 6ef2a3da58a39..e0a8cac599db6 100644 --- a/onnxruntime/core/providers/qnn/qnn_allocator.cc +++ b/onnxruntime/core/providers/qnn/qnn_allocator.cc @@ -13,6 +13,21 @@ namespace onnxruntime::qnn { +/** + * HtpSharedMemoryAllocator allocation details + * + * The HTP shared memory allocator will allocate a block of shared memory larger than the amount requested in order to + * hold some additional info. + * Each allocation returned by HtpSharedMemoryAllocator::Alloc() is preceded by an AllocationHeader structure. + * + * For example, if Alloc(num_requested_bytes) is called, this is what the memory layout looks like: + * | AllocationHeader bytes | num_requested_bytes bytes | + * ^- address returned by Alloc() + * + * The AllocationHeader can be used to obtain the owning allocator instance, which in turn can be used to do other + * operations with that allocation, such as retrieving more info about the allocation. + */ + namespace { struct AllocationHeader { From 88dec645f2bb595e1c8097cc6cd596d92d3745d8 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 9 Jan 2025 16:40:46 -0800 Subject: [PATCH 44/53] initialize with QNN_MEM_DESRIPTOR_INIT --- .../providers/qnn/builder/qnn_context_mem_handle_manager.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc index a1f28762be48f..7ce539c21ef94 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_context_mem_handle_manager.cc @@ -57,7 +57,7 @@ Status QnnContextMemHandleManager::GetOrRegister(void* shared_memory_address, co ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::GetAllocationSharedMemoryInfo(shared_memory_address, shared_memory_info)); - Qnn_MemDescriptor_t mem_descriptor{}; + Qnn_MemDescriptor_t mem_descriptor = QNN_MEM_DESCRIPTOR_INIT; mem_descriptor.memShape.dimSize = qnn_tensor_dims; mem_descriptor.memShape.numDim = qnn_tensor_rank; mem_descriptor.memShape.shapeConfig = nullptr; From 4ca3ea75fa2927f0ee0d78752d42021bdca813bc Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Thu, 9 Jan 2025 18:29:55 -0800 Subject: [PATCH 45/53] address comments --- .../qnn/builder/qnn_backend_manager.cc | 2 +- .../qnn/builder/qnn_backend_manager.h | 3 ++- .../core/providers/qnn/builder/qnn_model.cc | 18 +++++++++--------- .../core/providers/qnn/builder/qnn_utils.cc | 2 +- .../core/providers/qnn/builder/qnn_utils.h | 5 +++-- 5 files changed, 16 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 39d0dcef9f167..0455dd86d1033 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1466,7 +1466,7 @@ const char* QnnBackendManager::QnnProfileErrorToString(QnnProfile_Error_t error) } } -const char* QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) { +std::string_view QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) { return utils::GetQnnErrorMessage(qnn_interface_, error); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 3c8e37f12dc35..1dc8b35d1e6bb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -226,7 +227,7 @@ class QnnBackendManager : public std::enable_shared_from_this static const std::string GetEventTypeString(QnnProfile_EventType_t eventType); static const std::string ExtractQnnScalarValue(const Qnn_Scalar_t& scalar); const char* QnnProfileErrorToString(QnnProfile_Error_t error); - const char* QnnErrorHandleToString(Qnn_ErrorHandle_t error); + std::string_view QnnErrorHandleToString(Qnn_ErrorHandle_t error); QnnLog_Level_t MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level); #ifdef _WIN32 void LogQnnProfileEventAsTraceLogging( diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 9ca1b86de9e40..2af0971d74f11 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -185,12 +185,12 @@ Status QnnModel::SetupQnnInputOutput(const logging::Logger& logger) { return Status::OK(); } -static Status BindQnnTensorMemoryToOrtValue(const logging::Logger& logger, - QnnBackendManager& qnn_backend_manager, - const OrtMemoryInfo& ort_value_memory_info, - void* ort_value_data, uint32_t ort_value_data_size, - Qnn_ContextHandle_t qnn_context, - Qnn_Tensor_t& qnn_tensor) { +static Status BindQnnTensorMemoryToOrtValueMemory(const logging::Logger& logger, + QnnBackendManager& qnn_backend_manager, + const OrtMemoryInfo& ort_value_memory_info, + void* ort_value_data, uint32_t ort_value_data_size, + Qnn_ContextHandle_t qnn_context, + Qnn_Tensor_t& qnn_tensor) { // either set qnn_tensor memHandle or clientBuf const bool uses_shared_memory = ort_value_memory_info == HtpSharedMemoryAllocator::AssociatedMemoryInfo(); @@ -242,7 +242,7 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); - ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValueMemory( logger, *qnn_backend_manager_, *static_cast(ort_input_tensor.GetTensorMemoryInfo()), @@ -268,11 +268,11 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); - ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValue( + ORT_RETURN_IF_ERROR(BindQnnTensorMemoryToOrtValueMemory( logger, *qnn_backend_manager_, *static_cast(ort_output_tensor.GetTensorMemoryInfo()), - const_cast(ort_output_tensor.GetTensorRawData()), qnn_output_info.tensor_byte_size, + ort_output_tensor.GetTensorMutableRawData(), qnn_output_info.tensor_byte_size, graph_info_->GraphContext(), qnn_outputs.back())); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 444dbafe2d4cc..2052c492014d6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -578,7 +578,7 @@ Status Quantize(const double double_value, return Status::OK(); } -const char* GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, Qnn_ErrorHandle_t qnn_error_handle) { +std::string_view GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, Qnn_ErrorHandle_t qnn_error_handle) { // From QNN SDK: The memory is statically owned and should not be freed by the caller. const char* error_msg = nullptr; if (qnn_interface.errorGetMessage(qnn_error_handle, &error_msg) == QNN_SUCCESS) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 08c20177821e6..28c5b91b7f1cf 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -111,8 +112,8 @@ Status Quantize(const double double_value, int& quant_value); // Gets error message associated with QNN error handle value. -const char* GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, - Qnn_ErrorHandle_t qnn_error_handle); +std::string_view GetQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, + Qnn_ErrorHandle_t qnn_error_handle); // Gets verbose error message associated with QNN error handle value. std::string GetVerboseQnnErrorMessage(const QNN_INTERFACE_VER_TYPE& qnn_interface, From 7a88c3f4bc3cfa2a004073ff7897b1779ef54425 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:21:53 -0800 Subject: [PATCH 46/53] rework context handle ownership --- .../qnn/builder/qnn_backend_manager.cc | 73 +++++++++++-------- .../qnn/builder/qnn_backend_manager.h | 30 ++++++-- 2 files changed, 67 insertions(+), 36 deletions(-) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 0455dd86d1033..a996a8e80f0eb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -558,7 +558,7 @@ Status QnnBackendManager::CreateContext() { ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result)); - ORT_RETURN_IF_ERROR(AddQnnContext(context)); + ORT_RETURN_IF_ERROR(AddQnnContextHandle(context)); context_created_ = true; return Status::OK(); @@ -569,17 +569,9 @@ Status QnnBackendManager::ReleaseContext() { return Status::OK(); } - // release context mem handles - context_mem_handles_.clear(); - - bool failed = false; - for (auto context : contexts_) { - Qnn_ErrorHandle_t result = qnn_interface_.contextFree(context, nullptr); - if (QNN_CONTEXT_NO_ERROR != result) { - failed = true; - } - } - ORT_RETURN_IF(failed, "Failed to release context."); + // release QNN context handles + contexts_.clear(); + context_map_.clear(); context_created_ = false; return Status::OK(); @@ -780,7 +772,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t &context, profile_backend_handle_); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); - ORT_RETURN_IF_ERROR(AddQnnContext(context)); + ORT_RETURN_IF_ERROR(AddQnnContextHandle(context)); if (1 == graph_count) { // in case the EPContext node is generated from script // the graph name from the context binary may not match the EPContext node name @@ -1700,49 +1692,68 @@ void* QnnBackendManager::LibFunction(void* handle, const char* symbol, std::stri #endif } -Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) { +Status QnnBackendManager::AddQnnContextHandle(Qnn_ContextHandle_t raw_context_handle) { ORT_RETURN_IF(logger_ == nullptr, "logger_ should be set."); - auto mem_handle_manager = std::make_shared(GetQnnInterface(), context, *logger_); - const bool inserted = context_mem_handles_.try_emplace(context, std::move(mem_handle_manager)).second; - ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", context); + auto free_context_handle = [this, &logger = *logger_](Qnn_ContextHandle_t raw_context_handle) { + const auto free_result = qnn_interface_.contextFree(raw_context_handle, nullptr); + if (free_result != QNN_CONTEXT_NO_ERROR) { + LOGS(logger, ERROR) << "qnn_interface.contextFree() failed: " + << utils::GetVerboseQnnErrorMessage(qnn_interface_, free_result); + } + }; + + // take ownership of `raw_context_handle` + auto context_handle = UniqueQnnContextHandle(raw_context_handle, free_context_handle); + auto mem_handle_manager = std::make_unique(GetQnnInterface(), raw_context_handle, + *logger_); - contexts_.push_back(context); + auto context_handle_record = std::make_shared(); + context_handle_record->context_handle = std::move(context_handle); + context_handle_record->mem_handles = std::move(mem_handle_manager); + + const bool inserted = context_map_.try_emplace(raw_context_handle, std::move(context_handle_record)).second; + ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", raw_context_handle); + + contexts_.push_back(raw_context_handle); return Status::OK(); } -Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address, +Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context_handle, + void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t& mem_handle) { - const auto context_mem_handles_it = context_mem_handles_.find(context); - ORT_RETURN_IF_NOT(context_mem_handles_it != context_mem_handles_.end(), "QNN context not found: ", context); + const auto context_handle_record_it = context_map_.find(context_handle); + ORT_RETURN_IF_NOT(context_handle_record_it != context_map_.end(), "QNN context not found: ", context_handle); + + auto& context_handle_record = context_handle_record_it->second; + auto& context_mem_handle_manager = context_handle_record->mem_handles; - auto& context_mem_handle_manager = context_mem_handles_it->second; bool did_register{}; ORT_RETURN_IF_ERROR(context_mem_handle_manager->GetOrRegister(shared_memory_address, qnn_tensor, mem_handle, did_register)); if (did_register) { - HtpSharedMemoryAllocator::AllocationCleanUpFn allocation_clean_up = + HtpSharedMemoryAllocator::AllocationCleanUpFn unregister_mem_handle = [&logger = *logger_, weak_backend_manager = weak_from_this(), - weak_context_mem_handle_manager = std::weak_ptr{context_mem_handle_manager}]( + weak_context_handle_record = std::weak_ptr{context_handle_record}]( void* shared_memory_address) { - // Get QnnBackendManager shared_ptr to ensure that: - // - QNN interface is still valid. - // - QNN context handle is still valid. This should be true as long as QNN contexts are not freed from - // anywhere other than the destructor. + // Lock QnnBackendManager shared_ptr to ensure that QNN interface is still valid. auto backend_manager = weak_backend_manager.lock(); if (!backend_manager) { return; } - auto context_mem_handle_manager = weak_context_mem_handle_manager.lock(); - if (!context_mem_handle_manager) { + // Lock QnnContextHandleRecord shared_ptr to ensure that QNN context handle is still valid. + auto context_handle_record = weak_context_handle_record.lock(); + if (!context_handle_record) { return; } + auto& context_mem_handle_manager = context_handle_record->mem_handles; + auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address); if (!unregister_status.IsOK()) { LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: " @@ -1751,7 +1762,7 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont }; ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::AddAllocationCleanUp(shared_memory_address, - std::move(allocation_clean_up))); + std::move(unregister_mem_handle))); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 1dc8b35d1e6bb..685e03f17cdd3 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -57,7 +57,8 @@ class QnnBackendManager : public std::enable_shared_from_this return std::make_shared(config, PrivateConstructorTag{}); } - // Note: creation should be done via Create() + // Note: Creation should be done via Create(). This constructor is public so that it can be called from + // std::make_shared(). QnnBackendManager(const QnnBackendManagerConfig& config, PrivateConstructorTag) : backend_path_(config.backend_path), profiling_level_etw_(config.profiling_level_etw), @@ -240,7 +241,20 @@ class QnnBackendManager : public std::enable_shared_from_this const char* eventIdentifier); #endif - Status AddQnnContext(Qnn_ContextHandle_t context); + // Adds a new QNN context. + // Transfers ownership of `context_handle` (i.e., responsibility of freeing it) to this instance. + Status AddQnnContextHandle(Qnn_ContextHandle_t context_handle); + + private: + // assume Qnn_ContextHandle_t is a pointer and able to be wrapped with std::unique_ptr + static_assert(std::is_pointer_v); + using UniqueQnnContextHandle = + std::unique_ptr, std::function>; + + struct QnnContextHandleRecord { + UniqueQnnContextHandle context_handle; + std::unique_ptr mem_handles; + }; private: const std::string backend_path_; @@ -254,10 +268,16 @@ class QnnBackendManager : public std::enable_shared_from_this QnnBackend_Config_t** backend_config_ = nullptr; Qnn_LogHandle_t log_handle_ = nullptr; Qnn_DeviceHandle_t device_handle_ = nullptr; - std::vector contexts_; - // Note: Using shared_ptr so that we can refer to it with a weak_ptr from a + + // Map of Qnn_ContextHandle_t to QnnContextHandleRecord. + // The QnnContextHandleRecord has ownership of the Qnn_ContextHandle_t. + // Note: Using shared_ptr so that we can refer to it with a weak_ptr from a // HtpSharedMemoryAllocator allocation cleanup callback. - std::unordered_map> context_mem_handles_; + std::unordered_map> context_map_; + + // Vector of Qnn_ContextHandle_t. The context handles are owned by context_map_. + std::vector contexts_; + ProfilingLevel profiling_level_etw_; ProfilingLevel profiling_level_; ProfilingLevel profiling_level_merge_; From f3730353609bdc133a57f8b14f608097bc1800c3 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 10 Jan 2025 16:27:39 -0800 Subject: [PATCH 47/53] add / update tests --- .../core/providers/qnn/rpcmem_library.cc | 6 +- .../test/providers/qnn/qnn_basic_test.cc | 21 ++- onnxruntime/test/shared_lib/test_inference.cc | 162 ++++++++++++++++++ 3 files changed, 182 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc index 77a340ddfcea1..59e6cff925668 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.cc +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -35,7 +35,11 @@ DynamicLibraryHandle LoadDynamicLibrary(const PathString& path, bool global_symb const auto& env = Env::Default(); void* library_handle = nullptr; - ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(path, global_symbols, &library_handle)); + + const auto load_status = env.LoadDynamicLibrary(path, global_symbols, &library_handle); + if (!load_status.IsOK()) { + ORT_THROW("Failed to load ", ToUTF8String(path), ": ", load_status.ErrorMessage()); + } return DynamicLibraryHandle{library_handle, unload_library}; } diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index ed21ebbccc923..8dfae0fd5b0a4 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1100,11 +1100,6 @@ TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { } TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { -#if !defined(__ANDROID__) && !defined(_WIN32) - // TODO there's probably a better way to check that we are on a Qualcomm device - GTEST_SKIP() << "Test is only supported on a Qualcomm device."; -#endif - ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -1113,9 +1108,23 @@ TEST_F(QnnHTPBackendTests, UseHtpSharedMemoryAllocatorForInputs) { #endif provider_options["enable_htp_shared_memory_allocator"] = "1"; + std::unique_ptr qnn_ep; + try { + qnn_ep = QnnExecutionProviderWithOptions(provider_options); + } catch (const OnnxRuntimeException& e) { + // handle particular exception that indicates that the libcdsprpc.so / dll can't be loaded +#if defined(_WIN32) + constexpr const char* expected_error_message = "Failed to load libcdsprpc.dll"; +#else + constexpr const char* expected_error_message = "Failed to load libcdsprpc.so"; +#endif + ASSERT_THAT(e.what(), testing::HasSubstr(expected_error_message)); + GTEST_SKIP() << "HTP shared memory allocator is unavailable."; + } + AllocatorPtr htp_shared_memory_allocator{}; { - auto allocators = QnnExecutionProviderWithOptions(provider_options)->CreatePreferredAllocators(); + auto allocators = qnn_ep->CreatePreferredAllocators(); ASSERT_FALSE(allocators.empty()); auto& allocator = allocators[0]; ASSERT_EQ(allocator->Info(), qnn::HtpSharedMemoryAllocator::AssociatedMemoryInfo()); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index e8c8c8db8d08f..ead2eadde8e80 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1934,6 +1934,47 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) { } #endif +#if defined(USE_QNN) + +// Returns true if QNN EP was created and QNN HTP shared memory allocator is available, false otherwise. +static bool CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(PATH_TYPE model_path, Ort::Session& session) { +#if defined(_WIN32) + constexpr const char* backend_path = "QnnCpu.dll"; +#else + constexpr const char* backend_path = "libQnnCpu.so"; +#endif + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider("QNN", + {{"enable_htp_shared_memory_allocator", "1"}, + {"backend_path", backend_path}}); + + try { + session = Ort::Session{*ort_env, model_path, session_options}; + return true; + } catch (const Ort::Exception& e) { + // handle particular exception that indicates that the libcdsprpc.so / dll can't be loaded + std::string_view error_message = e.what(); + +#if defined(_WIN32) + std::string_view expected_error_message = "Failed to load libcdsprpc.dll"; +#else + std::string_view expected_error_message = "Failed to load libcdsprpc.so"; + #endif + + if (e.GetOrtErrorCode() == ORT_FAIL && + error_message.find(expected_error_message) != std::string_view::npos) { + session = Ort::Session{nullptr}; + return false; + } + + // propagate other exceptions + throw; + } +} + +#endif // defined(USE_QNN) + TEST(CApiTest, get_allocator_cpu) { Ort::SessionOptions session_options; Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -2001,6 +2042,32 @@ TEST(CApiTest, get_allocator_rocm) { } #endif +#if defined(USE_QNN) + +TEST(CApiTest, get_allocator_qnn_htp_shared) { + Ort::Session session{nullptr}; + + if (!CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(NAMED_AND_ANON_DIM_PARAM_URI, session)) { + GTEST_SKIP() << "HTP shared memory allocator is unavailable."; + } + + Ort::MemoryInfo info_qnn_htp_shared("QnnHtpShared", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); + Ort::Allocator qnn_htp_shared_allocator(session, info_qnn_htp_shared); + + auto allocator_info = qnn_htp_shared_allocator.GetInfo(); + ASSERT_EQ(allocator_info, info_qnn_htp_shared); + + void* p = qnn_htp_shared_allocator.Alloc(1024); + ASSERT_NE(p, nullptr); + qnn_htp_shared_allocator.Free(p); + + auto mem_allocation = qnn_htp_shared_allocator.GetAllocation(1024); + ASSERT_NE(mem_allocation.get(), nullptr); + ASSERT_EQ(mem_allocation.size(), size_t{1024}); +} + +#endif // defined(USE_QNN) + TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -2178,6 +2245,101 @@ TEST(CApiTest, io_binding_cuda) { } #endif +#if defined(USE_QNN) + +TEST(CApiTest, io_binding_qnn_htp_shared) { + Ort::Session session{nullptr}; + if (!CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(MODEL_URI, session)) { + GTEST_SKIP() << "HTP shared memory allocator is unavailable."; + } + + Ort::MemoryInfo info_qnn_htp_shared("QnnHtpShared", OrtAllocatorType::OrtDeviceAllocator, 0, OrtMemTypeDefault); + + Ort::Allocator qnn_htp_shared_allocator(session, info_qnn_htp_shared); + auto allocator_info = qnn_htp_shared_allocator.GetInfo(); + ASSERT_EQ(info_qnn_htp_shared, allocator_info); + + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto input_data = qnn_htp_shared_allocator.GetAllocation(x_values.size() * sizeof(float)); + ASSERT_NE(input_data.get(), nullptr); + memcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size()); + + // Create an OrtValue tensor backed by data on QNN HTP shared memory + Ort::Value bound_x = Ort::Value::CreateTensor(info_qnn_htp_shared, reinterpret_cast(input_data.get()), x_values.size(), + x_shape.data(), x_shape.size()); + + const std::array expected_y_shape = {3, 2}; + const std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + auto output_data = qnn_htp_shared_allocator.GetAllocation(expected_y.size() * sizeof(float)); + ASSERT_NE(output_data.get(), nullptr); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_y = Ort::Value::CreateTensor(info_qnn_htp_shared, reinterpret_cast(output_data.get()), + expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); + + Ort::IoBinding binding(session); + binding.BindInput("X", bound_x); + binding.BindOutput("Y", bound_y); + + session.Run(Ort::RunOptions(), binding); + + // Check the values against the bound raw memory + { + gsl::span y{reinterpret_cast(output_data.get()), expected_y.size()}; + ASSERT_TRUE(std::equal(std::begin(y), std::end(y), std::begin(expected_y))); + } + + // Now compare values via GetOutputValues + { + std::vector output_values = binding.GetOutputValues(); + ASSERT_EQ(output_values.size(), 1U); + const Ort::Value& Y_value = output_values[0]; + ASSERT_TRUE(Y_value.IsTensor()); + Ort::TensorTypeAndShapeInfo type_info = Y_value.GetTensorTypeAndShapeInfo(); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, type_info.GetElementType()); + auto count = type_info.GetElementCount(); + ASSERT_EQ(expected_y.size(), count); + + gsl::span y{Y_value.GetTensorData(), count}; + ASSERT_TRUE(std::equal(std::begin(y), std::end(y), std::begin(expected_y))); + } + + { + std::vector output_names = binding.GetOutputNames(); + ASSERT_EQ(1U, output_names.size()); + ASSERT_EQ(output_names[0].compare("Y"), 0); + } + + // Now replace binding of Y with an on device binding instead of pre-allocated memory. + // This is when we can not allocate an OrtValue due to unknown dimensions + { + binding.BindOutput("Y", info_qnn_htp_shared); + session.Run(Ort::RunOptions(), binding); + } + + // Check the output value allocated based on the device binding. + { + std::vector output_values = binding.GetOutputValues(); + ASSERT_EQ(output_values.size(), 1U); + const Ort::Value& Y_value = output_values[0]; + ASSERT_TRUE(Y_value.IsTensor()); + Ort::TensorTypeAndShapeInfo type_info = Y_value.GetTensorTypeAndShapeInfo(); + ASSERT_EQ(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, type_info.GetElementType()); + auto count = type_info.GetElementCount(); + ASSERT_EQ(expected_y.size(), count); + + gsl::span y{Y_value.GetTensorData(), count}; + ASSERT_TRUE(std::equal(std::begin(y), std::end(y), std::begin(expected_y))); + } + + // Clean up + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); +} + +#endif // defined(USE_QNN) + #if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) || defined(USE_DML) TEST(CApiTest, basic_cuda_graph) { const auto& api = Ort::GetApi(); From e86ff2eb1f1bde60bac5e1b1463bff6e222e78b6 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:34:04 -0800 Subject: [PATCH 48/53] add check for qnn tensor dynamic shape --- onnxruntime/core/providers/qnn/builder/qnn_def.cc | 14 ++++++++++++++ onnxruntime/core/providers/qnn/builder/qnn_def.h | 1 + .../core/providers/qnn/builder/qnn_model.cc | 3 +++ .../core/providers/qnn/builder/qnn_utils.cc | 12 ++++++++++++ onnxruntime/core/providers/qnn/builder/qnn_utils.h | 2 ++ 5 files changed, 32 insertions(+) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index 5af7f024716f1..1a58d0d417a0b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -394,6 +394,20 @@ const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor) ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); } +uint8_t* GetQnnTensorIsDynamicDimensions(const Qnn_Tensor_t& qnn_tensor) { + if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) { + return nullptr; // not present in v1 + } + +#ifdef QNN_TENSOR_V2_INIT + if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) { + return qnn_tensor.v2.isDynamicDimensions; + } +#endif // QNN_TENSOR_V2_INIT + + ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version); +} + Status CompareQnnQuantParams(const Qnn_QuantizeParams_t& qparam0, const Qnn_QuantizeParams_t& qparam1, float& scale_diff, int32_t& offset_diff) { scale_diff = 0.0f; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index b3b6b392d7857..f0619eb218245 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -126,6 +126,7 @@ uint32_t* GetQnnTensorDims(const Qnn_Tensor_t& qnn_tensor); const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor); Qnn_MemHandle_t GetQnnTensorMemHandle(const Qnn_Tensor_t& qnn_tensor); const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor); +uint8_t* GetQnnTensorIsDynamicDimensions(const Qnn_Tensor_t& qnn_tensor); /** * Compares two sets of quantization parameters. Sets the parameters `scale_diff` and `offset_diff` diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 2af0971d74f11..5f8b7f35eea8b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -322,6 +322,9 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, qnn_tensor_infos.resize(tensor_count); for (auto& tensor_wrapper : tensor_wrappers) { + ORT_RETURN_IF(utils::QnnTensorHasDynamicShape(tensor_wrapper.GetQnnTensor()), + "QNN tensor (", tensor_wrapper.GetName(), ") has dynamic shape. This is not supported yet."); + const size_t length = utils::GetQnnTensorDataSizeInBytes(tensor_wrapper.GetTensorDims(), tensor_wrapper.GetTensorDataType()); const auto& tensor_name = tensor_wrapper.GetName(); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index 2052c492014d6..08d3120260cea 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -3,6 +3,7 @@ #include "core/providers/qnn/builder/qnn_utils.h" +#include #include #include #include @@ -71,6 +72,17 @@ size_t GetQnnTensorDataSizeInBytes(gsl::span shape, Qnn_DataType return std::accumulate(shape.begin(), shape.end(), data_length, std::multiplies<>{}); } +bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor) { + const uint8_t* is_dynamic_dimensions = GetQnnTensorIsDynamicDimensions(tensor); + if (is_dynamic_dimensions == nullptr) { + return false; + } + + const auto rank = GetQnnTensorRank(tensor); + return std::any_of(is_dynamic_dimensions, is_dynamic_dimensions + rank, + [](uint8_t is_dynamic_dimension) { return is_dynamic_dimension != 0; }); +} + std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) { switch (scalar.dataType) { case QNN_DATATYPE_INT_8: diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 28c5b91b7f1cf..950f349c5006f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -29,6 +29,8 @@ size_t GetElementSizeByType(ONNXTensorElementDataType elem_type); size_t GetQnnTensorDataSizeInBytes(gsl::span shape, Qnn_DataType_t element_data_type); +bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor); + // TODO: make these work with Wrappers? std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param); std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor); From 6fa33f0baa5d17eed41547efcc7fa2f46e73e4bc Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:19:20 -0800 Subject: [PATCH 49/53] Add comment about multi-threading considerations --- .../core/providers/qnn/builder/qnn_backend_manager.cc | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index a996a8e80f0eb..e91fda32510dd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1724,6 +1724,15 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont void* shared_memory_address, const Qnn_Tensor_t& qnn_tensor, Qnn_MemHandle_t& mem_handle) { + // Multi-threading situations to consider: + // 1) Shared memory allocation is being freed in another thread while we are processing `shared_memory_address`. + // This implies incorrect usage as the memory is being freed while it is still in use. Let's assume this won't + // happen. + // 2) The shared memory allocation clean up function is being run from another thread while the + // QnnContextHandleRecord or QnnBackendManager objects are being destroyed. + // Usage of weak_ptrs from the clean up function should ensure that those objects are only accessed while they are + // in scope. + const auto context_handle_record_it = context_map_.find(context_handle); ORT_RETURN_IF_NOT(context_handle_record_it != context_map_.end(), "QNN context not found: ", context_handle); From 4101cca61e60cd3a18633ac3703e51e83a235d82 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:25:50 -0800 Subject: [PATCH 50/53] fix test comment --- onnxruntime/test/shared_lib/test_inference.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index ead2eadde8e80..0dc19466c36d0 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2274,7 +2274,7 @@ TEST(CApiTest, io_binding_qnn_htp_shared) { auto output_data = qnn_htp_shared_allocator.GetAllocation(expected_y.size() * sizeof(float)); ASSERT_NE(output_data.get(), nullptr); - // Create an OrtValue tensor backed by data on CUDA memory + // Create an OrtValue tensor backed by data on QNN HTP shared memory Ort::Value bound_y = Ort::Value::CreateTensor(info_qnn_htp_shared, reinterpret_cast(output_data.get()), expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); From 14af7ad697d807e42a6a66be08c63bd81c36b6aa Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:31:59 -0800 Subject: [PATCH 51/53] fix formatting --- onnxruntime/test/shared_lib/test_inference.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 0dc19466c36d0..e427fc533998d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1960,7 +1960,7 @@ static bool CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(PATH_TYPE model std::string_view expected_error_message = "Failed to load libcdsprpc.dll"; #else std::string_view expected_error_message = "Failed to load libcdsprpc.so"; - #endif +#endif if (e.GetOrtErrorCode() == ORT_FAIL && error_message.find(expected_error_message) != std::string_view::npos) { From 2f5c93c967cce4eccca8a095d3cae45ccca90dc2 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 13 Jan 2025 11:20:58 -0800 Subject: [PATCH 52/53] add ifdef to use htp backend if on arm64 or linux. --- onnxruntime/test/shared_lib/test_inference.cc | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index e427fc533998d..dab73d3824d3b 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -1938,10 +1938,16 @@ TEST(ReducedOpsBuildTest, test_excluded_ops) { // Returns true if QNN EP was created and QNN HTP shared memory allocator is available, false otherwise. static bool CreateSessionWithQnnEpAndQnnHtpSharedMemoryAllocator(PATH_TYPE model_path, Ort::Session& session) { +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + constexpr bool use_htp_backend = true; +#else + constexpr bool use_htp_backend = false; +#endif + #if defined(_WIN32) - constexpr const char* backend_path = "QnnCpu.dll"; + const char* backend_path = use_htp_backend ? "QnnHtp.dll" : "QnnCpu.dll"; #else - constexpr const char* backend_path = "libQnnCpu.so"; + const char* backend_path = use_htp_backend ? "libQnnHtp.so" : "libQnnCpu.so"; #endif Ort::SessionOptions session_options; From 7ca45523a3afc9adf1d19b6b0e514b95f09cf472 Mon Sep 17 00:00:00 2001 From: edgchen1 <18449977+edgchen1@users.noreply.github.com> Date: Mon, 13 Jan 2025 11:49:10 -0800 Subject: [PATCH 53/53] fix typo --- include/onnxruntime/core/session/onnxruntime_cxx_api.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 0a57999246b06..123ef98901003 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2133,7 +2133,7 @@ struct KernelContext { // If input is optional and is not present, the method returns an empty ConstValue // which can be compared to nullptr. ConstValue GetInput(size_t index) const; - // If outout is optional and is not present, the method returns an empty UnownedValue + // If output is optional and is not present, the method returns an empty UnownedValue // which can be compared to nullptr. UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const; UnownedValue GetOutput(size_t index, const std::vector& dims) const;