diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index d468894080b3d..0dfb7928e3adc 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -94,6 +94,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), + "", main_context_node.Name(), qnn_models, max_spill_fill_size); @@ -127,6 +128,18 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible."); } + std::string context_binary_path_str = context_binary_path.string(); +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + if (qnn_backend_manager->FileMappingIsEnabled()) { + return qnn_backend_manager->LoadCachedQnnContextFromBuffer(nullptr, + 0, + context_binary_path_str, + main_context_node.Name(), + qnn_models, + max_spill_fill_size); + } +#endif + size_t buffer_size{0}; std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary); ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); @@ -144,6 +157,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, cache_file.close(); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), + context_binary_path_str, main_context_node.Name(), qnn_models, max_spill_fill_size); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 164e4c3157f62..ac6f6db1d0a55 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -29,6 +29,10 @@ #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "core/providers/qnn/builder/qnn_utils.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED +#include "core/providers/qnn/builder/qnn_windows_file_mapper.h" +#endif + // Flag to determine if Backend should do node validation for each opNode added #define DO_GRAPH_NODE_VALIDATIONS 1 @@ -770,22 +774,54 @@ Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t return Status::OK(); } +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED +// callback required for allocating file mapping resources +static Qnn_ErrorHandle_t DmaDataProvider(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, void* notify_param) { + if (notify_param == nullptr) { + LOGS_DEFAULT(ERROR) << "DmaProvider: notify_param is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + auto pair = reinterpret_cast*>(notify_param); + + if (pair->first == nullptr) { + LOGS_DEFAULT(ERROR) << "DmaProvider: file mapper is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + return pair->first->MapDmaData(request, response, pair->second); +} + +// callback required for releasing file mapping resources +static Qnn_ErrorHandle_t DmaDataRelease(Qnn_ContextBinaryDmaDataMem_t data_mem, void* notify_param) { + if (notify_param == nullptr) { + LOGS_DEFAULT(ERROR) << "DmaRelease: notify_param is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + auto pair = reinterpret_cast*>(notify_param); + + if (pair->first == nullptr) { + LOGS_DEFAULT(ERROR) << "DmaRelease: file mapper is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + }; + + return pair->first->ReleaseDmaData(data_mem, pair->second); +} +#endif // QNN_FILE_MAPPED_WEIGHTS_ENABLED + // callback required to add context handles to class list // when using contextCreateFromBinaryListAsync() -void ContextCreateAsyncCallback(Qnn_ContextHandle_t context, - Qnn_GraphHandle_t graph, - const char* graphName, - QnnContext_createFromBinaryAsyncNotifyType_t notifyType, - void* notifyParam, - Qnn_ErrorHandle_t status) { +static void ContextCreateAsyncCallback(Qnn_ContextHandle_t context, + Qnn_GraphHandle_t /* graph */, + const char* /* graph_name */, + QnnContext_createFromBinaryAsyncNotifyType_t /* notify_type */, + void* notify_param, + Qnn_ErrorHandle_t /* status */) { auto qnn_backend_manager = SharedContext::GetInstance().GetSharedQnnBackendManager(); if (context) { - qnn_backend_manager->ProcessContextFromBinListAsync(context, notifyParam); - } - - if (nullptr == graphName || graph || notifyType || status) { - // Avoid compilation unused var warning error + qnn_backend_manager->ProcessContextFromBinListAsync(context, notify_param); } } @@ -809,6 +845,41 @@ void QnnBackendManager::ProcessContextFromBinListAsync(Qnn_ContextHandle_t conte } } +Status QnnBackendManager::ReadContextBinIfValid(const std::string& context_bin_filepath, + BufferInfo_t& buffer_info, + bool read_file_contents) { + std::error_code ec; + ORT_RETURN_IF(!std::filesystem::exists(context_bin_filepath, ec), "Context binary does not exist: ", context_bin_filepath); + ORT_RETURN_IF(ec, "Failed to read context binary: ", context_bin_filepath, + ", error: ", ec.message()); + + auto file_size = std::filesystem::file_size(context_bin_filepath, ec); + ORT_RETURN_IF(ec, "Failed to retrieve size of context binary: ", context_bin_filepath, + ", error: ", ec.message()); + ORT_RETURN_IF(file_size == 0, "Context binary is empty: ", context_bin_filepath); + ORT_RETURN_IF(file_size > SIZE_MAX, "Context binary (", context_bin_filepath, ") file size (", file_size, + " bytes) exceeds maximum value of size_t for this platform (", SIZE_MAX, " bytes)."); + + size_t buffer_size = static_cast(file_size); + + std::unique_ptr buffer; + if (read_file_contents) { + std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary); + ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to read context binary from: ", context_bin_filepath); + + buffer = std::make_unique(buffer_size); + ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file."); + + const auto& read_result = cache_file.read(buffer.get(), buffer_size); + ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); + } + + buffer_info.data = std::move(buffer); + buffer_info.size = buffer_size; + + return Status::OK(); +} + Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map) { #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) QnnContext_Config_t context_config_resource_sharing = QNN_CONTEXT_CONFIG_INIT; @@ -845,6 +916,26 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord #endif nullptr}; +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + if (file_mapped_weights_enabled_ && file_mapper_) { + // Retry logic -- if context creation failed with file mapped weights, then retry with feature disabled + if (CreateContextFromListAsyncV2(configs, context_bin_map) != Status::OK()) { + LOGS(*logger_, WARNING) << "Failed to create context with file mapping enabled. Retrying with feature disabled."; + + file_mapped_weights_enabled_ = false; + // Destruction of file_mapper_ to prevent resource leaks + file_mapper_.reset(); + } else { + return Status::OK(); + } + } +#endif + return CreateContextFromListAsyncV1(configs, context_bin_map); +} + +Status QnnBackendManager::CreateContextFromListAsyncV1(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map) { std::vector context_params_list; std::vector context_paramsv1_list; std::vector context_params_ptr_list; @@ -856,20 +947,12 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord for (auto& it : context_bin_map) { auto context_bin_filepath = it.first; - std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary); - ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to retrieve context binary from: ", context_bin_filepath); - - cache_file.seekg(0, cache_file.end); - size_t buffer_size = static_cast(cache_file.tellg()); - ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); + BufferInfo_t buffer_info; + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, buffer_info, true)); - cache_file.seekg(0, cache_file.beg); - std::unique_ptr buffer = std::make_unique(buffer_size); - ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file."); - const auto& read_result = cache_file.read(buffer.get(), buffer_size); - ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); + std::unique_ptr buffer = std::move(buffer_info.data); + size_t buffer_size = buffer_info.size; - cache_file.close(); QnnContext_ParamsV1_t context_params_v1 = {nullptr, buffer.get(), buffer_size, @@ -892,14 +975,82 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord configs, nullptr); - context_params_ptr_list.clear(); - context_paramsv1_list.clear(); - context_params_list.clear(); + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result); + return Status::OK(); +} + +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED +Status QnnBackendManager::CreateContextFromListAsyncV2(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map) { + std::vector context_params_list; + std::vector context_paramsv2_list; + std::vector context_callbacks_list; + std::vector context_params_ptr_list; + + std::vector buffer_list; + + context_params_list.reserve(context_bin_map.size()); + context_callbacks_list.reserve(context_bin_map.size()); + context_params_ptr_list.reserve(context_bin_map.size() + 1); + + for (auto& it : context_bin_map) { + auto context_bin_filepath = it.first; + + BufferInfo_t buffer_info; + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, buffer_info, false)); + + size_t buffer_size = buffer_info.size; + + void* buffer = nullptr; + void* file_mapping_handle = nullptr; + ORT_RETURN_IF_ERROR(file_mapper_->MapContextBin(context_bin_filepath, &file_mapping_handle)); + ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappingPointer(context_bin_filepath, &buffer)); + + auto& notify_param = file_mapping_notify_params_.emplace_back(file_mapper_.get(), + file_mapping_handle); + + Qnn_ContextBinaryCallback_t context_file_map_callbacks; + context_file_map_callbacks.type = QNN_CONTEXT_CALLBACK_DMA_BUFFER; + context_file_map_callbacks.dmaBufferCallback.version = QNN_CONTEXT_CALLBACK_DMA_BUFFER_VERSION_1; + context_file_map_callbacks.dmaBufferCallback.v1.dataProvide = DmaDataProvider; + context_file_map_callbacks.dmaBufferCallback.v1.dataRelease = DmaDataRelease; + context_file_map_callbacks.dmaBufferCallback.v1.notifyParam = reinterpret_cast(¬ify_param); + + // Callbacks require QnnContext_ParamsV2_t which is new to QNN API 2.32 + QnnContext_ParamsV2_t context_params_v2 = {nullptr, + buffer, + buffer_size, + nullptr, + ContextCreateAsyncCallback, + it.second.get(), + &context_file_map_callbacks}; + + QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_2, + {}}; + context_params.v2 = &context_params_v2; + + buffer_list.push_back(buffer); + context_params_list.push_back(std::move(context_params)); + context_callbacks_list.push_back(std::move(context_file_map_callbacks)); + context_paramsv2_list.push_back(std::move(context_params_v2)); + context_params_ptr_list.push_back(&(context_params_list.back())); + } + context_params_ptr_list.push_back(nullptr); + auto result = qnn_interface_.contextCreateFromBinaryListAsync(backend_handle_, + device_handle_, + context_params_ptr_list.data(), + configs, + nullptr); + + for (auto& buffer : buffer_list) { + ORT_RETURN_IF_ERROR(file_mapper_->FreeContextBinMappingPointer(buffer)); + } buffer_list.clear(); ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result); return Status::OK(); } +#endif // QNN_FILE_MAPPED_WEIGHTS_ENABLED Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; @@ -1098,6 +1249,7 @@ Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer, } Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + const std::string& context_bin_filepath, std::string node_name, QnnModelLookupTable& qnn_models, int64_t max_spill_fill_size) { @@ -1106,6 +1258,28 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t nullptr == qnn_sys_interface_.systemContextFree; ORT_RETURN_IF(result, "Failed to get valid function pointer."); + void* bin_buffer = nullptr; +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + void* file_mapping_handle; + if (file_mapped_weights_enabled_) { + ORT_RETURN_IF(!file_mapper_, "Attemping to use File Mapping feature but file_mapper_ is uninitialized"); + + BufferInfo_t buffer_info; + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, buffer_info, false)); + buffer_length = buffer_info.size; + + ORT_RETURN_IF(buffer_length == 0, "Context bin has a size of 0 bytes", context_bin_filepath); + ORT_RETURN_IF_ERROR(file_mapper_->MapContextBin(context_bin_filepath, &file_mapping_handle)); + ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappingPointer(context_bin_filepath, &bin_buffer)); + + } else { + ORT_RETURN_IF(buffer == nullptr, "Attempting to load QNN context from buffer but buffer is null"); + bin_buffer = static_cast(buffer); + } +#else + bin_buffer = static_cast(buffer); +#endif + QnnSystemContext_Handle_t sys_ctx_handle = nullptr; auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); @@ -1113,7 +1287,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; Qnn_ContextBinarySize_t binary_info_size{0}; rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, - static_cast(buffer), + bin_buffer, buffer_length, &binary_info, &binary_info_size); @@ -1188,6 +1362,25 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + Qnn_ContextBinaryCallback_t callbacks; + if (file_mapped_weights_enabled_ && file_mapper_) { + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinaryWithCallback, + "Invalid function pointer for contextCreateFromBinaryWithCallback."); + + auto& notify_param = file_mapping_notify_params_.emplace_back(file_mapper_.get(), + file_mapping_handle); + + callbacks.type = QNN_CONTEXT_CALLBACK_DMA_BUFFER; + callbacks.dmaBufferCallback.version = QNN_CONTEXT_CALLBACK_DMA_BUFFER_VERSION_1; + callbacks.dmaBufferCallback.v1.dataProvide = DmaDataProvider; + callbacks.dmaBufferCallback.v1.dataRelease = DmaDataRelease; + callbacks.dmaBufferCallback.v1.notifyParam = reinterpret_cast(¬ify_param); + } +#else + ORT_UNUSED_PARAMETER(context_bin_filepath); +#endif + qnn::profile::ProfilingInfo profiling_info; #ifdef QNN_SYSTEM_PROFILE_API_ENABLED if (ProfilingEnabled()) { @@ -1195,13 +1388,47 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } #endif - rt = qnn_interface_.contextCreateFromBinary(backend_handle_, - device_handle_, - context_configs, - static_cast(buffer), - buffer_length, - &context, - profile_backend_handle_); +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + std::unique_ptr backup_buffer; + if (file_mapped_weights_enabled_ && file_mapper_) { + rt = qnn_interface_.contextCreateFromBinaryWithCallback(backend_handle_, + device_handle_, + context_configs, + &callbacks, + bin_buffer, + buffer_length, + &context, + profile_backend_handle_, + NULL); + + ORT_RETURN_IF_ERROR(file_mapper_->FreeContextBinMappingPointer(bin_buffer)); + + if (rt != QNN_SUCCESS) { + LOGS(*logger_, WARNING) << "Failed to create context with file mapping enabled. Retrying with feature disabled."; + + file_mapped_weights_enabled_ = false; + // Destruction of file_mapper_ to prevent resource leaks + file_mapper_.reset(); + + // Read context bin from file since file mapping has failed + BufferInfo_t buffer_info; + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, buffer_info, true)); + backup_buffer = std::move(buffer_info.data); + + bin_buffer = static_cast(backup_buffer.get()); + } + } +#endif // QNN_FILE_MAPPED_WEIGHTS_ENABLED + + if (!file_mapped_weights_enabled_ || rt != QNN_SUCCESS) { + rt = qnn_interface_.contextCreateFromBinary(backend_handle_, + device_handle_, + context_configs, + bin_buffer, + buffer_length, + &context, + profile_backend_handle_); + } #ifdef QNN_SYSTEM_PROFILE_API_ENABLED if (ProfilingEnabled()) { @@ -1249,6 +1476,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool need_load_system_lib, bool share_ep_contexts, bool enable_vtcm_backup_buffer_sharing, + bool enable_file_mapped_weights, std::unordered_map>>& context_bin_map) { std::lock_guard lock(logger_recursive_mutex_); if (backend_setup_completed_) { @@ -1280,6 +1508,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, return Status::OK(); } + file_mapped_weights_enabled_ = enable_file_mapped_weights; vtcm_backup_buffer_sharing_enabled_ = enable_vtcm_backup_buffer_sharing; Status status = Status::OK(); @@ -1335,6 +1564,12 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, LOGS(logger, VERBOSE) << "LoadOpPackage succeed."; } +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + if (file_mapped_weights_enabled_ && !file_mapper_ && GetQnnBackendType() == QnnBackendType::HTP) { + file_mapper_ = std::make_unique(logger); + } +#endif + bool enable_htp_weight_sharing = false; if (share_ep_contexts && !load_from_cached_context) { #if defined(__aarch64__) || defined(_M_ARM64) @@ -1529,7 +1764,6 @@ void QnnBackendManager::ReleaseResources() { } backend_setup_completed_ = false; - return; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index f1c6c19bb1311..7af505cdfaf0d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -32,6 +32,10 @@ #include "core/providers/qnn/builder/qnn_profile_serializer.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED +#include "core/providers/qnn/builder/qnn_file_mapping_callback_interface.h" +#endif + namespace onnxruntime { namespace qnn { @@ -154,6 +158,7 @@ class QnnBackendManager : public std::enable_shared_from_this std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + const std::string& context_bin_filepath, std::string node_name, std::unordered_map>& qnn_models, int64_t max_spill_fill_size); @@ -163,6 +168,7 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib, bool share_ep_contexts, bool enable_vtcm_backup_buffer_sharing, + bool enable_file_mapped_weights, std::unordered_map>>& context_bin_map); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -246,7 +252,14 @@ class QnnBackendManager : public std::enable_shared_from_this bool ProfilingEnabled() { return profiling_enabled_; } #endif + bool FileMappingIsEnabled() { return file_mapped_weights_enabled_; } + private: + typedef struct BufferInfo { + std::unique_ptr data; + size_t size; + } BufferInfo_t; + Status LoadBackend(); Status InitializeBackend(); @@ -263,9 +276,23 @@ class QnnBackendManager : public std::enable_shared_from_this Status CreateContext(bool enable_htp_weight_sharing); + Status ReadContextBinIfValid(const std::string& context_bin_filepath, + BufferInfo_t& buffer_info, + bool read_file_contents); + Status CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map); + Status CreateContextFromListAsyncV1(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map); + +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + Status CreateContextFromListAsyncV2(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map); +#endif + Status ReleaseContext(); // Sets the ORT logger and creates a corresponding QNN logger with the same log level. @@ -451,6 +478,15 @@ class QnnBackendManager : public std::enable_shared_from_this bool context_created_ = false; bool backend_setup_completed_ = false; bool vtcm_backup_buffer_sharing_enabled_ = false; + bool file_mapped_weights_enabled_ = false; + +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + std::unique_ptr file_mapper_ = nullptr; + // Notify params for file mapping must persist throughout lifetime of + // QnnBackendManager for release of DMA data callback on destruction + std::vector> file_mapping_notify_params_; +#endif + // NPU backend requires quantized model QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 625166f62d166..8ddd8e6d65c25 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -19,6 +19,12 @@ namespace qnn { #define QNN_SYSTEM_PROFILE_API_ENABLED #endif +#if defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64)) +#if QNN_API_VERSION_MAJOR > 2 || ((QNN_API_VERSION_MAJOR) == 2 && (QNN_API_VERSION_MINOR >= 32)) +#define QNN_FILE_MAPPED_WEIGHTS_ENABLED +#endif +#endif + // QNN only support subset of POSIX of dlopen/dlsym/dladdr/dlerror/dlclose // except the following flags for dlopen, others should be done only // when we really need them diff --git a/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_callback_interface.h b/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_callback_interface.h new file mode 100644 index 0000000000000..4e7cf434873c7 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_callback_interface.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_def.h" + +namespace onnxruntime { +namespace qnn { + +class FileMappingCallbackInterface { + public: + virtual ~FileMappingCallbackInterface() = default; + virtual Status MapContextBin(const std::string& bin_filepath, + void** notify_param) = 0; + virtual Status ReleaseContextBin(const std::string& model_name) = 0; + + virtual Status GetContextBinMappingPointer(const std::string& bin_filepath, void** mapping_ptr) = 0; + + virtual Status FreeContextBinMappingPointer(LPVOID bin_mapping_pointer) = 0; + + virtual Qnn_ErrorHandle_t MapDmaData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, + void* notify_param) = 0; + virtual Qnn_ErrorHandle_t ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, + void* notify_param) = 0; + + virtual Qnn_ErrorHandle_t MapRawData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryRawDataResponse_t* response, + void* notify_param) = 0; + virtual Qnn_ErrorHandle_t ReleaseRawData(Qnn_ContextBinaryRawDataMem_t data_mem, + void* notify_param) = 0; +}; + +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc new file mode 100644 index 0000000000000..c985f5569cce1 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_windows_file_mapper.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + +#include + +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/rpcmem_library.h" + +namespace onnxruntime { +namespace qnn { + +WindowsFileMapper::WindowsFileMapper(const logging::Logger& logger) : logger_(&logger) { +} + +// Close all handles and registered buffers +// Use LOGS_DEFAULT here as this function will be called during destruction of QnnBackendManager +// At time of destruction. Usage of logger_ will not be available and will result in a seg fault +WindowsFileMapper::~WindowsFileMapper() { + // Ideally, there should be nothing to clean up at this point + // but free any resources anyway if applicable + if (!mapping_handle_to_info_map_.empty() || !context_bin_map_view_pointers_.empty()) { + LOGS_DEFAULT(WARNING) << "File mapping resources still exist. Attempting to free all resources."; + } + + context_bin_to_mapping_handle_map_.clear(); + + for (auto& mapview_ptr : context_bin_map_view_pointers_) { + CleanUpDataMapping(mapview_ptr, nullptr, 0); + } + + for (auto& kv : mapping_handle_to_info_map_) { + HANDLE file_mapping_handle = kv.first; + auto& mapping_info = kv.second; + + CleanUpDataMappings(mapping_info.mapped_data); + CloseHandles(mapping_info.file_handle, file_mapping_handle); + } + mapping_handle_to_info_map_.clear(); +} + +Status WindowsFileMapper::MapContextBin(const std::string& bin_filepath, + void** notify_param) { + LOGS(*logger_, INFO) << "Creating context bin file mapping for " + << bin_filepath; + + ORT_RETURN_IF(bin_filepath.empty(), "Context bin file path is empty"); + + std::lock_guard lock(map_mutex_); + + auto file_mapping_it = context_bin_to_mapping_handle_map_.find(bin_filepath); + if (file_mapping_it != context_bin_to_mapping_handle_map_.end()) { + LOGS(*logger_, INFO) << "Context bin file mapping already exists for " + << bin_filepath; + *notify_param = reinterpret_cast(file_mapping_it->second); + return Status::OK(); + } + + HANDLE file_handle = CreateFileA(bin_filepath.c_str(), + GENERIC_READ, + FILE_SHARE_READ, + NULL, + OPEN_EXISTING, + FILE_ATTRIBUTE_NORMAL, + NULL); + ORT_RETURN_IF(file_handle == INVALID_HANDLE_VALUE, + "Failed to create file handle with error code ", + GetLastError(), " for context bin", bin_filepath); + + LOGS(*logger_, VERBOSE) << "Created file handle (" << file_handle << ") for context bin: " + << bin_filepath; + + HANDLE file_mapping_handle = CreateFileMappingA(file_handle, NULL, PAGE_READONLY, 0x00, 0x00, NULL); + if (file_mapping_handle == INVALID_HANDLE_VALUE) { + CloseHandles(file_handle, nullptr); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create file mapping with error code ", + GetLastError(), " for context bin ", bin_filepath); + } + + LOGS(*logger_, INFO) << "Created file mapping with handle (" << file_mapping_handle << ") for context bin:" + << bin_filepath; + + auto inserted = context_bin_to_mapping_handle_map_.insert({bin_filepath, file_mapping_handle}); + if (!inserted.second) { + CloseHandles(file_handle, file_mapping_handle); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to add file handle mapping for context bin: ", + bin_filepath); + } + mapping_handle_to_info_map_.insert({file_mapping_handle, {file_handle, {}}}); + + *notify_param = reinterpret_cast(file_mapping_handle); + return Status::OK(); +} + +Status WindowsFileMapper::ReleaseContextBin(const std::string& bin_filepath) { + LOGS(*logger_, INFO) << "Removing context bin file mapping for " + << bin_filepath; + std::lock_guard lock(map_mutex_); + auto status = Status::OK(); + + auto bin_map_it = context_bin_to_mapping_handle_map_.find(bin_filepath); + + if (bin_map_it == context_bin_to_mapping_handle_map_.end()) { + LOGS(*logger_, VERBOSE) << "File handle does not exist for " << bin_filepath; + return status; + } + + HANDLE file_mapping_handle = bin_map_it->second; + + HANDLE file_handle = nullptr; + auto it = mapping_handle_to_info_map_.find(file_mapping_handle); + if (it == mapping_handle_to_info_map_.end()) { + status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "File mapping information does not exist for file mapping handle: ", + file_mapping_handle, ", context bin: ", bin_filepath); + } else { + MappingInfo_t& mapping_info = it->second; + file_handle = mapping_info.file_handle; + auto mapped_data = mapping_info.mapped_data; + + if (!mapped_data.empty()) { + LOGS(*logger_, WARNING) << "Attempting to remove context bin: " << bin_filepath + << ", but data regions still need to be unmapped. " + << "Proceeding with unmapping."; + CleanUpDataMappings(mapped_data); + } + } + + // Will ignore handles that are null + CloseHandles(file_handle, file_mapping_handle); + ORT_UNUSED_PARAMETER(mapping_handle_to_info_map_.erase(file_mapping_handle)); + ORT_UNUSED_PARAMETER(context_bin_to_mapping_handle_map_.erase(bin_filepath)); + + return Status::OK(); +} + +Status WindowsFileMapper::GetContextBinMappingPointer(const std::string& bin_filepath, void** mapping_ptr) { + LOGS(*logger_, INFO) << "Creating mapping pointer for " << bin_filepath; + + std::lock_guard lock(map_mutex_); + auto it = context_bin_to_mapping_handle_map_.find(bin_filepath); + + ORT_RETURN_IF(it == context_bin_to_mapping_handle_map_.end(), + "Failed to create mapping pointer: File mapping does not exist for ", + bin_filepath); + + HANDLE& file_mapping_handle = it->second; + + LPVOID mapview_ptr = MapViewOfFile(file_mapping_handle, + FILE_MAP_READ, + 0, 0, 0); + + ORT_RETURN_IF(mapview_ptr == nullptr, "Failed to create mapping pointer for ", bin_filepath); + + ORT_UNUSED_PARAMETER(context_bin_map_view_pointers_.insert(mapview_ptr)); + *mapping_ptr = mapview_ptr; + LOGS(*logger_, INFO) << "Created mapping pointer (" << mapping_ptr << ") for " << bin_filepath; + return Status::OK(); +} + +Status WindowsFileMapper::FreeContextBinMappingPointer(LPVOID bin_mapping_pointer) { + LOGS(*logger_, INFO) << "Releasing mapping pointer " << bin_mapping_pointer; + + std::lock_guard lock(map_mutex_); + auto it = context_bin_map_view_pointers_.find(bin_mapping_pointer); + + ORT_RETURN_IF(it == context_bin_map_view_pointers_.end(), "Mapping pointer ", + bin_mapping_pointer, " cannot be found and is invalid"); + + ORT_RETURN_IF(!UnmapViewOfFile(bin_mapping_pointer), "Failed to free mapping pointer ", bin_mapping_pointer, + " with error code ", GetLastError()); + ORT_UNUSED_PARAMETER(context_bin_map_view_pointers_.erase(bin_mapping_pointer)); + return Status::OK(); +} + +Qnn_ErrorHandle_t WindowsFileMapper::MapDmaData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, + void* notify_param) { + if (notify_param == nullptr) { + LOGS(*logger_, ERROR) << "Attempting to map DMA data for null mapping handle"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + std::lock_guard lock(map_mutex_); + HANDLE file_mapping_handle = reinterpret_cast(notify_param); + LOGS(*logger_, INFO) << "Mapping DMA data for request: mapping handle(" + << file_mapping_handle << "), offset(" << request.offset + << "), size(" << request.size << "), isBackendMappingNeeded(" + << request.isBackendMappingNeeded << ")"; + + auto buffer_size = request.size; + if (buffer_size == 0 || !request.isBackendMappingNeeded) { + LOGS(*logger_, ERROR) << "Mapping request size must be > 0 with backend mapping required"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + auto it = mapping_handle_to_info_map_.find(file_mapping_handle); + if (it == mapping_handle_to_info_map_.end()) { + LOGS(*logger_, ERROR) << "File mapping info not found for mapping handle: " << file_mapping_handle; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + MappingInfo_t& mapping_info = it->second; + + // Align to nearest granularity boundary + SYSTEM_INFO sys_info; + GetSystemInfo(&sys_info); + Qnn_ContextBinarySize_t granularity = sys_info.dwAllocationGranularity; + SIZE_T aligned_offset = request.offset & ~(granularity - 1); + SIZE_T delta = request.offset - aligned_offset; + + LPVOID aligned_data_ptr = MapViewOfFile(file_mapping_handle, + FILE_MAP_READ, + (aligned_offset >> 32), + (aligned_offset & 0xFFFFFFFF), + (buffer_size + delta)); + + if (aligned_data_ptr == nullptr) { + LOGS(*logger_, ERROR) << "Failed to map DMA data with error code " << GetLastError() + << " for file mapping handle " << file_mapping_handle; + return QNN_COMMON_ERROR_SYSTEM; + } + + LPVOID unaligned_data_ptr = static_cast(aligned_data_ptr) + delta; + LOGS(*logger_, INFO) << "Created DMA data mapping with: address(" << aligned_data_ptr + << "), aligned offset(" << aligned_offset << "), delta(" << delta + << "), unaligned address(" << unaligned_data_ptr << ")"; + + rpcmem_lib_.Api().register_buf(unaligned_data_ptr, buffer_size, NULL, + rpcmem::RPCMEM_ATTR_IMPORT_BUFFER | rpcmem::RPCMEM_ATTR_READ_ONLY); + + auto fd = rpcmem_lib_.Api().to_fd(unaligned_data_ptr); + if (fd == -1) { + LOGS(*logger_, ERROR) << "Failed to register DMA data mapping to RPCMEM"; + if (!UnmapViewOfFile(aligned_data_ptr)) { + LOGS(*logger_, ERROR) << "Failed to unmap DMA data with error code " << GetLastError() + << " with address : " << aligned_data_ptr; + } + return QNN_COMMON_ERROR_SYSTEM; + } + + mapping_info.mapped_data.insert({unaligned_data_ptr, {aligned_data_ptr, buffer_size}}); + response->dmaBuffer.fd = fd; + response->dmaBuffer.data = reinterpret_cast(unaligned_data_ptr); + response->dataStartOffset = 0; + response->alignedSize = buffer_size; + + return QNN_SUCCESS; +} + +// Use LOGS_DEFAULT here as this function will be called during destruction of QnnBackendManager +// At time of destruction. Usage of logger_ will not be available and will result in a seg fault +Qnn_ErrorHandle_t WindowsFileMapper::ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, + void* notify_param) { + if (notify_param == nullptr) { + LOGS_DEFAULT(ERROR) << "Attempting to release DMA data for null mapping handle"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + std::lock_guard lock(map_mutex_); + HANDLE file_mapping_handle = static_cast(notify_param); + LOGS_DEFAULT(INFO) << "Releasing DMA data mapping for mapping handle(" << file_mapping_handle + << "), address(" << data_mem.dmaBuffer.data << "), size: (" + << data_mem.memSize << ")"; + + if (data_mem.dmaBuffer.data == nullptr || data_mem.memSize == 0) { + LOGS_DEFAULT(ERROR) << "Mapping release request address must not be null and size must be > 0"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + auto mapping_info_it = mapping_handle_to_info_map_.find(file_mapping_handle); + if (mapping_info_it == mapping_handle_to_info_map_.end()) { + LOGS_DEFAULT(ERROR) << "File mapping info not found for mapping handle: " << file_mapping_handle; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + MappingInfo_t& mapping_info = mapping_info_it->second; + + LPVOID unaligned_data_ptr = reinterpret_cast(data_mem.dmaBuffer.data); + auto& mapped_data = mapping_info.mapped_data; + auto mapped_data_it = mapped_data.find(unaligned_data_ptr); + + if (mapped_data_it == mapped_data.end()) { + LOGS_DEFAULT(ERROR) << "Failed to find DMA data mapping for address: " << unaligned_data_ptr; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + LPVOID aligned_data_ptr = mapped_data_it->second.aligned_data_ptr; + + CleanUpDataMapping(unaligned_data_ptr, aligned_data_ptr, data_mem.memSize); + if (!mapped_data.erase(unaligned_data_ptr)) { + LOGS_DEFAULT(WARNING) << "Possible leak: failed to remove unordered_map entry for DMA data address: " + << unaligned_data_ptr; + } + + if (mapped_data.empty()) { + CloseHandles(mapping_info.file_handle, file_mapping_handle); + ORT_UNUSED_PARAMETER(mapping_handle_to_info_map_.erase(mapping_info_it)); + } + + return QNN_SUCCESS; +} + +Qnn_ErrorHandle_t WindowsFileMapper::MapRawData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryRawDataResponse_t* response, + void* notify_param) { + ORT_UNUSED_PARAMETER(request); + ORT_UNUSED_PARAMETER(response); + ORT_UNUSED_PARAMETER(notify_param); + + LOGS(*logger_, ERROR) << "File mapping for raw binary data is unsupported on Windows"; + return QNN_CONTEXT_ERROR_UNSUPPORTED_FEATURE; +} + +// Use LOGS_DEFAULT for all clean up functions below as they will be called during destruction of +// QnnBackendManager at time of destruction. Usage of logger_ will not be available and will result +// in a seg fault +Qnn_ErrorHandle_t WindowsFileMapper::ReleaseRawData(Qnn_ContextBinaryRawDataMem_t data_mem, + void* notify_param) { + ORT_UNUSED_PARAMETER(data_mem); + ORT_UNUSED_PARAMETER(notify_param); + + LOGS_DEFAULT(ERROR) << "File mapping for raw binary data is unsupported on Windows"; + return QNN_CONTEXT_ERROR_UNSUPPORTED_FEATURE; +} + +void WindowsFileMapper::CleanUpDataMapping(LPVOID unaligned_data_ptr, LPVOID aligned_data_ptr, + size_t buffer_size) { + if (unaligned_data_ptr) { + // Set file descriptor to -1 to signal deregistration + rpcmem_lib_.Api().register_buf(unaligned_data_ptr, buffer_size, -1, + rpcmem::RPCMEM_ATTR_IMPORT_BUFFER | rpcmem::RPCMEM_ATTR_READ_ONLY); + + auto fd = rpcmem_lib_.Api().to_fd(unaligned_data_ptr); + if (fd != -1) { + LOGS_DEFAULT(ERROR) << "Failed to deregister buffer from RPCMEM: " << unaligned_data_ptr; + } + } + + if (aligned_data_ptr && !UnmapViewOfFile(aligned_data_ptr)) { + LOGS_DEFAULT(ERROR) << "Failed to unmap view of pointer: " << aligned_data_ptr + << ", error code: " << GetLastError(); + } +} + +void WindowsFileMapper::CleanUpDataMappings(const std::unordered_map& mapped_data) { + // Key is unaligned data pointer + for (const auto& kv : mapped_data) { + auto mapped_data_info = kv.second; + // Will handle null ptrs + CleanUpDataMapping(kv.first, mapped_data_info.aligned_data_ptr, + mapped_data_info.buffer_size); + } +} + +void WindowsFileMapper::CloseHandles(HANDLE file_handle, HANDLE file_mapping_handle) { + if (file_mapping_handle && !CloseHandle(file_mapping_handle)) { + LOGS_DEFAULT(ERROR) << "Failed to close file mapping handle: " << file_mapping_handle + << ", error code: " << GetLastError(); + } + if (file_handle && !CloseHandle(file_handle)) { + LOGS_DEFAULT(ERROR) << "Failed to close file handle: " << file_handle + << ", error code: " << GetLastError(); + } +} + +} // namespace qnn +} // namespace onnxruntime + +#endif // QNN_FILE_MAPPED_WEIGHTS_ENABLED diff --git a/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h new file mode 100644 index 0000000000000..143946a443cd7 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/qnn/builder/qnn_file_mapping_callback_interface.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_ENABLED + +#include +#include + +#include + +#include "core/providers/qnn/rpcmem_library.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class WindowsFileMapper : public FileMappingCallbackInterface { + public: + explicit WindowsFileMapper(const logging::Logger& logger); + ~WindowsFileMapper() override; + Status MapContextBin(const std::string& bin_filepath, + void** notify_param) override; + Status ReleaseContextBin(const std::string& model_name) override; + + Status GetContextBinMappingPointer(const std::string& bin_filepath, void** mapping_ptr) override; + + Status FreeContextBinMappingPointer(LPVOID bin_mapping_pointer) override; + + Qnn_ErrorHandle_t MapDmaData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, + void* notify_param) override; + Qnn_ErrorHandle_t ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, + void* notify_param) override; + + Qnn_ErrorHandle_t MapRawData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryRawDataResponse_t* response, + void* notify_param) override; + Qnn_ErrorHandle_t ReleaseRawData(Qnn_ContextBinaryRawDataMem_t data_mem, + void* notify_param) override; + + private: + typedef struct MappedDataInfo { + LPVOID aligned_data_ptr = nullptr; + size_t buffer_size = 0; + } MappedDataInfo_t; + + typedef struct MappingInfo { + HANDLE file_handle; + + // Maps unaligned data pointers to aligned data pointers + std::unordered_map mapped_data; + } MappingInfo_t; + + void CleanUpDataMapping(LPVOID unaligned_data_ptr, LPVOID aligned_data_ptr, + size_t buffer_size); + void CleanUpDataMappings(const std::unordered_map& mapped_data); + void CloseHandles(HANDLE file_handle, HANDLE file_mapping_handle); + + std::mutex map_mutex_; // Applies to both unordered maps + std::unordered_map context_bin_to_mapping_handle_map_; + std::unordered_map mapping_handle_to_info_map_; + std::unordered_set context_bin_map_view_pointers_; + + const logging::Logger* logger_; + + RpcMemLibrary rpcmem_lib_; +}; + +} // namespace qnn +} // namespace onnxruntime + +#endif // QNN_FILE_MAPPED_WEIGHTS_ENABLED diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 737216b81139c..7f3a194886718 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -475,6 +475,21 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio #endif } + static const std::string DISABLE_FILE_MAPPED_WEIGHTS = "disable_file_mapped_weights"; + auto disable_file_mapped_weights_pos = provider_options_map.find(DISABLE_FILE_MAPPED_WEIGHTS); + if (disable_file_mapped_weights_pos != provider_options_map.end()) { + if ("1" == disable_file_mapped_weights_pos->second) { + enable_file_mapped_weights_ = false; + } + LOGS_DEFAULT(VERBOSE) << "User specified disable_file_mapped_weights: " << enable_file_mapped_weights_; + } + +#ifndef QNN_FILE_MAPPED_WEIGHTS_ENABLED + enable_file_mapped_weights_ = false; + LOGS_DEFAULT(WARNING) << "File mapped weights feature is only available on Windows arm64 devices for QNN API versions >= 2.32. " + << "Feature will be disabled by default"; +#endif + static const std::string QNN_DEVICE_ID = "device_id"; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { @@ -926,7 +941,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } std::unordered_map>> context_bin_map; - if (enable_vtcm_backup_buffer_sharing_) { + if (enable_vtcm_backup_buffer_sharing_ || enable_file_mapped_weights_) { std::unordered_set ep_ctx_nodes; GetMainEPCtxNodes(graph_viewer, ep_ctx_nodes, logger); @@ -939,7 +954,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer NodeAttrHelper node_helper(*ep_ctx_node); std::string context_bin_filepath(parent_path.string()); context_bin_filepath.append("/").append(node_helper.Get(qnn::EP_CACHE_CONTEXT, "")); - if (context_bin_map.find(context_bin_filepath) == context_bin_map.end()) { context_bin_map.emplace(context_bin_filepath, std::make_unique>()); // Push context bin filepath for lookup between sessions @@ -956,6 +970,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer context_cache_enabled_ && enable_spill_fill_buffer_, share_ep_contexts_, enable_vtcm_backup_buffer_sharing_, + enable_file_mapped_weights_, context_bin_map); context_bin_map.clear(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index dd301d7915935..05a73304806ec 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -119,6 +119,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; + bool enable_file_mapped_weights_ = true; #if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc index 20918f8bc6de1..f89a15157ddf4 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.cc +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -165,6 +165,8 @@ RpcMemApi CreateApi(void* library_handle) { ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_to_fd", (void**)&api.to_fd)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "remote_register_buf_attr2", (void**)&api.register_buf)); + return api; } diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.h b/onnxruntime/core/providers/qnn/rpcmem_library.h index 2746e147373bb..0f4b5b5391f59 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.h +++ b/onnxruntime/core/providers/qnn/rpcmem_library.h @@ -24,6 +24,9 @@ constexpr uint32_t RPCMEM_DEFAULT_FLAGS = 1; constexpr int RPCMEM_HEAP_ID_SYSTEM = 25; +constexpr int RPCMEM_ATTR_IMPORT_BUFFER = 256; +constexpr int RPCMEM_ATTR_READ_ONLY = 512; + /** * Allocate a zero-copy buffer for size upto 2 GB with the FastRPC framework. * Buffers larger than 2 GB must be allocated with rpcmem_alloc2 @@ -46,6 +49,17 @@ using FreeFnPtr = void (*)(void* po); */ using ToFdFnPtr = int (*)(void* po); +/** + * Registers and maps a CPU buffer to RPC memory space + * @param[in] buff Data pointer for a CPU-allocated buffer + * @param[in] size Size of the buffer in bytes + * @param[in] fd File descriptor for a CPU-allocated buffer + * Note: Can be NULL if N/A or -1 to signal deregistration + * @param[in] attr Specified attributes for the buffer + * @return Data pointer for an RPCMEM-allocated buffer + */ +using RegisterBufFnPtr = void (*)(void* buff, size_t size, int fd, int attr); + } // namespace rpcmem // RPCMEM API function pointers. @@ -53,6 +67,7 @@ struct RpcMemApi { rpcmem::AllocFnPtr alloc; rpcmem::FreeFnPtr free; rpcmem::ToFdFnPtr to_fd; + rpcmem::RegisterBufFnPtr register_buf; }; // Loads and provides access to the RPCMEM API functions from a dynamically loaded library. diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 3468e2e55c7b6..16bf08415fe33 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -352,7 +352,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device "qnn_saver_path", "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer", "enable_htp_shared_memory_allocator", "dump_json_qnn_graph", - "json_qnn_graph_dir"}); + "json_qnn_graph_dir", "disable_file_mapped_weights"}); for (const auto& provider_option : provider_options) { const std::string& key = provider_option.first; const std::string& value = provider_option.second; @@ -404,7 +404,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } } else if (key == "htp_arch") { - std::set supported_htp_archs = {"0", "68", "69", "73", "75"}; + std::set supported_htp_archs = {"0", "68", "69", "73", "75", "81"}; if (supported_htp_archs.find(value) == supported_htp_archs.end()) { std::ostringstream str_stream; std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), @@ -416,7 +416,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer" || key == "enable_htp_shared_memory_allocator" || - key == "dump_json_qnn_graph") { + key == "dump_json_qnn_graph" || + key == "disable_file_mapped_weights") { std::set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a2f1b9b56538b..77d13742413b5 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -1897,6 +1897,114 @@ TEST_F(QnnHTPBackendTests, VTCMBackupBufferSharing) { std::remove(qnn_ctx_binary_file_name1.c_str()); } +TEST_F(QnnHTPBackendTests, FileMapping_Off) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + provider_options["disable_file_mapped_weights"] = "1"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + // cleanup in case some failure test doesn't remove them + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + + DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); + + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + // 2 *_ctx.onn point to same .bin file + EXPECT_TRUE(qnn_ctx_binary_file_name1 == qnn_ctx_binary_file_name2); + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + EXPECT_TRUE(file_size_1 > 0); + + // only load and run the session on real device +#if defined(__aarch64__) || defined(_M_ARM64) + Ort::SessionOptions so1; + so1.SetLogId("so1"); + so1.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so1.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so2; + + // Test CreateFromBinaryListAsync path + provider_options["enable_vtcm_backup_buffer_sharing"] = "1"; + so2.SetLogId("so2"); + so2.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so2.AppendExecutionProvider("QNN", provider_options); + + EXPECT_TRUE(2 == ctx_model_paths.size()); +#ifdef _WIN32 + std::wstring ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::wstring ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#else + std::string ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::string ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#endif + Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so1); + Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so2); + + std::vector input_names; + std::vector output_names; + GetModelInputNames(ctx_model_paths[1], input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{2, 3}; + std::vector input_value(2 * 3, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_outputs1 = session1.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + auto ort_outputs2 = session2.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); +#endif + + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + std::remove(qnn_ctx_binary_file_name1.c_str()); +} + // For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled // Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) {