Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,32 @@ typedef enum OrtExecutionProviderDevicePolicy {
OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER,
} OrtExecutionProviderDevicePolicy;

/** \brief Delegate to allow providing custom OrtEpDevice selection logic
*
* This delegate is called by the EP selection code to allow the user to provide custom device selection logic.
* The user can use this to select OrtEpDevice instances from the list of available devices.
*
* \param ep_devices The list of available devices.
* \param num_devices The number of available devices.
* \param model_metadata The model metadata.
* \param runtime_metadata The runtime metadata. May be nullptr.
* \param selected Pre-allocated array to populate with selected OrtEpDevice pointers from ep_devices.
* \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array.
Currently the maximum is 8.
* \param num_ep_devices The number of selected devices.
*
* \return OrtStatus* Selection status. Return nullptr on success.
* Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
* ORT will release the OrtStatus* if not null.
*/
typedef OrtStatus* (*EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices,
_In_ size_t num_devices,
_In_ const OrtKeyValuePairs* model_metadata,
_In_opt_ const OrtKeyValuePairs* runtime_metadata,
_Inout_ const OrtEpDevice** selected,
_In_ size_t max_selected,
_Out_ size_t* num_selected);

/** \brief Algorithm to use for cuDNN Convolution Op
*/
typedef enum OrtCudnnConvAlgoSearch {
Expand Down Expand Up @@ -5073,7 +5099,8 @@ struct OrtApi {
ORT_API2_STATUS(GetEpDevices, _In_ const OrtEnv* env,
_Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices);

/** \brief Append execution provider to the session options by name.
/** \brief Append the execution provider that is responsible for the selected OrtEpDevice instances
* to the session options.
*
* \param[in] session_options Session options to add execution provider to.
* \param[in] env Environment that execution providers were registered with.
Expand All @@ -5098,6 +5125,21 @@ struct OrtApi {
_In_reads_(num_op_options) const char* const* ep_option_vals,
size_t num_ep_options);

/** \brief Set the execution provider selection policy for the session.
*
* Allows users to specify a device selection policy for automatic execution provider (EP) selection,
* or provide a delegate callback for custom selection logic.
*
* \param[in] session_options The OrtSessionOptions instance.
* \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy).
* \param[in] delegate Optional delegate callback for custom selection. Pass nullptr to use the built-in policy.
*
* \since Version 1.22
*/
ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options,
_In_ OrtExecutionProviderDevicePolicy policy,
_In_opt_ EpSelectionDelegate* delegate);

/** \brief Get the hardware device type.
*
* \param[in] device The OrtHardwareDevice instance to query.
Expand Down
12 changes: 10 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1085,19 +1085,27 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
SessionOptionsImpl& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptionsImpl& AppendExecutionProvider_TensorRT_V2(const OrtTensorRTProviderOptionsV2& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT
SessionOptionsImpl& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider_CANN
SessionOptionsImpl& AppendExecutionProvider_CANN(const OrtCANNProviderOptions& provider_options);
///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider_Dnnl
SessionOptionsImpl& AppendExecutionProvider_Dnnl(const OrtDnnlProviderOptions& provider_options);
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider. Currently supports QNN, SNPE and XNNPACK.
SessionOptionsImpl& AppendExecutionProvider(const std::string& provider_name,
const std::unordered_map<std::string, std::string>& provider_options = {});

/// Append EPs that have been registered previously with the OrtEnv.
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2
SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
const KeyValuePairs& ep_options);
/// Append EPs that have been registered previously with the OrtEnv.
/// Wraps OrtApi::SessionOptionsAppendExecutionProvider_V2
SessionOptionsImpl& AppendExecutionProvider_V2(Env& env, const std::vector<ConstEpDevice>& ep_devices,
const std::unordered_map<std::string, std::string>& ep_options);

/// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy
SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy,
EpSelectionDelegate* delegate = nullptr);

SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
SessionOptionsImpl& SetCustomJoinThreadFn(OrtCustomJoinThreadFn ort_custom_join_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomJoinThreadFn
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1149,6 +1149,13 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy,
EpSelectionDelegate* delegate) {
ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy, delegate));
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) {
ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(this->p_, ort_custom_create_thread_fn));
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/common/cpuid_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ void CPUIDInfo::X86Init() {
GetCPUID(0, data);

vendor_ = GetX86Vendor(data);
vendor_id_ = GetVendorId(vendor_);

int num_IDs = data[0];
if (num_IDs >= 1) {
Expand Down Expand Up @@ -151,6 +152,14 @@ std::string CPUIDInfo::GetX86Vendor(int32_t* data) {

#endif // defined(CPUIDINFO_ARCH_X86)

uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) {
if (vendor == "GenuineIntel") return 0x8086;
if (vendor == "GenuineAMD") return 0x1022;
if (vendor.find("Qualcomm") == 0) return 'Q' << 24 | 'C' << 16 | 'O' << 8 | 'M';
if (vendor.find("NV") == 0) return 0x10DE;
return 0;
}

#if defined(CPUIDINFO_ARCH_ARM)

#if defined(__linux__)
Expand Down Expand Up @@ -204,6 +213,7 @@ void CPUIDInfo::ArmLinuxInit() {
void CPUIDInfo::ArmWindowsInit() {
// Get the ARM vendor string from the registry
vendor_ = GetArmWindowsVendor();
vendor_id_ = GetVendorId(vendor_);

// Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry
// There should be one per CPU
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/common/cpuid_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ class CPUIDInfo {
return vendor_;
}

uint32_t GetCPUVendorId() const {
return vendor_id_;
}

bool HasAMX_BF16() const { return has_amx_bf16_; }
bool HasAVX() const { return has_avx_; }
bool HasAVX2() const { return has_avx2_; }
Expand Down Expand Up @@ -123,6 +127,9 @@ class CPUIDInfo {
bool has_arm_neon_bf16_{false};

std::string vendor_;
uint32_t vendor_id_;

uint32_t GetVendorId(const std::string& vendor);

#if defined(CPUIDINFO_ARCH_X86)

Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,13 @@ static Status CreateEpContextModel(const ExecutionProviders& execution_providers
ORT_RETURN_IF(ep_context_gen_options.error_if_no_compiled_nodes,
"Compiled model does not contain any EPContext nodes. "
"Check that the session EPs support compilation and can execute at least one model subgraph.");
return Status::OK();

LOGS(logger, WARNING) << "Compiled model does not contain any EPContext nodes. "
"Either the session EPs do not support compilation or "
"no subgraphs were able to be compiled.";

// we continue on to generate the compiled model which may benefit from L1 optimizations even if there are not
// EPContext nodes.
}

auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair<bool, const Node*> {
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,15 @@ struct EpContextModelGenerationOptions {
size_t output_external_initializer_size_threshold = 0;
};

struct EpSelectionPolicy {
// flag to detect that a policy was set by the user.
// need to preserve current behavior of defaulting to CPU EP if no EPs are explicitly registered
// and no selection policy was explicitly provided.
bool enable{false};
OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT;
EpSelectionDelegate* delegate{};
};

/**
* Configuration information for a session.
*/
Expand Down Expand Up @@ -222,6 +231,11 @@ struct SessionOptions {
// copied internally and the flag needs to be accessible across all copies.
std::shared_ptr<std::atomic_bool> load_cancellation_flag = std::make_shared<std::atomic_bool>(false);

// Policy to guide Execution Provider selection
EpSelectionPolicy ep_selection_policy = {false,
OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_DEFAULT,
nullptr};

// Options for generating compile EPContext models were previously stored in session_option.configs as
// string key/value pairs. To support more advanced options, such as setting input/output buffers, we
// now have to store EPContext options in a struct of type EpContextModelGenerationOptions.
Expand Down Expand Up @@ -253,6 +267,7 @@ inline std::ostream& operator<<(std::ostream& os, const SessionOptions& session_
<< " use_per_session_threads:" << session_options.use_per_session_threads
<< " thread_pool_allow_spinning:" << session_options.thread_pool_allow_spinning
<< " use_deterministic_compute:" << session_options.use_deterministic_compute
<< " ep_selection_policy:" << session_options.ep_selection_policy.policy
<< " config_options: { " << session_options.config_options << " }"
//<< " initializers_to_share_map:" << session_options.initializers_to_share_map
#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_EXTERNAL_INITIALIZERS)
Expand Down
85 changes: 50 additions & 35 deletions onnxruntime/core/platform/windows/device_discovery.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@

uint32_t vendor_id = get_id(buffer, L"VEN_");
uint32_t device_id = get_id(buffer, L"DEV_");
// won't always have a vendor id from an ACPI entry. need at least a device id to identify the hardware

// Processor ID should come from CPUID mapping.
if (vendor_id == 0 && guid == GUID_DEVCLASS_PROCESSOR) {
vendor_id = CPUIDInfo::GetCPUIDInfo().GetCPUVendorId();
}

// Won't always have a vendor id from an ACPI entry. ACPI is not defined for this purpose.
if (vendor_id == 0 && device_id == 0) {
continue;
}
Expand All @@ -138,8 +144,8 @@
entry = &device_info[key];
entry->vendor_id = vendor_id;
entry->device_id = device_id;
// put the first hardware id string in the metadata. ignore the other lines.
entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer)));
// put the first hardware id string in the metadata. ignore the other lines. not sure if this is of value.
// entry->metadata.emplace(L"SPDRP_HARDWAREID", std::wstring(buffer, wcslen(buffer)));
} else {
// need valid ids
continue;
Expand All @@ -156,22 +162,22 @@
(PBYTE)buffer, sizeof(buffer), &size)) {
std::wstring desc{buffer};

// Should we require the NPU to be found by DXCore or do we want to allow this vague matching?
// Probably depends on whether we always attempt to run DXCore or not.
const auto possible_npu = [](const std::wstring& desc) {
return (desc.find(L"NPU") != std::wstring::npos ||
desc.find(L"Neural") != std::wstring::npos ||
desc.find(L"AI Engine") != std::wstring::npos ||
desc.find(L"VPU") != std::wstring::npos);
};
// For now, require dxcore to identify an NPU.
// If we want to try and infer it from the description this _may_ work but is untested.
// const auto possible_npu = [](const std::wstring& desc) {
// return (desc.find(L"NPU") != std::wstring::npos ||
// desc.find(L"Neural") != std::wstring::npos ||
// desc.find(L"AI Engine") != std::wstring::npos ||
// desc.find(L"VPU") != std::wstring::npos);
// };

// use description if no friendly name
if (entry->description.empty()) {
entry->description = desc;
}

uint64_t npu_key = GetDeviceKey(*entry);
bool is_npu = npus.count(npu_key) > 0 || possible_npu(desc);
bool is_npu = npus.count(npu_key) > 0; // rely on dxcore to determine if something is an NPU

if (guid == GUID_DEVCLASS_DISPLAY) {
entry->type = OrtHardwareDeviceType_GPU;
Expand All @@ -194,22 +200,17 @@
continue;
}

if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr,
(PBYTE)buffer, sizeof(buffer), &size)) {
entry->vendor = std::wstring(buffer, wcslen(buffer));
if (entry->type == OrtHardwareDeviceType_CPU) {
// get 12 byte string from CPUID. easier for a user to match this if they are explicitly picking a device.
std::string_view cpuid_vendor = CPUIDInfo::GetCPUIDInfo().GetCPUVendor();
entry->vendor = std::wstring(cpuid_vendor.begin(), cpuid_vendor.end());
}

// Add the UI number if GPU. Helpful if user has integrated and discrete GPUs
if (entry->type == OrtHardwareDeviceType_GPU) {
DWORD ui_number = 0;
if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_UI_NUMBER, nullptr,
(PBYTE)&ui_number, sizeof(ui_number), &size)) {
// use value read.
} else {
// infer it as 0 if not set.
if (entry->vendor.empty()) {
if (SetupDiGetDeviceRegistryPropertyW(devInfo, &devData, SPDRP_MFG, nullptr,
(PBYTE)buffer, sizeof(buffer), &size)) {
entry->vendor = std::wstring(buffer, wcslen(buffer));
}

entry->metadata.emplace(L"SPDRP_UI_NUMBER", std::to_wstring(ui_number));
}
}

