Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QNN EP HTP shared memory allocator #23136

Merged
merged 61 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
110a3bc
save work
edgchen1 Nov 5, 2024
0ba3a2f
save work
edgchen1 Nov 9, 2024
8436b14
add logging for setting QNN tensor memory, update comment
edgchen1 Nov 11, 2024
c9826f4
add option to enable HTP shared memory allocator to onnxruntime_perf_…
edgchen1 Nov 11, 2024
c07c35e
hack - try to cache mem handles in QnnModel
edgchen1 Nov 12, 2024
60dc837
Remove duplicate include.
edgchen1 Nov 13, 2024
24e072f
hack, continued - move cache out to SharedContext
edgchen1 Nov 14, 2024
e66cbef
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Nov 14, 2024
8c515da
move mem handle registration to allocator
edgchen1 Nov 15, 2024
18e2780
hook up some test code
edgchen1 Nov 15, 2024
09ddce5
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Nov 19, 2024
a65bb71
rename to RpcMemAllocator to HtpSharedMemoryAllocator
edgchen1 Nov 27, 2024
bfb135e
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Dec 2, 2024
f179a0d
remove onnx protobuf dependency from allocator.h, add shared provider…
edgchen1 Dec 3, 2024
7645ef4
remove unused CPUAllocator::TensorAlloc declaration
edgchen1 Dec 5, 2024
1043732
Check for nullptr when trying to free
baijumeswani Dec 5, 2024
022f4bc
move mem handle management to QNN backend manager
edgchen1 Dec 10, 2024
c527dee
remove IAllocator::TensorAlloc()
edgchen1 Dec 10, 2024
e4f72b3
document IAllocator::Free
edgchen1 Dec 10, 2024
39ff901
remove IAllocator__TensorAlloc
edgchen1 Dec 10, 2024
1bed5a4
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Dec 10, 2024
d70db84
fix android build warning
edgchen1 Dec 10, 2024
45ef883
remove shared mem handles from shared context
edgchen1 Dec 11, 2024
d2e7b3c
remove allocation clean up callback removal, use weak_ptrs in allocat…
edgchen1 Dec 16, 2024
c892c18
some clean up
edgchen1 Dec 17, 2024
b295eef
more clean up
edgchen1 Dec 17, 2024
13f5e30
add helper to get qnn error message
edgchen1 Dec 17, 2024
d5eace1
use make_shared for QnnBackendManager
edgchen1 Dec 17, 2024
bacbcdc
add test to qnn_basic_test.cc, document allocator parameter.
edgchen1 Dec 17, 2024
30cd9ed
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Dec 17, 2024
b29ab61
rename variables
edgchen1 Dec 18, 2024
67a54b8
revert changes to onnxruntime/test/providers/qnn/max_min_op_test.cc
edgchen1 Dec 18, 2024
c0569e2
fix formatting
edgchen1 Dec 19, 2024
dd45c84
skip test if not android and not windows
edgchen1 Dec 19, 2024
959d8df
update comment
edgchen1 Dec 19, 2024
ab48516
remove QnnBackendManager::ReleaseQnnContextMemHandles declaration, up…
edgchen1 Dec 19, 2024
4a3f6c3
add onnxruntime_c_api.h include to ortmemoryinfo.h
edgchen1 Jan 6, 2025
65ce4b1
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Jan 6, 2025
ff12541
rename GetQnnTensorDataSize to GetQnnTensorDataSizeInBytes
edgchen1 Jan 6, 2025
5e6e103
add QnnBackendManager::Create function to ensure shared_ptr usage
edgchen1 Jan 6, 2025
78e86cc
make some QnnBackendManager member functions private, update comment
edgchen1 Jan 6, 2025
e665a2b
document GetOrRegister functions
edgchen1 Jan 7, 2025
425023b
add enable_htp_shared_memory_allocator to available_keys
edgchen1 Jan 8, 2025
781a4a0
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Jan 9, 2025
4d29208
make DlError return const char*
edgchen1 Jan 9, 2025
568c9a7
Use ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE for SharedContext
edgchen1 Jan 9, 2025
8b95535
use safeint instead of manually checking against int max
edgchen1 Jan 9, 2025
515999c
add/update doc for enable_htp_shared_memory_allocator option
edgchen1 Jan 9, 2025
6986839
formatting
edgchen1 Jan 9, 2025
00b286b
add some comments about HtpSharedmemoryAllocator impl
edgchen1 Jan 9, 2025
88dec64
initialize with QNN_MEM_DESRIPTOR_INIT
edgchen1 Jan 10, 2025
4ca3ea7
address comments
edgchen1 Jan 10, 2025
7a88c3f
rework context handle ownership
edgchen1 Jan 10, 2025
f373035
add / update tests
edgchen1 Jan 11, 2025
e86ff2e
add check for qnn tensor dynamic shape
edgchen1 Jan 11, 2025
6fa33f0
Add comment about multi-threading considerations
edgchen1 Jan 11, 2025
4101cca
fix test comment
edgchen1 Jan 11, 2025
14af7ad
fix formatting
edgchen1 Jan 11, 2025
2f5c93c
add ifdef to use htp backend if on arm64 or linux.
edgchen1 Jan 13, 2025
b868a9f
Merge remote-tracking branch 'origin/main' into edgchen1/qnn_ep_rpcmem
edgchen1 Jan 13, 2025
7ca4552
fix typo
edgchen1 Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/onnxruntime/core/framework/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ constexpr const char* OpenVINO_CPU = "OpenVINO_CPU";
constexpr const char* OpenVINO_GPU = "OpenVINO_GPU";
constexpr const char* OpenVINO_RT = "OpenVINO_RT";
constexpr const char* OpenVINO_RT_NPU = "OpenVINO_RT_NPU";
constexpr const char* QNN_HTP_SHARED = "QnnHtpShared";
constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer";
constexpr const char* WEBNN_TENSOR = "WebNN_Tensor";

