From 027e7d535fce6622b780677c51438641f6b0727e Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Sat, 3 May 2025 16:48:51 +1000 Subject: [PATCH] Selection policy. Device discovery updates. Bug fixes. (#24625) ### Description Add initial selection policy implementations. Update device discovery - get vendor and vendor id for CPU from cpuid_info - trim metadata to known useful fields - NPU detection via dxcore only Bug fixes/updates from PRs for C# and python bindings Add some tests for selection policy - TODO: Add more tests ### Motivation and Context Desire to boil oceans. --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../core/session/onnxruntime_c_api.h | 44 ++- .../core/session/onnxruntime_cxx_api.h | 12 +- .../core/session/onnxruntime_cxx_inline.h | 7 + onnxruntime/core/common/cpuid_info.cc | 10 + onnxruntime/core/common/cpuid_info.h | 7 + .../core/framework/graph_partitioner.cc | 8 +- onnxruntime/core/framework/session_options.h | 15 + .../core/platform/windows/device_discovery.cc | 85 +++-- .../core/session/abi_key_value_pairs.h | 2 + .../core/session/abi_session_options.cc | 11 + .../core/session/ep_library_internal.cc | 8 +- onnxruntime/core/session/inference_session.h | 4 +- .../core/session/model_compilation_options.cc | 4 +- onnxruntime/core/session/onnxruntime_c_api.cc | 3 +- onnxruntime/core/session/ort_apis.h | 4 + .../core/session/provider_policy_context.cc | 347 ++++++++++++++++++ .../core/session/provider_policy_context.h | 79 ++++ onnxruntime/core/session/utils.cc | 17 +- .../test/autoep/test_autoep_selection.cc | 122 +++++- 19 files changed, 724 insertions(+), 65 deletions(-) create mode 100644 onnxruntime/core/session/provider_policy_context.cc create mode 100644 onnxruntime/core/session/provider_policy_context.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 9478a7ee3e77f..b7f3df8c0b2d7 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -425,6 +425,32 @@ typedef enum OrtExecutionProviderDevicePolicy { OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER, } OrtExecutionProviderDevicePolicy; +/** \brief Delegate to allow providing custom OrtEpDevice selection logic + * + * This delegate is called by the EP selection code to allow the user to provide custom device selection logic. + * The user can use this to select OrtEpDevice instances from the list of available devices. + * + * \param ep_devices The list of available devices. + * \param num_devices The number of available devices. + * \param model_metadata The model metadata. + * \param runtime_metadata The runtime metadata. May be nullptr. + * \param selected Pre-allocated array to populate with selected OrtEpDevice pointers from ep_devices. + * \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array. + Currently the maximum is 8. + * \param num_ep_devices The number of selected devices. + * + * \return OrtStatus* Selection status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus* (*EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected); + /** \brief Algorithm to use for cuDNN Convolution Op */ typedef enum OrtCudnnConvAlgoSearch { @@ -5073,7 +5099,8 @@ struct OrtApi { ORT_API2_STATUS(GetEpDevices, _In_ const OrtEnv* env, _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices); - /** \brief Append execution provider to the session options by name. + /** \brief Append the execution provider that is responsible for the selected OrtEpDevice instances + * to the session options. * * \param[in] session_options Session options to add execution provider to. * \param[in] env Environment that execution providers were registered with. @@ -5098,6 +5125,21 @@ struct OrtApi { _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options); + /** \brief Set the execution provider selection policy for the session. + * + * Allows users to specify a device selection policy for automatic execution provider (EP) selection, + * or provide a delegate callback for custom selection logic. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy). + * \param[in] delegate Optional delegate callback for custom selection. Pass nullptr to use the built-in policy. + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options, + _In_ OrtExecutionProviderDevicePolicy policy, + _In_opt_ EpSelectionDelegate* delegate); + /** \brief Get the hardware device type. * * \param[in] device The OrtHardwareDevice instance to query. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 0ecc27c59dc28..6c175c606b4a1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1085,19 +1085,27 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options); - ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options); /// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK. SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name, const std::unordered_map& provider_options = {}); + /// Append EPs that have been registered previously with the OrtEnv. + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, const KeyValuePairs& ep_options); + /// Append EPs that have been registered previously with the OrtEnv. + /// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2 SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector& ep_devices, const std::unordered_map& ep_options); + /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy + SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, + EpSelectionDelegate* delegate = nullptr); + SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 48b3b80cced55..1fdb8f16d9600 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1149,6 +1149,13 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( return *this; } +template +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, + EpSelectionDelegate* delegate) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy, delegate)); + return *this; +} + template inline SessionOptionsImpl& SessionOptionsImpl::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn)); diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 97766028cfe12..ee7782e3c8763 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -106,6 +106,7 @@ void CPUIDInfo::X86Init() { GetCPUID(0, data); vendor_ = GetX86Vendor(data); + vendor_id_ = GetVendorId(vendor_); int num_IDs = data[0]; if (num_IDs >= 1) { @@ -151,6 +152,14 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) { #endif // defined(CPUIDINFO_ARCH_X86) +uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { + if (vendor == "GenuineIntel") return 0x8086; + if (vendor == "GenuineAMD") return 0x1022; + if (vendor.find("Qualcomm") == 0) return 'Q' << 24 | 'C' << 16 | 'O' << 8 | 'M'; + if (vendor.find("NV") == 0) return 0x10DE; + return 0; +} + #if defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) @@ -204,6 +213,7 @@ void CPUIDInfo::ArmLinuxInit() { void CPUIDInfo::ArmWindowsInit() { // Get the ARM vendor string from the registry vendor_ = GetArmWindowsVendor(); + vendor_id_ = GetVendorId(vendor_); // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry // There should be one per CPU diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 4d6e7e8b9105e..b820fa2ab1af7 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -19,6 +19,10 @@ class CPUIDInfo { return vendor_; } + uint32_t GetCPUVendorId() const { + return vendor_id_; + } + bool HasAMX_BF16() const { return has_amx_bf16_; } bool HasAVX() const { return has_avx_; } bool HasAVX2() const { return has_avx2_; } @@ -123,6 +127,9 @@ class CPUIDInfo { bool has_arm_neon_bf16_{false}; std::string vendor_; + uint32_t vendor_id_; + + uint32_t GetVendorId(const std::string& vendor); #if defined(CPUIDINFO_ARCH_X86) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e3e54be3f7c21..8ed5eeaa8d44f 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -806,7 +806,13 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers ORT_RETURN_IF(ep_context_gen_options.error_if_no_compiled_nodes, "Compiled model does not contain any EPContext nodes. " "Check that the session EPs support compilation and can execute at least one model subgraph."); - return Status::OK(); + + LOGS(logger, WARNING) << "Compiled model does not contain any EPContext nodes. " + "Either the session EPs do not support compilation or " + "no subgraphs were able to be compiled."; + + // we continue on to generate the compiled model which may benefit from L1 optimizations even if there are not + // EPContext nodes. } auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 94ff2bb55a055..8f8a3d6634a7e 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -90,6 +90,15 @@ struct EpContextModelGenerationOptions { size_t output_external_initializer_size_threshold = 0; }; +struct EpSelectionPolicy { + // flag to detect that a policy was set by the user. + // need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered + // and no selection policy was explicitly provided. + bool enable{false}; + OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT; + EpSelectionDelegate* delegate{}; +}; + /** * Configuration information for a session. */ @@ -222,6 +231,11 @@ struct SessionOptions { // copied internally and the flag needs to be accessible across all copies. std::shared_ptr load_cancellation_flag = std::make_shared(false); + // Policy to guide Execution Provider selection + EpSelectionPolicy ep_selection_policy = {false, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_DEFAULT, + nullptr}; + // Options for generating compile EPContext models were previously stored in session_option.configs as // string key/value pairs. To support more advanced options, such as setting input/output buffers, we // now have to store EPContext options in a struct of type EpContextModelGenerationOptions. @@ -253,6 +267,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_ << " use_per_session_threads:" << session_options.use_per_session_threads << " thread_pool_allow_spinning:" << session_options.thread_pool_allow_spinning << " use_deterministic_compute:" << session_options.use_deterministic_compute + << " ep_selection_policy:" << session_options.ep_selection_policy.policy << " config_options: { " << session_options.config_options << " }" //<< " initializers_to_share_map:" << session_options.initializers_to_share_map #if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS) diff --git a/onnxruntime/core/platform/windows/device_discovery.cc b/onnxruntime/core/platform/windows/device_discovery.cc index 88fbec37c8075..5a5b5041a5912 100644 --- a/onnxruntime/core/platform/windows/device_discovery.cc +++ b/onnxruntime/core/platform/windows/device_discovery.cc @@ -119,7 +119,13 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde uint32_t vendor_id = get_id(buffer, L"VEN_"); uint32_t device_id = get_id(buffer, L"DEV_"); - // won't always have a vendor id from an ACPI entry. need at least a device id to identify the hardware + + // Processor ID should come from CPUID mapping. + if (vendor_id == 0 && guid == GUID_DEVCLASS_PROCESSOR) { + vendor_id = CPUIDInfo::GetCPUIDInfo().GetCPUVendorId(); + } + + // Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose. if (vendor_id == 0 && device_id == 0) { continue; } @@ -138,8 +144,8 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde entry = &device_info[key]; entry->vendor_id = vendor_id; entry->device_id = device_id; - // put the first hardware id string in the metadata. ignore the other lines. - entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer))); + // put the first hardware id string in the metadata. ignore the other lines. not sure if this is of value. + // entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer))); } else { // need valid ids continue; @@ -156,14 +162,14 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde (PBYTE)buffer, sizeof(buffer), &size)) { std::wstring desc{buffer}; - // Should we require the NPU to be found by DXCore or do we want to allow this vague matching? - // Probably depends on whether we always attempt to run DXCore or not. - const auto possible_npu = [](const std::wstring& desc) { - return (desc.find(L"NPU") != std::wstring::npos || - desc.find(L"Neural") != std::wstring::npos || - desc.find(L"AI Engine") != std::wstring::npos || - desc.find(L"VPU") != std::wstring::npos); - }; + // For now, require dxcore to identify an NPU. + // If we want to try and infer it from the description this _may_ work but is untested. + // const auto possible_npu = [](const std::wstring& desc) { + // return (desc.find(L"NPU") != std::wstring::npos || + // desc.find(L"Neural") != std::wstring::npos || + // desc.find(L"AI Engine") != std::wstring::npos || + // desc.find(L"VPU") != std::wstring::npos); + // }; // use description if no friendly name if (entry->description.empty()) { @@ -171,7 +177,7 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde } uint64_t npu_key = GetDeviceKey(*entry); - bool is_npu = npus.count(npu_key) > 0 || possible_npu(desc); + bool is_npu = npus.count(npu_key) > 0; // rely on dxcore to determine if something is an NPU if (guid == GUID_DEVCLASS_DISPLAY) { entry->type = OrtHardwareDeviceType_GPU; @@ -194,22 +200,17 @@ std::unordered_map GetDeviceInfoSetupApi(const std::unorde continue; } - if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr, - (PBYTE)buffer, sizeof(buffer), &size)) { - entry->vendor = std::wstring(buffer, wcslen(buffer)); + if (entry->type == OrtHardwareDeviceType_CPU) { + // get 12 byte string from CPUID. easier for a user to match this if they are explicitly picking a device. + std::string_view cpuid_vendor = CPUIDInfo::GetCPUIDInfo().GetCPUVendor(); + entry->vendor = std::wstring(cpuid_vendor.begin(), cpuid_vendor.end()); } - // Add the UI number if GPU. Helpful if user has integrated and discrete GPUs - if (entry->type == OrtHardwareDeviceType_GPU) { - DWORD ui_number = 0; - if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_UI_NUMBER, nullptr, - (PBYTE)&ui_number, sizeof(ui_number), &size)) { - // use value read. - } else { - // infer it as 0 if not set. + if (entry->vendor.empty()) { + if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr, + (PBYTE)buffer, sizeof(buffer), &size)) { + entry->vendor = std::wstring(buffer, wcslen(buffer)); } - - entry->metadata.emplace(L"SPDRP_UI_NUMBER", std::to_wstring(ui_number)); } } @@ -252,9 +253,7 @@ std::unordered_map GetDeviceInfoD3D12() { info.description = std::wstring(desc.Description); info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i); - info.metadata[L"VideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; - info.metadata[L"SystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB"; - info.metadata[L"SharedSystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB"; + info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB"; } // iterate by high-performance GPU preference to add that info @@ -272,7 +271,7 @@ std::unordered_map GetDeviceInfoD3D12() { auto it = device_info.find(key); if (it != device_info.end()) { DeviceInfo& info = it->second; - info.metadata[L"HighPerformanceIndex"] = std::to_wstring(i); + info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i); } } @@ -405,25 +404,40 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor } } - std::wstring_convert > converter; // wstring to string - const auto device_to_ortdevice = [&converter]( + // our log output to std::wclog breaks with UTF8 chars that are not supported by the current code page. + // e.g. (TM) symbol. that stops ALL logging working on at least arm64. + // safest way to avoid that is to keep it to single byte chars. + // process the OrtHardwareDevice values this way so it can be safely logged. + // only the 'description' metadata is likely to be affected and that is mainly for diagnostic purposes. + const auto to_safe_string = [](const std::wstring& wstr) -> std::string { + std::string str(wstr.size(), ' '); + std::transform(wstr.begin(), wstr.end(), str.begin(), [](wchar_t wchar) { + if (wchar >= 0 && wchar <= 127) { + return static_cast(wchar); + } + return ' '; + }); + return str; + }; + + const auto device_to_ortdevice = [&to_safe_string]( DeviceInfo& device, std::unordered_map* extra_metadata = nullptr) { - OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, converter.to_bytes(device.vendor)}; + OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, to_safe_string(device.vendor)}; if (!device.description.empty()) { - ortdevice.metadata.Add("Description", converter.to_bytes(device.description)); + ortdevice.metadata.Add("Description", to_safe_string(device.description)); } for (auto& [key, value] : device.metadata) { - ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value)); + ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value)); } if (extra_metadata) { // add any extra metadata from the dxcore info for (auto& [key, value] : *extra_metadata) { if (device.metadata.find(key) == device.metadata.end()) { - ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value)); + ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value)); } } } @@ -431,6 +445,7 @@ std::unordered_set DeviceDiscovery::DiscoverDevicesForPlatfor std::ostringstream oss; oss << "Adding OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id << ", device_id:0x" << ortdevice.device_id + << ", vendor:" << ortdevice.vendor << ", type:" << std::dec << static_cast(ortdevice.type) << ", metadata: ["; for (auto& [key, value] : ortdevice.metadata.entries) { diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 3242be817881a..150575b3a9efc 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -57,6 +57,8 @@ struct OrtKeyValuePairs { keys.erase(key_iter); values.erase(values.begin() + idx); } + + entries.erase(iter); } } diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 0b116c2fa64f6..b1c0467da642e 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -366,6 +366,17 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* options, + _In_ OrtExecutionProviderDevicePolicy policy, + _In_opt_ EpSelectionDelegate* delegate) { + API_IMPL_BEGIN + options->value.ep_selection_policy.enable = true; + options->value.ep_selection_policy.policy = policy; + options->value.ep_selection_policy.delegate = delegate; + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, _In_ bool is_cancel) { API_IMPL_BEGIN diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index c515195c7e6bf..aa032f24f13c0 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -183,14 +183,14 @@ std::vector> EpLibraryInternal::CreateInterna // CPU EP internal_eps.push_back(CreateCpuEp()); -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - #if defined(USE_WEBGPU) internal_eps.push_back(CreateWebGpuEp()); #endif +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + return internal_eps; } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 2a395050636ba..1956654f4538b 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -348,8 +348,8 @@ class InferenceSession { /** * Initializes a previously loaded ONNX model. Initialization includes but is not - * limited to graph transformations, construction of kernels, etc. - * This method assumes that a method has been loaded previously. + * limited to graph transformations, construction of kernels, EP policy decisions, etc. + * This method assumes that a model has been loaded previously. * This API is thread-safe. * @return OK if success */ diff --git a/onnxruntime/core/session/model_compilation_options.cc b/onnxruntime/core/session/model_compilation_options.cc index 80ef18de5cfa3..c4a7c5262d03d 100644 --- a/onnxruntime/core/session/model_compilation_options.cc +++ b/onnxruntime/core/session/model_compilation_options.cc @@ -19,7 +19,9 @@ ModelCompilationOptions::ModelCompilationOptions(const OrtEnv& env, const OrtSes session_options_.value.ep_context_gen_options = session_options.value.GetEpContextGenerationOptions(); session_options_.value.ep_context_gen_options.enable = true; session_options_.value.ep_context_gen_options.overwrite_existing_output_file = true; - session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = true; + // defaulting to false to support wider usage. will log WARNING if compiling model with no context nodes. + // TODO: Add ability for user to explicitly set this. + session_options_.value.ep_context_gen_options.error_if_no_compiled_nodes = false; // Shouldn't fail because the key/value strings are below the maximum string length limits in ConfigOptions. ORT_ENFORCE(session_options_.value.config_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1").IsOK()); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 7e8d770871867..317676c90bc4f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3012,6 +3012,7 @@ static constexpr OrtApi ort_api_1_to_22 = { &OrtApis::UnregisterExecutionProviderLibrary, &OrtApis::GetEpDevices, &OrtApis::SessionOptionsAppendExecutionProvider_V2, + &OrtApis::SessionOptionsSetEpSelectionPolicy, &OrtApis::HardwareDevice_Type, &OrtApis::HardwareDevice_VendorId, @@ -3061,7 +3062,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo // no additions in version 19, 20, and 21 static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); -static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 315, "Size of version 22 API cannot change"); +static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 316, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.22.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 0033eb0d604f2..7928f9b822cf0 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -575,6 +575,10 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOpt _In_reads_(num_op_options) const char* const* ep_option_vals, size_t num_ep_options); +ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* sess_options, + _In_ OrtExecutionProviderDevicePolicy policy, + _In_opt_ EpSelectionDelegate* delegate); + // OrtHardwareDevice accessors. ORT_API(OrtHardwareDeviceType, HardwareDevice_Type, _In_ const OrtHardwareDevice* device); ORT_API(uint32_t, HardwareDevice_VendorId, _In_ const OrtHardwareDevice* device); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc new file mode 100644 index 0000000000000..565891fe2cdfb --- /dev/null +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -0,0 +1,347 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/session/provider_policy_context.h" + +#include + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/ep_factory_internal.h" +#include "core/session/inference_session.h" +#include "core/session/inference_session_utils.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +namespace { +bool MatchesEpVendor(const OrtEpDevice* d) { + // TODO: Would be better to match on Id. Should the EP add that in EP metadata? + return d->device->vendor == d->ep_vendor; +} + +bool IsDiscreteDevice(const OrtEpDevice* d) { + if (d->device->type != OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return false; + } + + const auto& entries = d->device->metadata.entries; + if (auto it = entries.find("Discrete"); it != entries.end()) { + return it->second == "1"; + } + + return false; +} + +bool IsDefaultCpuEp(const OrtEpDevice* d) { + return d->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU && + d->ep_vendor == "Microsoft"; +} + +// Sort devices. NPU -> GPU -> CPU +// Within in type, vendor owned, not. +// Default CPU EP is last +std::vector OrderDevices(const std::vector& devices) { + std::vector sorted_devices(devices.begin(), devices.end()); + std::sort(sorted_devices.begin(), sorted_devices.end(), [](const OrtEpDevice* a, const OrtEpDevice* b) { + auto aDeviceType = a->device->type; + auto bDeviceType = b->device->type; + if (aDeviceType != bDeviceType) { + // NPU -> GPU -> CPU + // std::sort is ascending order, so NPU < GPU < CPU + + // one is an NPU + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + return true; + } else if (bDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + return false; + } + + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return true; + } else if (bDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + return false; + } + + // this shouldn't be reachable as it would imply both are CPU + ORT_THROW("Unexpected combination of devices"); + } + + // both devices are the same + + if (aDeviceType == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + bool aDiscrete = IsDiscreteDevice(a); + bool bDiscrete = IsDiscreteDevice(b); + if (aDiscrete != bDiscrete) { + return aDiscrete == true; // prefer discrete + } + + // both discrete or both integrated + } + + // prefer device matching platform vendor + bool aVendor = MatchesEpVendor(a); + bool bVendor = MatchesEpVendor(b); + if (aVendor != bVendor) { + return aVendor == true; // prefer the device that matches the EP vendor + } + + // default CPU EP last + bool aIsDefaultCpuEp = IsDefaultCpuEp(a); + bool bIsDefaultCpuEp = IsDefaultCpuEp(b); + if (!aIsDefaultCpuEp && !bIsDefaultCpuEp) { + // neither are default CPU EP. both do/don't match vendor. + // TODO: implement tie-breaker for this scenario. arbitrarily prefer the first for now + return true; + } + + // one is the default CPU EP + return aIsDefaultCpuEp == false; // prefer the one that is not the default CPU EP + }); + + return sorted_devices; +} +} // namespace + +// Select execution providers based on the device policy and available devices and add to session +Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, + InferenceSession& sess) { + ORT_ENFORCE(options.value.ep_selection_policy.delegate == nullptr, + "EP selection delegate support is not implemented yet."); + + // Get the list of devices from the environment and order them. + // Ordered by preference within each type. NPU -> GPU -> NPU + // TODO: Should environment.cc do the ordering? + const auto& execution_devices = OrderDevices(env.GetOrtEpDevices()); + + // The list of devices selected by policies + std::vector devices_selected; + + // Run the delegate if it was passed in lieu of any other policy + if (options.value.ep_selection_policy.delegate) { + auto policy_fn = options.value.ep_selection_policy.delegate; + std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); + std::array selected_devices{nullptr}; + + size_t num_selected = 0; + auto* status = (*policy_fn)(delegate_devices.data(), delegate_devices.size(), + nullptr, nullptr, selected_devices.data(), selected_devices.size(), &num_selected); + + // return or fall-through for both these cases + // going with explicit failure for now so it's obvious to user what is happening + if (status != nullptr) { + std::string delegate_error_msg = OrtApis::GetErrorMessage(status); // copy + OrtApis::ReleaseStatus(status); + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate failed: ", delegate_error_msg); + } + + if (num_selected == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); + } + } else { + // Create the selector for the chosen policy + std::unique_ptr selector; + switch (options.value.ep_selection_policy.policy) { + case OrtExecutionProviderDevicePolicy_DEFAULT: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_CPU: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_NPU: + case OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY: + case OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER: + selector = std::make_unique(); + break; + case OrtExecutionProviderDevicePolicy_PREFER_GPU: + case OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE: + selector = std::make_unique(); + break; + } + + // Execute policy + + selector->SelectProvidersForDevices(execution_devices, devices_selected); + } + + // Fail if we did not find any device matches + if (devices_selected.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "No execution providers selected. Please check the device policy and available devices."); + } + + // Configure the session options for the devices. This updates the SessionOptions in the InferenceSession with any + // EP options that have not been overridden by the user. + ORT_RETURN_IF_ERROR(AddEpDefaultOptionsToSession(sess, devices_selected)); + + // Create OrtSessionOptions for the CreateEp call. + // Once the InferenceSession is created, its SessionOptions is the source of truth and contains all the values from + // the user provided OrtSessionOptions. We do a copy for simplicity. The OrtSessionOptions instance goes away + // once we exit this function so an EP implementation should not use OrtSessionOptions after it returns from + // CreateEp. + auto& session_options = sess.GetMutableSessionOptions(); + OrtSessionOptions ort_so; + ort_so.value = session_options; + const auto& session_logger = sess.GetLogger(); + const OrtLogger& api_session_logger = *session_logger->ToExternal(); + + // Remove the ORT CPU EP if configured to do so + bool disable_ort_cpu_ep = ort_so.value.config_options.GetConfigEntry(kOrtSessionOptionsDisableCPUEPFallback) == "1"; + if (disable_ort_cpu_ep) { + RemoveOrtCpuDevice(devices_selected); + } + + // Fold the EPs into a single structure per factory + std::vector eps_selected; + FoldSelectedDevices(devices_selected, eps_selected); + + // Iterate through the selected EPs and create them + for (size_t idx = 0; idx < eps_selected.size(); ++idx) { + std::unique_ptr ep = nullptr; + ORT_RETURN_IF_ERROR(CreateExecutionProvider(env, ort_so, api_session_logger, eps_selected[idx], ep)); + if (ep != nullptr) { + ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); + } + } + + return Status::OK(); +} + +void ProviderPolicyContext::FoldSelectedDevices(std::vector devices_selected, + std::vector& eps_selected) { + while (devices_selected.size() > 0) { + const auto ep_name = std::string(devices_selected[0]->ep_name); + SelectionInfo info; + info.ep_factory = devices_selected[0]->ep_factory; + + do { + auto iter = std::find_if(devices_selected.begin(), devices_selected.end(), [&ep_name](const OrtEpDevice* d) { + return d->ep_name == ep_name; + }); + + if (iter != devices_selected.end()) { + info.devices.push_back((*iter)->device); + info.ep_metadata.push_back(&(*iter)->ep_metadata); + devices_selected.erase(iter); + } else { + break; + } + } while (true); + + eps_selected.push_back(info); + } +} + +Status ProviderPolicyContext::CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, + const OrtLogger& logger, + SelectionInfo& info, std::unique_ptr& ep) { + EpFactoryInternal* internal_factory = env.GetEpFactoryInternal(info.ep_factory); + + if (internal_factory) { + // this is a factory we created and registered internally for internal and provider bridge EPs + OrtStatus* status = internal_factory->CreateIExecutionProvider(info.devices.data(), info.ep_metadata.data(), + info.devices.size(), &options, &logger, + &ep); + if (status != nullptr) { + return ToStatus(status); + } + } else { + // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, + // and we would add that IExecutionProvider to the InferenceSession. + // but first we need OrtEp and the OrtEpApi to be implemented. + ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + + // OrtEp* api_ep = nullptr; + //// add the ep_options to session options but leave any existing entries (user provided overrides) untouched. + // auto status = info.ep_factory->CreateEp(info.ep_factory, info.devices.data(), info.ep_metadata.data(), + // info.devices.size(), &options, &logger, + // &api_ep); + // if (status != nullptr) { + // return ToStatus(status); + // } + } + + return Status::OK(); +} + +Status ProviderPolicyContext::AddEpDefaultOptionsToSession(InferenceSession& sess, + std::vector devices) { + auto& config_options = sess.GetMutableSessionOptions().config_options; + for (auto device : devices) { + const std::string ep_options_prefix = OrtSessionOptions::GetProviderOptionPrefix(device->ep_name.c_str()); + for (const auto& [key, value] : device->ep_options.entries) { + const std::string option_key = ep_options_prefix + key; + // preserve user-provided options as they override any defaults the EP factory specified earlier + if (config_options.configurations.find(option_key) == config_options.configurations.end()) { + // use AddConfigEntry for the error checking it does + ORT_RETURN_IF_ERROR(config_options.AddConfigEntry(option_key.c_str(), value.c_str())); + } + } + } + + return Status::OK(); +} + +void ProviderPolicyContext::RemoveOrtCpuDevice(std::vector& devices) { + // Remove the Microsoft CPU EP. always last if available. + if (IsDefaultCpuEp(devices.back())) { + devices.pop_back(); + } +} + +void DefaultEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Default policy is prefer CPU + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} + +void PreferCpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first CPU device from sorted devices + auto first_cpu = std::find_if(sorted_devices.begin(), sorted_devices.end(), + [](const OrtEpDevice* device) { + return device->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU; + }); + + ORT_ENFORCE(first_cpu != sorted_devices.end(), "No CPU based execution providers were found."); + selected.push_back(*first_cpu); + + // add ORT CPU EP as the final option to ensure maximum coverage of opsets and operators + if (!IsDefaultCpuEp(*first_cpu) && IsDefaultCpuEp(sorted_devices.back())) { + selected.push_back(sorted_devices.back()); + } +} + +void PreferNpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first NPU if there is one. + if (sorted_devices.front()->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_NPU) { + selected.push_back(sorted_devices.front()); + } + + // CPU fallback + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} + +void PreferGpuEpPolicy::SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected) { + // Select the first GPU device + auto first_gpu = std::find_if(sorted_devices.begin(), sorted_devices.end(), + [](const OrtEpDevice* device) { + return device->device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU; + }); + + if (first_gpu != sorted_devices.end()) { + selected.push_back(*first_gpu); + } + + // Add a CPU fallback + PreferCpuEpPolicy().SelectProvidersForDevices(sorted_devices, selected); +} +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/provider_policy_context.h b/onnxruntime/core/session/provider_policy_context.h new file mode 100644 index 0000000000000..185f9523312ba --- /dev/null +++ b/onnxruntime/core/session/provider_policy_context.h @@ -0,0 +1,79 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) + +#include "core/session/abi_session_options_impl.h" +#include "core/session/environment.h" +#include "core/session/onnxruntime_c_api.h" // For OrtExecutionProviderDevicePolicy + +namespace onnxruntime { + +struct SelectionInfo { + OrtEpFactory* ep_factory; + std::vector devices; + std::vector ep_metadata; +}; + +class IEpPolicySelector { + public: + /// + /// Select the OrtEpDevice instances to use. + /// Selection is in priority order. Highest priority first. + /// + /// Ordered devices. + /// Type order is NPU -> GPU -> CPU + /// Within a type: Discrete -> Integrated if GPU, EP vendor matches device vendor, vendor does not match + /// ORT CPU EP is always last if available. + /// + /// + virtual void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) = 0; + + virtual ~IEpPolicySelector() = default; +}; + +class ProviderPolicyContext { + public: + ProviderPolicyContext() = default; + + Status SelectEpsForSession(const Environment& env, const OrtSessionOptions& options, InferenceSession& sess); + Status AddEpDefaultOptionsToSession(InferenceSession& sess, std::vector devices); + void RemoveOrtCpuDevice(std::vector& devices); + Status CreateExecutionProvider(const Environment& env, OrtSessionOptions& options, const OrtLogger& logger, + SelectionInfo& info, std::unique_ptr& ep); + void FoldSelectedDevices(std::vector devices_selected, // copy + std::vector& eps_selected); + + private: +}; + +class DefaultEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferCpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferNpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +class PreferGpuEpPolicy : public IEpPolicySelector { + public: + void SelectProvidersForDevices(const std::vector& sorted_devices, + std::vector& selected_devices) override; +}; + +} // namespace onnxruntime + +#endif // !ORT_MINIMAL_BUILD diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index adb019fdde86d..d17514e54a945 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -18,6 +18,7 @@ #include "core/session/ort_apis.h" #include "core/session/ort_env.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/provider_policy_context.h" using namespace onnxruntime; #if !defined(ORT_MINIMAL_BUILD) @@ -71,6 +72,11 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con return ToStatus(status); } } else { + // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, + // and we would add that IExecutionProvider to the InferenceSession. + ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + + /* OrtEp* api_ep = nullptr; auto status = ep_device->ep_factory->CreateEp( ep_device->ep_factory, devices.data(), ep_metadata.data(), devices.size(), @@ -79,10 +85,7 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con if (status != nullptr) { return ToStatus(status); } - - // in the real setup we need an IExecutionProvider wrapper implementation that uses the OrtEp internally, - // and we would add that IExecutionProvider to the InferenceSession. - ORT_NOT_IMPLEMENTED("IExecutionProvider that wraps OrtEp has not been implemented."); + */ } ORT_RETURN_IF_ERROR(sess.RegisterExecutionProvider(std::move(ep))); @@ -175,6 +178,12 @@ OrtStatus* CreateSessionAndLoadModel(_In_ const OrtSessionOptions* options, if (auto_select_ep_name) { ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env->GetEnvironment(), *sess, *auto_select_ep_name)); } + + // if there are no providers registered, and there's an ep selection policy set, do auto ep selection + if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) { + ProviderPolicyContext context; + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env->GetEnvironment(), *options, *sess)); + } #endif // !defined(ORT_MINIMAL_BUILD) #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index 619f0a4bcda33..04b1b2ea0bdc4 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -64,7 +64,10 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod const std::vector& expected_dims_y, const std::vector& expected_values_y, bool auto_select = true, // auto select vs SessionOptionsAppendExecutionProvider_V2 + // manual select using functor const std::function&)>& select_devices = nullptr, + // auto select using policy + std::optional policy = std::nullopt, bool test_session_creation_only = false) { Ort::SessionOptions session_options; @@ -74,16 +77,20 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } if (auto_select) { - // manually specify EP to select for now - session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); - - // add the provider options to the session options with the required prefix - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); - std::vector keys, values; - ep_options.GetKeyValuePairs(keys, values); - for (size_t i = 0, end = keys.size(); i < end; ++i) { - // add the default value with prefix - session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); + if (policy) { + session_options.SetEpSelectionPolicy(*policy); + } else { + // manually specify EP to select + session_options.AddConfigEntry("test.ep_to_select", ep_to_select.c_str()); + + // add the provider options to the session options with the required prefix + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_to_select.c_str()); + std::vector keys, values; + ep_options.GetKeyValuePairs(keys, values); + for (size_t i = 0, end = keys.size(); i < end; ++i) { + // add the default value with prefix + session_options.AddConfigEntry((option_prefix + keys[i]).c_str(), values[i]); + } } } else { std::vector devices; @@ -188,7 +195,7 @@ TEST(AutoEpSelection, DmlEP) { devices.push_back(ep_device); } else { // if this is available, 0 == best performance - auto* perf_index = c_api->GetKeyValue(kvps, "HighPerformanceIndex"); + auto* perf_index = c_api->GetKeyValue(kvps, "DxgiHighPerformanceIndex"); if (perf_index && strcmp(perf_index, "0") == 0) { devices[0] = ep_device; // replace as this is the higher performance device } @@ -213,20 +220,27 @@ TEST(AutoEpSelection, WebGpuEP) { TEST(AutoEpSelection, MiscApiTests) { const OrtApi* c_api = &Ort::GetApi(); - // nullptr and empty input to OrtKeyValuePairs + // nullptr and empty input to OrtKeyValuePairs. also test RemoveKeyValuePair { OrtKeyValuePairs* kvps = nullptr; c_api->CreateKeyValuePairs(&kvps); c_api->AddKeyValuePair(kvps, "key1", nullptr); // should be ignored c_api->AddKeyValuePair(kvps, nullptr, "value1"); // should be ignored c_api->RemoveKeyValuePair(kvps, nullptr); // should be ignored - - c_api->AddKeyValuePair(kvps, "", "value2"); // empty key should be ignored + c_api->AddKeyValuePair(kvps, "", "value2"); // should be ignored ASSERT_EQ(c_api->GetKeyValue(kvps, ""), nullptr); + c_api->AddKeyValuePair(kvps, "key1", "value1"); c_api->AddKeyValuePair(kvps, "key2", ""); // empty value is allowed ASSERT_EQ(c_api->GetKeyValue(kvps, "key2"), std::string("")); + c_api->RemoveKeyValuePair(kvps, "key1"); + const char* const* keys = nullptr; + const char* const* values = nullptr; + size_t num_entries = 0; + c_api->GetKeyValuePairs(kvps, &keys, &values, &num_entries); + ASSERT_EQ(num_entries, 1); + c_api->ReleaseKeyValuePairs(kvps); } @@ -259,6 +273,86 @@ TEST(AutoEpSelection, MiscApiTests) { } } +TEST(AutoEpSelection, PreferCpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_CPU); +} + +// this should fallback to CPU if no GPU +TEST(AutoEpSelection, PreferGpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_GPU); +} + +// this should fallback to CPU if no NPU +TEST(AutoEpSelection, PreferNpu) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_NPU); +} + namespace { struct ExamplePluginInfo { const std::filesystem::path library_path =