diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index d22edaf33eb1c..fe51f924310ba 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -839,6 +839,23 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord return Status::OK(); } +Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { + QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; + ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); + + QnnContext_Config_t* configs[] = {&context_priority_config, nullptr}; + for (const auto& context_handle : contexts_) { + auto result = qnn_interface_.contextSetConfig(context_handle, (const QnnContext_Config_t**)configs); + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to set context priority for context handle: ", context_handle); + } + + return Status::OK(); +} + +Status QnnBackendManager::ResetContextPriority() { + return SetContextPriority(context_priority_); +} + Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 3e68df3024565..84454350973c7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -219,6 +219,11 @@ class QnnBackendManager : public std::enable_shared_from_this // For each node name, a mapping to the context handle will be created void ProcessContextFromBinListAsync(Qnn_ContextHandle_t handle, void* notifyParam); + // Sets the context priority to the given value, if valid + Status SetContextPriority(ContextPriority context_priority); + // Resets the context priority to the session default as defined by context_priority_ + Status ResetContextPriority(); + private: Status LoadBackend(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 236447cc95c3d..c5f5795273dcb 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -1545,4 +1545,38 @@ OrtDevice QNNExecutionProvider::GetOrtDeviceByMemType(OrtMemType /* em_type */) return default_device_; } +Status QNNExecutionProvider::SetEpDynamicOptions(gsl::span keys, + gsl::span values) { + if (keys.size() != values.size()) { + LOGS_DEFAULT(ERROR) << "SetEpDynamicOptions: number of keys (" << keys.size() + << ") does not equal number of values (" << values.size() << ")."; + } + auto key_it = keys.begin(); + auto value_it = values.begin(); + + while (key_it != keys.end() && value_it != values.end()) { + std::string key(*key_it); + std::string value(*value_it); + + if (key == kOrtEpDynamicOptionsWorkloadType) { + if (value == "Default") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ResetContextPriority()); + } else if (value == "Efficient") { + ORT_RETURN_IF_ERROR(qnn_backend_manager_->SetContextPriority(qnn::ContextPriority::LOW)); + } else { + LOGS_DEFAULT(ERROR) << "Invalid EP Workload Type: " << value; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid EP Workload Type."); + } + } else { + LOGS_DEFAULT(ERROR) << "EP Dynamic Option \"" << key << "\" is not currently supported."; + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported EP Dynamic Option"); + } + + key_it++; + value_it++; + } + + return Status::OK(); +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 06f9726ae96cf..a54589aa3e3ca 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -57,6 +57,9 @@ class QNNExecutionProvider : public IExecutionProvider { OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + Status SetEpDynamicOptions(gsl::span keys, + gsl::span value) override; + private: std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 4febfe7ba836d..3335c242112ab 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -1649,7 +1649,6 @@ static void DumpModelWithSharedCtx(ProviderOptions provider_options, Ort::Session session2(*ort_env, ToPathString(onnx_model_path2).c_str(), so); } -#if defined(__aarch64__) || defined(_M_ARM64) static void GetModelInputNames(const std::string& model_path, std::vector& input_names, std::vector& output_names, @@ -1669,7 +1668,6 @@ static void GetModelInputNames(const std::string& model_path, output_names.push_back(output->Name()); } } -#endif // 1. Create 2 QDQ models // 2. Initialize 2 Ort sessions which share the same QNN EP from these 2 QDQ models @@ -1994,6 +1992,73 @@ TEST_F(QnnHTPBackendTests, LoadFromArrayWithQnnEpContextGenPathValidation) { }); } } + +TEST_F(QnnHTPBackendTests, QnnEpDynamicOptions) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", provider_options); + so.SetLogSeverityLevel(ORT_LOGGING_LEVEL_VERBOSE); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so); + + std::vector input_names; + std::vector output_names; + GetModelInputNames("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx", input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{3, 4}; + std::vector input_value(3 * 4, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + const char* const workload_type[] = {"ep.dynamic.workload_type"}; + const char* const efficient_type[] = {"Efficient"}; + const char* const default_type[] = {"Default"}; + + // Test Efficient & Default options + session.SetEpDynamicOptions(workload_type, efficient_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + session.SetEpDynamicOptions(workload_type, default_type, 1); + ort_output = session.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + + // Test invalid EP dynamic option and invalid workload type + const char* const dne[] = {"DNE"}; + try { + session.SetEpDynamicOptions(workload_type, dne, 1); + FAIL() << "Expected exception to be thrown for workload type DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Invalid EP Workload Type.", e.what()); + } + + try { + session.SetEpDynamicOptions(dne, efficient_type, 1); + FAIL() << "Expected exception to be thrown for dynamic option DNE but was set successfully"; + } catch (const std::exception& e) { + EXPECT_STREQ("Unsupported EP Dynamic Option", e.what()); + } +} #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test