Expand Down Expand Up @@ -81,6 +82,10 @@ class IAllocator {
*/
virtual void* Alloc(size_t size) = 0;

/**
* Free memory at p.
* If p is nullptr, do nothing.
*/
virtual void Free(void* p) = 0;

// Reserve() is an interface exposed for an implementation of IAllocator
Expand Down
1 change: 1 addition & 0 deletions include/onnxruntime/core/framework/ortdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ struct OrtDevice {
static const MemoryType CUDA_PINNED = 1;
static const MemoryType HIP_PINNED = 2;
static const MemoryType CANN_PINNED = 3;
static const MemoryType QNN_HTP_SHARED = 4;
};

constexpr OrtDevice(DeviceType device_type_, MemoryType memory_type_, DeviceId device_id_)
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/framework/ortmemoryinfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <string_view>

#include "core/common/hash_combine.h"
#include "core/framework/ortdevice.h"
#include "core/session/onnxruntime_c_api.h" // for OrtMemType, OrtAllocatorType
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved

struct OrtMemoryInfo {
OrtMemoryInfo() = default; // to allow default construction of Tensor
Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2130,10 +2130,10 @@ struct KernelContext {
explicit KernelContext(OrtKernelContext* context);
size_t GetInputCount() const;
size_t GetOutputCount() const;
// If input is optional and is not present, the method returns en empty ConstValue
// If input is optional and is not present, the method returns an empty ConstValue
// which can be compared to nullptr.
ConstValue GetInput(size_t index) const;
// If outout is optional and is not present, the method returns en empty UnownedValue
// If outout is optional and is not present, the method returns an empty UnownedValue
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
// which can be compared to nullptr.
UnownedValue GetOutput(size_t index, const int64_t* dim_values, size_t dim_count) const;
UnownedValue GetOutput(size_t index, const std::vector<int64_t>& dims) const;
Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/core/framework/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,18 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA
mem_type1);
} else if (strcmp(name1, onnxruntime::CUDA_PINNED) == 0) {
*out = new OrtMemoryInfo(
onnxruntime::CUDA_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
onnxruntime::CUDA_PINNED, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::CUDA_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else if (strcmp(name1, onnxruntime::HIP_PINNED) == 0) {
*out = new OrtMemoryInfo(
onnxruntime::HIP_PINNED, type, OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
onnxruntime::HIP_PINNED, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::HIP_PINNED, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else if (strcmp(name1, onnxruntime::QNN_HTP_SHARED) == 0) {
*out = new OrtMemoryInfo(
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
onnxruntime::QNN_HTP_SHARED, type,
OrtDevice(OrtDevice::CPU, OrtDevice::MemType::QNN_HTP_SHARED, static_cast<OrtDevice::DeviceId>(id1)),
id1, mem_type1);
} else {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Specified device is not supported.");
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ SessionState::SessionState(Graph& graph,
for (auto& ep : execution_providers_) {
auto allocators = ep->CreatePreferredAllocators();
for (auto& alloc : allocators) {
allocators_->insert({alloc->Info().device, alloc}); // DONT overwrite existing key
allocators_->insert({alloc->Info().device, alloc}); // DON'T overwrite existing key
}
}
}
Expand Down
86 changes: 76 additions & 10 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,22 @@
#include <fstream>
#include <string>
#include "QnnOpDef.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
#include "HTP/QnnHtpSystemContext.h"
#include "CPU/QnnCpuCommon.h"
// TODO: not exist for Windows yet
// #include "GPU/QnnGpuCommon.h"
#include "DSP/QnnDspCommon.h"
#include "HTP/QnnHtpCommon.h"
#include "HTP/QnnHtpContext.h"
#include "HTP/QnnHtpPerfInfrastructure.h"
#include "HTP/QnnHtpSystemContext.h"
#include "Saver/QnnSaver.h"
#include <gsl/gsl>
#include "core/framework/endian_utils.h"
#include "core/common/logging/capture.h"
#include "core/providers/qnn/qnn_allocator.h"
#include "core/providers/qnn/builder/onnx_ctx_model_helper.h"
#include "core/providers/qnn/builder/qnn_configs_helper.h"
#include "core/providers/qnn/builder/qnn_utils.h"

#ifdef _WIN32
#include <winmeta.h>
Expand All @@ -46,6 +48,14 @@
return qnn_interface->systemApiVersion;
}

static char* DlError() {
#ifdef _WIN32
return "";
#else
return ::dlerror();
#endif
}

template <typename F, class T>
Status QnnBackendManager::GetQnnInterfaceProvider(const char* lib_path,
const char* interface_provider_name,
Expand Down Expand Up @@ -545,10 +555,11 @@
device_handle_,
context_configs,
&context);
contexts_.push_back(context);

ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result));

ORT_RETURN_IF_ERROR(AddQnnContext(context));

context_created_ = true;
return Status::OK();
}
Expand All @@ -558,6 +569,9 @@
return Status::OK();
}

// release context mem handles
context_mem_handles_.clear();

bool failed = false;
for (auto context : contexts_) {
Qnn_ErrorHandle_t result = qnn_interface_.contextFree(context, nullptr);
Expand Down Expand Up @@ -766,7 +780,7 @@
&context,
profile_backend_handle_);
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt);
contexts_.push_back(context);
ORT_RETURN_IF_ERROR(AddQnnContext(context));
if (1 == graph_count) {
// in case the EPContext node is generated from script
// the graph name from the context binary may not match the EPContext node name
Expand Down Expand Up @@ -1425,12 +1439,7 @@
}

