diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 1cb3bf5c1e461..6fdd5357bf4dc 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -1092,6 +1092,24 @@ Status QnnBackendManager::CreateContextFromListAsyncWithCallback(const QnnContex void* buffer; ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappedMemoryPtr(context_bin_filepath, &buffer)); + auto sys_ctx_handle = GetSystemContextHandle(); + ORT_RETURN_IF(sys_ctx_handle == nullptr, "System context handle is null."); + + uint32_t graph_count = 0; + Qnn_Version_t blob_version; + QnnSystemContext_GraphInfo_t* graphs_info = nullptr; + ORT_RETURN_IF_ERROR(GetGraphInfoAndBinVersion(sys_ctx_handle.get(), + buffer, + static_cast(buffer_size), + blob_version, + graph_count, + &graphs_info)); + + // Return ORT failure to continue to retry logic in CreateContextVtcmBackupBufferSharingEnabled() + ORT_RETURN_IF(!MinVersionMet(blob_version, {3, 3, 3}), "Context binary of ", context_bin_filepath, " is v", + blob_version.major, ".", blob_version.minor, ".", blob_version.patch, + ". File mapping is only supported for versions >= 3.3.3"); + auto notify_param_ptr = std::make_unique(buffer, buffer_size, this); Qnn_ContextBinaryCallback_t context_file_map_callbacks; @@ -1277,42 +1295,27 @@ Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer, max_spill_fill_buffer_size = 0; // spill fill starts from 2.28 #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) - bool result = nullptr == qnn_sys_interface_.systemContextCreate || - nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || - nullptr == qnn_sys_interface_.systemContextFree; - ORT_RETURN_IF(result, "Failed to get valid function pointer."); - - QnnSystemContext_Handle_t sys_ctx_handle = nullptr; - auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); - - const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; - Qnn_ContextBinarySize_t binary_info_size{0}; - rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, - static_cast(buffer), - buffer_length, - &binary_info, - &binary_info_size); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); + auto sys_ctx_handle = GetSystemContextHandle(); + ORT_RETURN_IF(sys_ctx_handle == nullptr, "System context handle is null."); - // binary_info life cycle is here - // Binary info to graph info - // retrieve Qnn graph info from binary info - ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); uint32_t graph_count = 0; QnnSystemContext_GraphInfo_t* graphs_info = nullptr; - if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { - graph_count = binary_info->contextBinaryInfoV3.numGraphs; - graphs_info = binary_info->contextBinaryInfoV3.graphs; - } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { - graph_count = binary_info->contextBinaryInfoV2.numGraphs; - graphs_info = binary_info->contextBinaryInfoV2.graphs; - } else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { - graph_count = binary_info->contextBinaryInfoV1.numGraphs; - graphs_info = binary_info->contextBinaryInfoV1.graphs; - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); - } +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_Version_t blob_version; + ORT_RETURN_IF_ERROR(GetGraphInfoAndBinVersion(sys_ctx_handle.get(), + static_cast(buffer), + static_cast(buffer_length), + blob_version, + graph_count, + &graphs_info)); + ORT_UNUSED_PARAMETER(blob_version); +#else + ORT_RETURN_IF_ERROR(GetGraphInfoAndBinVersion(sys_ctx_handle.get(), + static_cast(buffer), + static_cast(buffer_length), + graph_count, + &graphs_info)); +#endif for (uint32_t i = 0; i < graph_count; ++i) { if (graphs_info[i].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { @@ -1344,16 +1347,12 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t std::string node_name, QnnModelLookupTable& qnn_models, int64_t max_spill_fill_size) { - bool result = nullptr == qnn_sys_interface_.systemContextCreate || - nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || - nullptr == qnn_sys_interface_.systemContextFree; - ORT_RETURN_IF(result, "Failed to get valid function pointer."); - void* bin_buffer = nullptr; + bool use_file_mapping = file_mapped_weights_enabled_; #ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE // A nonzero buffer length implies an embedded context - if (file_mapped_weights_enabled_ && buffer_length == 0) { - ORT_RETURN_IF(!file_mapper_, "Attemping to use File Mapping feature but file_mapper_ is uninitialized"); + if (use_file_mapping && buffer_length == 0) { + ORT_RETURN_IF(!file_mapper_, "Attempting to use File Mapping feature but file_mapper_ is uninitialized"); ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_length)); @@ -1361,6 +1360,11 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappedMemoryPtr(context_bin_filepath, &bin_buffer)); } else { + if (use_file_mapping) { + use_file_mapping = false; + LOGS(*logger_, WARNING) << "Node " << node_name << " is using an embedded cache." + << " Disabling file mapping for this node."; + } ORT_RETURN_IF(buffer == nullptr, "Attempting to load QNN context from buffer but buffer is null"); bin_buffer = static_cast(buffer); } @@ -1368,44 +1372,35 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t bin_buffer = static_cast(buffer); #endif - QnnSystemContext_Handle_t sys_ctx_handle = nullptr; - auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); - - const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; - Qnn_ContextBinarySize_t binary_info_size{0}; - rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, - bin_buffer, - buffer_length, - &binary_info, - &binary_info_size); - ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); + auto sys_ctx_handle = GetSystemContextHandle(); + ORT_RETURN_IF(sys_ctx_handle == nullptr, "System context handle is null."); - // binary_info life cycle is here - // Binary info to graph info - // retrieve Qnn graph info from binary info - ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); uint32_t graph_count = 0; QnnSystemContext_GraphInfo_t* graphs_info = nullptr; - if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { - graph_count = binary_info->contextBinaryInfoV1.numGraphs; - graphs_info = binary_info->contextBinaryInfoV1.graphs; - } -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 15) // starts from 2.22 - else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { - graph_count = binary_info->contextBinaryInfoV2.numGraphs; - graphs_info = binary_info->contextBinaryInfoV2.graphs; - } +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_Version_t blob_version; + ORT_RETURN_IF_ERROR(GetGraphInfoAndBinVersion(sys_ctx_handle.get(), + bin_buffer, + static_cast(buffer_length), + blob_version, + graph_count, + &graphs_info)); +#else + ORT_RETURN_IF_ERROR(GetGraphInfoAndBinVersion(sys_ctx_handle.get(), + bin_buffer, + static_cast(buffer_length), + graph_count, + &graphs_info)); #endif -#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) // starts from 2.28 - else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { - graph_count = binary_info->contextBinaryInfoV3.numGraphs; - graphs_info = binary_info->contextBinaryInfoV3.graphs; + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + if (use_file_mapping && !MinVersionMet(blob_version, {3, 3, 3})) { + LOGS(*logger_, WARNING) << "Context binary of " << node_name << " is v" + << blob_version.major << "." << blob_version.minor << "." << blob_version.patch + << ". File mapping is only supported for versions >= 3.3.3. Disabling file mapping for this node."; + use_file_mapping = false; } #endif - else { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); - } ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count; @@ -1452,7 +1447,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t #ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE Qnn_ContextBinaryCallback_t callbacks; - if (file_mapped_weights_enabled_ && file_mapper_) { + if (use_file_mapping && file_mapper_) { ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinaryWithCallback, "Invalid function pointer for contextCreateFromBinaryWithCallback."); @@ -1477,9 +1472,10 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } #endif + Qnn_ErrorHandle_t rt = QNN_SUCCESS; #ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE std::vector backup_buffer; - if (file_mapped_weights_enabled_ && file_mapper_) { + if (use_file_mapping && file_mapper_) { rt = qnn_interface_.contextCreateFromBinaryWithCallback(backend_handle_, device_handle_, context_configs, @@ -1501,9 +1497,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t bin_buffer = static_cast(backup_buffer.data()); } } -#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE - - if (!file_mapped_weights_enabled_ || rt != QNN_SUCCESS) { +#endif + if (!use_file_mapping || rt != QNN_SUCCESS) { rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, context_configs, @@ -1544,10 +1539,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } } - qnn_sys_interface_.systemContextFree(sys_ctx_handle); - sys_ctx_handle = nullptr; context_created_ = true; - LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed."; return Status::OK(); } @@ -2338,5 +2330,88 @@ Status QnnBackendManager::GetOrRegisterContextMemHandle(Qnn_ContextHandle_t cont return Status::OK(); } +std::unique_ptr> QnnBackendManager::GetSystemContextHandle() { + if (nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextFree) { + LOGS(*logger_, ERROR) << "Failed to get valid function pointers for system context handle creation and destruction."; + return nullptr; + } + + QnnSystemContext_Handle_t sys_ctx_handle = nullptr; + auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); + if (QNN_SUCCESS != rt) { + LOGS(*logger_, ERROR) << "Failed to create system handle."; + return nullptr; + } + + auto sys_ctx_handle_deleter = [&qnn_sys_interface = qnn_sys_interface_](void* handle) { + if (qnn_sys_interface.systemContextFree) { + qnn_sys_interface.systemContextFree(reinterpret_cast(handle)); + } else { + LOGS_DEFAULT(ERROR) << "qnn_sys_interface.systemContextFree is null. Unable to free system context handle"; + } + }; + + std::unique_ptr> sys_ctx_handle_uptr(sys_ctx_handle, sys_ctx_handle_deleter); + return sys_ctx_handle_uptr; +} + +Status QnnBackendManager::GetGraphInfoAndBinVersion(QnnSystemContext_Handle_t sys_ctx_handle, + void* buffer, + Qnn_ContextBinarySize_t buffer_length, +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_Version_t& blob_version, +#endif + uint32_t& graph_count, + QnnSystemContext_GraphInfo_t** graphs_info) { + bool result = nullptr == qnn_sys_interface_.systemContextGetBinaryInfo; + ORT_RETURN_IF(result, "Failed to get valid function pointer to retrieve binary info from context binary."); + ORT_RETURN_IF(sys_ctx_handle == nullptr, "System context handle is null."); + + // The lifetime of binary_info's contents is tied to the lifetime of + // the obj pointed to by sys_ctx_handle (owned by caller) + const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; + Qnn_ContextBinarySize_t binary_info_size{0}; + auto rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, + buffer, + buffer_length, + &binary_info, + &binary_info_size); + + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info."); + ORT_RETURN_IF(nullptr == binary_info, "Qnn cached binary info is nullptr."); + + // Extract graph info and context bin version from binary_info + if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + graph_count = binary_info->contextBinaryInfoV1.numGraphs; + *graphs_info = binary_info->contextBinaryInfoV1.graphs; +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + blob_version = binary_info->contextBinaryInfoV1.contextBlobVersion; +#endif + } +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 15) // starts from 2.22 + else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + graph_count = binary_info->contextBinaryInfoV2.numGraphs; + *graphs_info = binary_info->contextBinaryInfoV2.graphs; +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + blob_version = binary_info->contextBinaryInfoV2.contextBlobVersion; +#endif + } +#endif +#if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 21) // starts from 2.28 + else if (binary_info->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + graph_count = binary_info->contextBinaryInfoV3.numGraphs; + *graphs_info = binary_info->contextBinaryInfoV3.graphs; +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + blob_version = binary_info->contextBinaryInfoV3.contextBlobVersion; +#endif + } +#endif + else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported context binary info version."); + } + + return Status::OK(); +} + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index fe4ec0b7018a5..f5d47806b6765 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -11,6 +11,7 @@ #include #endif +#include #include #include #include @@ -460,7 +461,28 @@ class QnnBackendManager : public std::enable_shared_from_this return Status::OK(); } - private: + std::unique_ptr> GetSystemContextHandle(); + + Status GetGraphInfoAndBinVersion(QnnSystemContext_Handle_t sys_ctx_handle, + void* buffer, + Qnn_ContextBinarySize_t buffer_length, +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_Version_t& blob_version, +#endif + uint32_t& graph_count, + QnnSystemContext_GraphInfo_t** graphs_info); + + // Checks if act_ver is >= min_ver. An act_ver of 0.0.0 is considered invalid. + static bool MinVersionMet(const Qnn_Version_t& act_ver, const Qnn_Version_t& min_ver) { + if (act_ver.major == 0 && act_ver.minor == 0 && act_ver.patch == 0) { + return false; + } + + return act_ver.major > min_ver.major || + (act_ver.major == min_ver.major && act_ver.minor > min_ver.minor) || + (act_ver.major == min_ver.major && act_ver.minor == min_ver.minor && act_ver.patch >= min_ver.patch); + } + const std::string backend_path_; std::recursive_mutex logger_recursive_mutex_; const logging::Logger* logger_ = nullptr;