Expand Down Expand Up @@ -252,9 +253,7 @@
info.description = std::wstring(desc.Description);

info.metadata[L"DxgiAdapterNumber"] = std::to_wstring(i);
info.metadata[L"VideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB";
info.metadata[L"SystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB";
info.metadata[L"SharedSystemMemory"] = std::to_wstring(desc.DedicatedSystemMemory / (1024 * 1024)) + L" MB";
info.metadata[L"DxgiVideoMemory"] = std::to_wstring(desc.DedicatedVideoMemory / (1024 * 1024)) + L" MB";
}

// iterate by high-performance GPU preference to add that info
Expand All @@ -272,7 +271,7 @@
auto it = device_info.find(key);
if (it != device_info.end()) {
DeviceInfo& info = it->second;
info.metadata[L"HighPerformanceIndex"] = std::to_wstring(i);
info.metadata[L"DxgiHighPerformanceIndex"] = std::to_wstring(i);
}
}

Expand Down Expand Up @@ -405,32 +404,48 @@
}
}

std::wstring_convert<std::codecvt_utf8<wchar_t> > converter; // wstring to string
const auto device_to_ortdevice = [&converter](
// our log output to std::wclog breaks with UTF8 chars that are not supported by the current code page.
// e.g. (TM) symbol. that stops ALL logging working on at least arm64.
// safest way to avoid that is to keep it to single byte chars.
// process the OrtHardwareDevice values this way so it can be safely logged.
// only the 'description' metadata is likely to be affected and that is mainly for diagnostic purposes.
const auto to_safe_string = [](const std::wstring& wstr) -> std::string {
std::string str(wstr.size(), ' ');
std::transform(wstr.begin(), wstr.end(), str.begin(), [](wchar_t wchar) {

Check warning on line 414 in onnxruntime/core/platform/windows/device_discovery.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <algorithm> for transform [build/include_what_you_use] [4] Raw Output: onnxruntime/core/platform/windows/device_discovery.cc:414: Add #include <algorithm> for transform [build/include_what_you_use] [4]
if (wchar >= 0 && wchar <= 127) {
return static_cast<char>(wchar);
}
return ' ';
});
return str;
};

const auto device_to_ortdevice = [&to_safe_string](
DeviceInfo& device,
std::unordered_map<std::wstring, std::wstring>* extra_metadata = nullptr) {
OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, converter.to_bytes(device.vendor)};
OrtHardwareDevice ortdevice{device.type, device.vendor_id, device.device_id, to_safe_string(device.vendor)};

if (!device.description.empty()) {
ortdevice.metadata.Add("Description", converter.to_bytes(device.description));
ortdevice.metadata.Add("Description", to_safe_string(device.description));
}

for (auto& [key, value] : device.metadata) {
ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value));
ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value));
}

if (extra_metadata) {
// add any extra metadata from the dxcore info
for (auto& [key, value] : *extra_metadata) {
if (device.metadata.find(key) == device.metadata.end()) {
ortdevice.metadata.Add(converter.to_bytes(key), converter.to_bytes(value));
ortdevice.metadata.Add(to_safe_string(key), to_safe_string(value));
}
}
}

std::ostringstream oss;
oss << "Adding OrtHardwareDevice {vendor_id:0x" << std::hex << ortdevice.vendor_id
<< ", device_id:0x" << ortdevice.device_id
<< ", vendor:" << ortdevice.vendor
<< ", type:" << std::dec << static_cast<int>(ortdevice.type)
<< ", metadata: [";
for (auto& [key, value] : ortdevice.metadata.entries) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/abi_key_value_pairs.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ struct OrtKeyValuePairs {
keys.erase(key_iter);
values.erase(values.begin() + idx);
}

entries.erase(iter);
}
}

Expand Down
Loading
Loading