const char* QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) {
// From QNN SDK: The memory is statically owned and should not be freed by the caller.
const char* error_msg = nullptr;
if (QNN_SUCCESS == qnn_interface_.errorGetMessage(error, &error_msg)) {
return error_msg;
}
return "Unknown";
return utils::GetQnnErrorMessage(qnn_interface_, error);
}

const std::string QnnBackendManager::ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) {
Expand Down Expand Up @@ -1663,5 +1672,62 @@
#endif
}

Status QnnBackendManager::AddQnnContext(Qnn_ContextHandle_t context) {
ORT_RETURN_IF(logger_ == nullptr, "logger_ should be set.");

auto mem_handle_manager = std::make_shared<QnnContextMemHandleManager>(GetQnnInterface(), context, *logger_);

Check warning on line 1678 in onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_shared<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc:1678: Add #include <memory> for make_shared<> [build/include_what_you_use] [4]
const bool inserted = context_mem_handles_.try_emplace(context, std::move(mem_handle_manager)).second;
ORT_RETURN_IF_NOT(inserted, "QNN context was already added: ", context);

contexts_.push_back(context);

return Status::OK();
}

Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address,
const Qnn_Tensor_t& qnn_tensor,
Qnn_MemHandle_t& mem_handle) {
const auto context_mem_handles_it = context_mem_handles_.find(context);
ORT_RETURN_IF_NOT(context_mem_handles_it != context_mem_handles_.end(), "QNN context not found: ", context);

auto& context_mem_handle_manager = context_mem_handles_it->second;
bool did_register{};
ORT_RETURN_IF_ERROR(context_mem_handle_manager->GetOrRegister(shared_memory_address, qnn_tensor,
mem_handle, did_register));

if (did_register) {
HtpSharedMemoryAllocator::AllocationCleanUpFn allocation_clean_up =
[&logger = *logger_,
weak_backend_manager = weak_from_this(),
weak_context_mem_handle_manager = std::weak_ptr{context_mem_handle_manager}](
void* shared_memory_address) {
// Get QnnBackendManager shared_ptr to ensure that:
// - QNN interface is still valid.
// - QNN context handle is still valid. This should be true as long as QNN contexts are not freed from
// anywhere other than the destructor.
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
auto backend_manager = weak_backend_manager.lock();
if (!backend_manager) {
return;
}

auto context_mem_handle_manager = weak_context_mem_handle_manager.lock();
if (!context_mem_handle_manager) {
return;
}

edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
auto unregister_status = context_mem_handle_manager->Unregister(shared_memory_address);
if (!unregister_status.IsOK()) {
LOGS(logger, ERROR) << "Failed to unregister shared memory mem handle for address: "
<< shared_memory_address << ", error: " << unregister_status.ErrorMessage();
}
};

ORT_RETURN_IF_ERROR(HtpSharedMemoryAllocator::AddAllocationCleanUp(shared_memory_address,
std::move(allocation_clean_up)));

Check warning on line 1726 in onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc:1726: Add #include <utility> for move [build/include_what_you_use] [4]
}

return Status::OK();
}

} // namespace qnn
} // namespace onnxruntime
116 changes: 65 additions & 51 deletions onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,70 +24,55 @@
#include "core/common/status.h"
#include "core/common/logging/logging.h"
#include "core/common/path_string.h"
#include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h"
#include "core/providers/qnn/builder/qnn_def.h"

namespace onnxruntime {
namespace qnn {

class QnnModel;

class QnnBackendManager {
// configuration values for QnnBackendManager creation
struct QnnBackendManagerConfig {
std::string backend_path;
ProfilingLevel profiling_level_etw;
ProfilingLevel profiling_level;
std::string profiling_file_path;
ContextPriority context_priority;
std::string qnn_saver_path;
uint32_t device_id;
QnnHtpDevice_Arch_t htp_arch;
uint32_t soc_model;
bool enable_htp_weight_sharing;
};

class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager> {
private:
// private tag to pass to constructor to ensure that constructor cannot be directly called externally
struct PrivateConstructorTag {};

public:
QnnBackendManager(std::string&& backend_path,
ProfilingLevel profiling_level_etw,
ProfilingLevel profiling_level,
std::string&& profiling_file_path,
ContextPriority context_priority,
std::string&& qnn_saver_path,
uint32_t device_id,
QnnHtpDevice_Arch_t htp_arch,
uint32_t soc_model,
bool enable_htp_weight_sharing)
: backend_path_(backend_path),
profiling_level_etw_(profiling_level_etw),
profiling_level_(profiling_level),
profiling_file_path_(profiling_file_path),
context_priority_(context_priority),
qnn_saver_path_(qnn_saver_path),
device_id_(device_id),
htp_arch_(htp_arch),
soc_model_(soc_model),
enable_htp_weight_sharing_(enable_htp_weight_sharing) {
static std::shared_ptr<QnnBackendManager> Create(const QnnBackendManagerConfig& config) {
return std::make_shared<QnnBackendManager>(config, PrivateConstructorTag{});
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager);

~QnnBackendManager();
char* DlError() {
#ifdef _WIN32
return "";
#else
return ::dlerror();
#endif
// Note: creation should be done via Create()
QnnBackendManager(const QnnBackendManagerConfig& config, PrivateConstructorTag)
: backend_path_(config.backend_path),
edgchen1 marked this conversation as resolved.
Show resolved Hide resolved
profiling_level_etw_(config.profiling_level_etw),
profiling_level_(config.profiling_level),
profiling_file_path_(config.profiling_file_path),
context_priority_(config.context_priority),
qnn_saver_path_(config.qnn_saver_path),
device_id_(config.device_id),
htp_arch_(config.htp_arch),
soc_model_(config.soc_model),
enable_htp_weight_sharing_(config.enable_htp_weight_sharing) {
}

Status LoadBackend();

Status InitializeBackend();

Status CreateDevice();

Status ReleaseDevice();

Status ShutdownBackend();

Status InitializeProfiling();

Status ReleaseProfilehandle();

Status CreateContext();

Status ReleaseContext();

Status ResetContext() {
ORT_RETURN_IF_ERROR(ReleaseContext());
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager);

return CreateContext();
}
~QnnBackendManager();

std::unique_ptr<unsigned char[]> GetContextBinaryBuffer(uint64_t& written_buffer_size);

Expand Down Expand Up @@ -148,7 +133,31 @@ class QnnBackendManager {
uint64_t buffer_length,
uint64_t& max_spill_fill_buffer_size);

// Gets an existing QNN mem handle or registers a new one.
// `mem_handle` is set to the QNN mem handle.
Status GetOrRegisterContextMemHandle(Qnn_ContextHandle_t context, void* shared_memory_address,
const Qnn_Tensor_t& qnn_tensor,
Qnn_MemHandle_t& mem_handle);

private:
Status LoadBackend();

Status InitializeBackend();

Status CreateDevice();

Status ReleaseDevice();

Status ShutdownBackend();

Status InitializeProfiling();

Status ReleaseProfilehandle();

Status CreateContext();

Status ReleaseContext();

// Sets the ORT logger and creates a corresponding QNN logger with the same log level.
// NOTE: caller must lock the `logger_mutex_` before calling this function.
Status InitializeQnnLog(const logging::Logger& logger);
Expand Down Expand Up @@ -230,6 +239,8 @@ class QnnBackendManager {
const char* eventIdentifier);
#endif

Status AddQnnContext(Qnn_ContextHandle_t context);

private:
const std::string backend_path_;
std::mutex logger_mutex_;
Expand All @@ -243,6 +254,9 @@ class QnnBackendManager {
Qnn_LogHandle_t log_handle_ = nullptr;
Qnn_DeviceHandle_t device_handle_ = nullptr;
std::vector<Qnn_ContextHandle_t> contexts_;
// Note: Using shared_ptr<QnnContextMemHandleManager> so that we can refer to it with a weak_ptr from a
// HtpSharedMemoryAllocator allocation cleanup callback.
std::unordered_map<Qnn_ContextHandle_t, std::shared_ptr<QnnContextMemHandleManager>> context_mem_handles_;
ProfilingLevel profiling_level_etw_;
ProfilingLevel profiling_level_;
ProfilingLevel profiling_level_merge_;
Expand Down
Loading
Loading