diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 620c13b8641b5..c543414ca13a9 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -353,6 +353,7 @@ public struct OrtApi public IntPtr SessionOptionsAppendExecutionProvider_V2; public IntPtr SessionOptionsSetEpSelectionPolicy; + public IntPtr SessionOptionsSetEpSelectionPolicyDelegate; public IntPtr HardwareDevice_Type; public IntPtr HardwareDevice_VendorId; @@ -692,6 +693,11 @@ static NativeMethods() (DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer( api_.SessionOptionsSetEpSelectionPolicy, typeof(DSessionOptionsSetEpSelectionPolicy)); + + OrtSessionOptionsSetEpSelectionPolicyDelegate = + (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer( + api_.SessionOptionsSetEpSelectionPolicyDelegate, + typeof(DSessionOptionsSetEpSelectionPolicyDelegate)); } internal class NativeLib @@ -2278,28 +2284,49 @@ out IntPtr lora_adapter #region Auto EP API related // // OrtKeyValuePairs + + /// + /// Create an OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtCreateKeyValuePairs(out IntPtr /* OrtKeyValuePairs** */ kvps); + /// + /// Add/replace a key-value pair in the OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, byte[] /* const char* */ key, byte[] /* const char* */ value); + /// + /// Get the value for the provided key. + /// + /// Value. Returns IntPtr.Zero if key was not found. [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* const char* */ DOrtGetKeyValue(IntPtr /* const OrtKeyValuePairs* */ kvps, byte[] /* const char* */ key); + /// + /// Get all the key-value pairs in the OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtGetKeyValuePairs(IntPtr /* const OrtKeyValuePairs* */ kvps, out IntPtr /* const char* const** */ keys, out IntPtr /* const char* const** */ values, out UIntPtr /* size_t* */ numEntries); + /// + /// Remove a key-value pair from the OrtKeyValuePairs instance. + /// Ignores keys that are not present. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, byte[] /* const char* */ key); + /// + /// Release the OrtKeyValuePairs instance. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate void DOrtReleaseKeyValuePairs(IntPtr /* OrtKeyValuePairs* */ kvps); @@ -2370,12 +2397,27 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, // // Auto Selection EP registration and selection customization + + /// + /// Register an execution provider library. + /// The library must implement CreateEpFactories and ReleaseEpFactory. + /// + /// Environment to add the EP library to. + /// Name to register the library under. + /// Absolute path to the library. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtRegisterExecutionProviderLibrary( IntPtr /* OrtEnv* */ env, byte[] /* const char* */ registration_name, byte[] /* const ORTCHAR_T* */ path); + /// + /// Unregister an execution provider library. + /// + /// The environment to unregister the library from. + /// The name the library was registered under. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtUnregisterExecutionProviderLibrary( IntPtr /* OrtEnv* */ env, @@ -2384,6 +2426,11 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtRegisterExecutionProviderLibrary OrtRegisterExecutionProviderLibrary; public static DOrtUnregisterExecutionProviderLibrary OrtUnregisterExecutionProviderLibrary; + /// + /// Get the OrtEpDevices that are available. + /// These are all the possible execution provider and device pairs. + /// + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtGetEpDevices( IntPtr /* const OrtEnv* */ env, @@ -2392,6 +2439,20 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtGetEpDevices OrtGetEpDevices; + /// + /// Add execution provider devices to the session options. + /// Priority is based on the order of the OrtEpDevice instances. Highest priority first. + /// All OrtEpDevice instances in ep_devices must be for the same execution provider. + /// e.g. selecting OpenVINO for GPU and NPU would have an OrtEpDevice for GPU and NPU. + /// + /// SessionOptions to add to. + /// Environment that the OrtEpDevice instances came from by calling GetEpDevices + /// One or more OrtEpDevice instances. + /// Number of OrtEpDevice instances. + /// User overrides for execution provider options. May be IntPtr.Zero. + /// User overrides for execution provider options. May be IntPtr.Zero. + /// Number of user overrides for execution provider options. + /// [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DOrtSessionOptionsAppendExecutionProvider_V2( IntPtr /* OrtSessionOptions* */ sess_options, @@ -2404,6 +2465,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, public static DOrtSessionOptionsAppendExecutionProvider_V2 OrtSessionOptionsAppendExecutionProvider_V2; + /// + /// Delegate to do custom execution provider selection. + /// + /// Available OrtEpDevices to select from. + /// Number of OrtEpDevices. + /// Metadata from the ONNX model. + /// Runtime metadata. May be IntPtr.Zero. + /// OrtEpDevices that were selected. Pre-allocated array for delegate to update. + /// Maximum number of OrtEpDevices that can be selected. + /// Number of OrtEpDevices that were selected. + /// State that was provided in when the delegate was registered. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr DOrtEpSelectionDelegate( IntPtr /* OrtEpDevice** */ epDevices, @@ -2412,16 +2485,36 @@ public delegate IntPtr DOrtEpSelectionDelegate( IntPtr /* OrtKeyValuePairs* */ runtimeMetadata, IntPtr /* OrtEpDevice** */ selected, uint maxSelected, - out UIntPtr numSelected + out UIntPtr numSelected, + IntPtr /* void* */ state ); + /// + /// Set the execution provider selection policy. + /// + /// SessionOptions to set the policy for. + /// Selection policy. + /// OrtStatus* [UnmanagedFunctionPointer(CallingConvention.Winapi)] public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicy( IntPtr /* OrtSessionOptions* */ session_options, - int /* OrtExecutionProviderDevicePolicy */ policy, - IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate); + int /* OrtExecutionProviderDevicePolicy */ policy); public static DSessionOptionsSetEpSelectionPolicy OrtSessionOptionsSetEpSelectionPolicy; + /// + /// Set the execution provider selection policy delegate. + /// + /// SessionOptions to set the policy for. + /// Selection policy delegate. + /// State that is passed through to the selection delegate. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicyDelegate( + IntPtr /* OrtSessionOptions* */ session_options, + IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate, + IntPtr /* void* */ state); + public static DSessionOptionsSetEpSelectionPolicyDelegate OrtSessionOptionsSetEpSelectionPolicyDelegate; + #endregion #region Misc API diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs index e3947d900214e..0318e08519128 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs @@ -10,18 +10,18 @@ namespace Microsoft.ML.OnnxRuntime /// Represents the combination of an execution provider and a hardware device /// that the execution provider can utilize. /// - public class OrtEpDevice : SafeHandle + public class OrtEpDevice { /// /// Construct an OrtEpDevice from an existing native OrtEpDevice instance. /// /// Native OrtEpDevice handle. internal OrtEpDevice(IntPtr epDeviceHandle) - : base(epDeviceHandle, ownsHandle: false) { + _handle = epDeviceHandle; } - internal IntPtr Handle => handle; + internal IntPtr Handle => _handle; /// /// The name of the execution provider. @@ -30,7 +30,7 @@ public string EpName { get { - IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(handle); + IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(_handle); return NativeOnnxValueHelper.StringFromNativeUtf8(namePtr); } } @@ -42,7 +42,7 @@ public string EpVendor { get { - IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(handle); + IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(_handle); return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); } } @@ -54,7 +54,7 @@ public OrtKeyValuePairs EpMetadata { get { - return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(handle)); + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(_handle)); } } @@ -65,7 +65,7 @@ public OrtKeyValuePairs EpOptions { get { - return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(handle)); + return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(_handle)); } } @@ -76,23 +76,11 @@ public OrtHardwareDevice HardwareDevice { get { - IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(handle); + IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(_handle); return new OrtHardwareDevice(devicePtr); } } - /// - /// Indicates whether the native handle is invalid. - /// - public override bool IsInvalid => handle == IntPtr.Zero; - - /// - /// No-op. OrtEpDevice is always read-only as the instance is owned by native ORT. - /// - /// True - protected override bool ReleaseHandle() - { - return true; - } + private readonly IntPtr _handle; } } \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs index 8e7caae90ff79..af7115a92285e 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs @@ -21,16 +21,16 @@ public enum OrtHardwareDeviceType /// /// Represents a hardware device that is available on the current system. /// - public class OrtHardwareDevice : SafeHandle + public class OrtHardwareDevice { /// /// Construct an OrtHardwareDevice for a native OrtHardwareDevice instance. /// /// Native OrtHardwareDevice handle. - internal OrtHardwareDevice(IntPtr deviceHandle) - : base(deviceHandle, ownsHandle: false) + internal OrtHardwareDevice(IntPtr deviceHandle) { + _handle = deviceHandle; } /// @@ -40,7 +40,7 @@ public OrtHardwareDeviceType Type { get { - return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(handle); + return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(_handle); } } @@ -54,7 +54,7 @@ public uint VendorId { get { - return NativeMethods.OrtHardwareDevice_VendorId(handle); + return NativeMethods.OrtHardwareDevice_VendorId(_handle); } } @@ -65,7 +65,7 @@ public string Vendor { get { - IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(handle); + IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(_handle); return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr); } } @@ -82,7 +82,7 @@ public uint DeviceId { get { - return NativeMethods.OrtHardwareDevice_DeviceId(handle); + return NativeMethods.OrtHardwareDevice_DeviceId(_handle); } } @@ -95,22 +95,10 @@ public OrtKeyValuePairs Metadata { get { - return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(handle)); + return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(_handle)); } } - /// - /// Indicates whether the native handle is invalid. - /// - public override bool IsInvalid => handle == IntPtr.Zero; - - /// - /// No-op. OrtHardwareDevice is always read-only as the instance is owned by native ORT. - /// - /// True - protected override bool ReleaseHandle() - { - return true; - } + private readonly IntPtr _handle; } } \ No newline at end of file diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs index de6189e105f78..d60bf75ccbd7c 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs @@ -636,9 +636,32 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue) public void SetEpSelectionPolicy(ExecutionProviderDevicePolicy policy) { NativeApiStatus.VerifySuccess( - NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy, IntPtr.Zero)); + NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy)); } + /// + /// Set the execution provider selection policy if using automatic execution provider selection. + /// Execution providers must be registered with the OrtEnv to be available for selection. + /// + /// Delegate that implements the custom selection policy. + public void SetEpSelectionPolicyDelegate(EpSelectionDelegate selectionDelegate = null) + { + _epSelectionPolicyConnector = new EpSelectionPolicyConnector(selectionDelegate); + _epSelectionPolicyDelegate = new NativeMethods.DOrtEpSelectionDelegate( + EpSelectionPolicyConnector.EpSelectionPolicyWrapper); + + // make sure these stay alive. not sure if this is necessary when they're class members though + _epSelectionPolicyConnectorHandle = GCHandle.Alloc(_epSelectionPolicyConnector); + _epSelectionPolicyDelegateHandle = GCHandle.Alloc(_epSelectionPolicyDelegate); + + IntPtr funcPtr = Marshal.GetFunctionPointerForDelegate(_epSelectionPolicyDelegate); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtSessionOptionsSetEpSelectionPolicyDelegate( + handle, + funcPtr, + GCHandle.ToIntPtr(_epSelectionPolicyConnectorHandle))); + } #endregion internal IntPtr Handle @@ -914,7 +937,98 @@ public void SetLoadCancellationFlag(bool value) { NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value)); } + #endregion + + #region Selection Policy Delegate helpers + /// + /// Delegate to select execution provider devices from a list of available devices. + /// + /// OrtEpDevices to select from. + /// Model metadata. + /// Runtime metadata. + /// Maximum number of devices that can be selected. + /// Selected devices. Ordered by priority. Highest priority first. + public delegate List EpSelectionDelegate(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections); + + /// + /// Class to bridge the C# and native worlds for the EP selection policy delegate + /// + internal class EpSelectionPolicyConnector + { + private readonly EpSelectionDelegate _csharpDelegate; + + internal EpSelectionPolicyConnector(EpSelectionDelegate selectionDelegate) + { + _csharpDelegate = selectionDelegate; + } + + /// + /// Delegate to convert between the C and C# worlds + /// + /// OrtEpDevices to select from. + /// Number of OrtEpDevices. + /// Model metadata. + /// Runtime metadata. + /// Pre-allocated OrtEpDevice buffer to update with selected devices. + /// Number of entries in selectedOut. + /// Number of OrtEpDevies that were selected. + /// Opaque state. + /// nullptr for OrtStatus* to indicate success. + /// Currently we don't have a way to create an OrtStatus instance from the C# bindings. + /// Can add if we need to return an explicit error message. + /// + public static IntPtr EpSelectionPolicyWrapper(IntPtr /* OrtEpDevice** */ epDevicesIn, + uint numDevices, + IntPtr /* OrtKeyValuePairs* */ modelMetadataIn, + IntPtr /* OrtKeyValuePairs* */ runtimeMetadataIn, + IntPtr /* OrtEpDevice** */ selectedOut, + uint maxSelected, + out UIntPtr numSelected, + IntPtr state) + { + Span epDevicesIntPtrs; + Span selectedDevicesIntPtrs; + EpSelectionPolicyConnector connector = (EpSelectionPolicyConnector)GCHandle.FromIntPtr(state).Target; + + unsafe + { + void* ptr = epDevicesIn.ToPointer(); + epDevicesIntPtrs = new Span(ptr, checked((int)numDevices)); + } + + List epDevices = new List(); + for (int i = 0; i < numDevices; i++) + { + + epDevices.Add(new OrtEpDevice(epDevicesIntPtrs[i])); + } + + OrtKeyValuePairs modelMetadata = new OrtKeyValuePairs(modelMetadataIn); + OrtKeyValuePairs runtimeMetadata = new OrtKeyValuePairs(runtimeMetadataIn); + var selected = connector._csharpDelegate(epDevices, modelMetadata, runtimeMetadata, maxSelected); + + numSelected = (UIntPtr)selected.Count; + + unsafe + { + void* ptr = selectedOut.ToPointer(); + selectedDevicesIntPtrs = new Span(ptr, (int)maxSelected); + } + + int idx = 0; + foreach (var epDevice in selected) + { + selectedDevicesIntPtrs[idx] = epDevice.Handle; + idx++; + } + + return IntPtr.Zero; + } + } #endregion #region Private Methods @@ -1000,8 +1114,43 @@ protected override bool ReleaseHandle() { NativeMethods.OrtReleaseSessionOptions(handle); handle = IntPtr.Zero; + + if (_epSelectionPolicyConnectorHandle.IsAllocated) + { + _epSelectionPolicyConnectorHandle.Free(); + _epSelectionPolicyConnector = null; + } + + if (_epSelectionPolicyDelegateHandle.IsAllocated) + { + _epSelectionPolicyDelegateHandle.Free(); + _epSelectionPolicyDelegate = null; + } + + return true; } #endregion + + /// + /// Helper class to connect C and C# usage of the EP selection policy delegate. + /// + EpSelectionPolicyConnector _epSelectionPolicyConnector = null; + + /// + /// Handle to the EP selection policy connector that is passed to the C API as state for the + /// EP selection policy delegate. + /// + GCHandle _epSelectionPolicyConnectorHandle = default; + + /// + /// Delegate instance that is provided to the C API. + /// + NativeMethods.DOrtEpSelectionDelegate _epSelectionPolicyDelegate = null; + + /// + /// Handle to the EP selection policy delegate that is passed to the C API. + /// + GCHandle _epSelectionPolicyDelegateHandle = default; } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs index 1aa4db15d275c..d95a649bd95c5 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -93,7 +93,7 @@ public void AppendToSessionOptionsV2() { var runTest = (Func> getEpOptions) => { - SessionOptions sessionOptions = new SessionOptions(); + using SessionOptions sessionOptions = new SessionOptions(); sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; var epDevices = ortEnvInstance.GetEpDevices(); @@ -138,7 +138,7 @@ public void AppendToSessionOptionsV2() [Fact] public void SetEpSelectionPolicy() { - SessionOptions sessionOptions = new SessionOptions(); + using SessionOptions sessionOptions = new SessionOptions(); sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; var epDevices = ortEnvInstance.GetEpDevices(); @@ -150,7 +150,50 @@ public void SetEpSelectionPolicy() var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); // session should load successfully - using (var session = new InferenceSession(model)) + using (var session = new InferenceSession(model, sessionOptions)) + { + Assert.NotNull(session); + } + } + + private static List SelectionPolicyDelegate(IReadOnlyList epDevices, + OrtKeyValuePairs modelMetadata, + OrtKeyValuePairs runtimeMetadata, + uint maxSelections) + { + Assert.NotEmpty(modelMetadata.Entries); + Assert.True(epDevices.Count > 0); + + // select first device and last (if there are more than one). + var selected = new List(); + + selected.Add(epDevices[0]); + + // add ORT CPU EP which is always last. + if (maxSelections > 2 && epDevices.Count > 1) + { + selected.Add(epDevices.Last()); + } + + return selected; + } + + [Fact] + public void SetEpSelectionPolicyDelegate() + { + using SessionOptions sessionOptions = new SessionOptions(); + sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE; + + var epDevices = ortEnvInstance.GetEpDevices(); + Assert.NotEmpty(epDevices); + + // doesn't matter what the value is. should fallback to ORT CPU EP + sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegate); + + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); + + // session should load successfully + using (var session = new InferenceSession(model, sessionOptions)) { Assert.NotNull(session); } diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index cef5eab9a505e..6c7d910b4963b 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -438,18 +438,20 @@ typedef enum OrtExecutionProviderDevicePolicy { * \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array. Currently the maximum is 8. * \param num_ep_devices The number of selected devices. + * \param state Opaque pointer. Required to use the delegate from other languages like C# and python. * * \return OrtStatus* Selection status. Return nullptr on success. * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. * ORT will release the OrtStatus* if not null. */ -typedef OrtStatus* (*EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, - _In_ size_t num_devices, - _In_ const OrtKeyValuePairs* model_metadata, - _In_opt_ const OrtKeyValuePairs* runtime_metadata, - _Inout_ const OrtEpDevice** selected, - _In_ size_t max_selected, - _Out_ size_t* num_selected); +typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* state); /** \brief Algorithm to use for cuDNN Convolution Op */ @@ -5127,18 +5129,30 @@ struct OrtApi { /** \brief Set the execution provider selection policy for the session. * - * Allows users to specify a device selection policy for automatic execution provider (EP) selection, - * or provide a delegate callback for custom selection logic. + * Allows users to specify a device selection policy for automatic execution provider (EP) selection. + * If custom selection is required please use SessionOptionsSetEpSelectionPolicyDelegate instead. * * \param[in] session_options The OrtSessionOptions instance. * \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy). - * \param[in] delegate Optional delegate callback for custom selection. Pass nullptr to use the built-in policy. * * \since Version 1.22 */ ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options, - _In_ OrtExecutionProviderDevicePolicy policy, - _In_opt_ EpSelectionDelegate* delegate); + _In_ OrtExecutionProviderDevicePolicy policy); + + /** \brief Set the execution provider selection policy delegate for the session. + * + * Allows users to provide a custom device selection policy for automatic execution provider (EP) selection. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] delegate Delegate callback for custom selection. + * \param[in] delegate_state Optional state that will be passed to the delegate callback. nullptr if not required. + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* session_options, + _In_ EpSelectionDelegate delegate, + _In_opt_ void* delegate_state); /** \brief Get the hardware device type. * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 6c175c606b4a1..bc6f381bb82a0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1103,8 +1103,10 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl { const std::unordered_map& ep_options); /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy - SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, - EpSelectionDelegate* delegate = nullptr); + SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy); + + /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicyDelegate + SessionOptionsImpl& SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state = nullptr); SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 1fdb8f16d9600..94ad2118fa4d6 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1150,9 +1150,14 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2( } template -inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy, - EpSelectionDelegate* delegate) { - ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy, delegate)); +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy)); + return *this; +} + +template +inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state) { + ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicyDelegate(this->p_, delegate, state)); return *this; } diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index 8f8a3d6634a7e..89a43c4f71ee6 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -96,7 +96,8 @@ struct EpSelectionPolicy { // and no selection policy was explicitly provided. bool enable{false}; OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT; - EpSelectionDelegate* delegate{}; + EpSelectionDelegate delegate{}; + void* state{nullptr}; // state for the delegate }; /** diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index b1c0467da642e..c205e05baadb9 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -367,12 +367,24 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions* } ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* options, - _In_ OrtExecutionProviderDevicePolicy policy, - _In_opt_ EpSelectionDelegate* delegate) { + _In_ OrtExecutionProviderDevicePolicy policy) { API_IMPL_BEGIN options->value.ep_selection_policy.enable = true; options->value.ep_selection_policy.policy = policy; + options->value.ep_selection_policy.delegate = nullptr; + options->value.ep_selection_policy.state = nullptr; + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* options, + _In_opt_ EpSelectionDelegate delegate, + _In_opt_ void* state) { + API_IMPL_BEGIN + options->value.ep_selection_policy.enable = true; + options->value.ep_selection_policy.policy = OrtExecutionProviderDevicePolicy_DEFAULT; options->value.ep_selection_policy.delegate = delegate; + options->value.ep_selection_policy.state = state; return nullptr; API_IMPL_END } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 8ec7312cc6354..df70856a64e99 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3266,6 +3266,7 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod // save model metadata model_metadata_.producer_name = model.ProducerName(); + model_metadata_.producer_version = model.ProducerVersion(); model_metadata_.description = model.DocString(); model_metadata_.graph_description = model.GraphDocString(); model_metadata_.domain = model.Domain(); @@ -3430,6 +3431,10 @@ const Model& InferenceSession::GetModel() const { return *model_; } +const Environment& InferenceSession::GetEnvironment() const { + return environment_; +} + SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) { ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK()); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index ba9812a59fec3..51350390a0456 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -80,6 +80,7 @@ struct ModelMetadata { ModelMetadata& operator=(const ModelMetadata&) = delete; std::string producer_name; + std::string producer_version; std::string graph_name; std::string domain; std::string description; @@ -603,6 +604,7 @@ class InferenceSession { #endif const Model& GetModel() const; + const Environment& GetEnvironment() const; protected: #if !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 304966605c9cf..d03b98a9c1eb5 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2980,6 +2980,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetEpDevices, &OrtApis::SessionOptionsAppendExecutionProvider_V2, &OrtApis::SessionOptionsSetEpSelectionPolicy, + &OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, &OrtApis::HardwareDevice_Type, &OrtApis::HardwareDevice_VendorId, @@ -3029,7 +3030,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo // no additions in version 19, 20, and 21 static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); -static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 316, "Size of version 22 API cannot change"); +static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.23.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 7be518a39480f..47d1a543b5a31 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -577,8 +577,11 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOpt size_t num_ep_options); ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* sess_options, - _In_ OrtExecutionProviderDevicePolicy policy, - _In_opt_ EpSelectionDelegate* delegate); + _In_ OrtExecutionProviderDevicePolicy policy); + +ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* sess_options, + _In_ EpSelectionDelegate delegate, + _In_opt_ void* state); // OrtHardwareDevice accessors. ORT_API(OrtHardwareDeviceType, HardwareDevice_Type, _In_ const OrtHardwareDevice* device); diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 4ce13fe36ea86..f706bd05d8494 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -94,8 +94,8 @@ std::vector OrderDevices(const std::vectorep_name < b->ep_name; } // one is the default CPU EP @@ -104,31 +104,57 @@ std::vector OrderDevices(const std::vector GPU -> NPU // TODO: Should environment.cc do the ordering? - const auto& execution_devices = OrderDevices(env.GetOrtEpDevices()); + std::vector execution_devices = OrderDevices(env.GetOrtEpDevices()); // The list of devices selected by policies std::vector devices_selected; // Run the delegate if it was passed in lieu of any other policy if (options.value.ep_selection_policy.delegate) { - auto policy_fn = options.value.ep_selection_policy.delegate; + auto model_metadata = GetModelMetadata(sess); + OrtKeyValuePairs runtime_metadata; // TODO: where should this come from? + std::vector delegate_devices(execution_devices.begin(), execution_devices.end()); std::array selected_devices{nullptr}; - size_t num_selected = 0; - auto* status = (*policy_fn)(delegate_devices.data(), delegate_devices.size(), - nullptr, nullptr, selected_devices.data(), selected_devices.size(), &num_selected); + + EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate; + auto* status = delegate(delegate_devices.data(), delegate_devices.size(), + &model_metadata, &runtime_metadata, + selected_devices.data(), selected_devices.size(), &num_selected, + options.value.ep_selection_policy.state); // return or fall-through for both these cases // going with explicit failure for now so it's obvious to user what is happening @@ -142,6 +168,12 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const if (num_selected == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything."); } + + // Copy the selected devices to the output vector + devices_selected.reserve(num_selected); + for (size_t i = 0; i < num_selected; ++i) { + devices_selected.push_back(selected_devices[i]); + } } else { // Create the selector for the chosen policy std::unique_ptr selector; diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index c05394039d8c7..8ca4ef6af1f44 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -176,20 +176,6 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op env); } -#if !defined(ORT_MINIMAL_BUILD) - // TEMPORARY for testing. Manually specify the EP to select. - auto auto_select_ep_name = sess->GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select"); - if (auto_select_ep_name) { - ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env, *sess, *auto_select_ep_name)); - } - - // if there are no providers registered, and there's an ep selection policy set, do auto ep selection - if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) { - ProviderPolicyContext context; - ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env, *options, *sess)); - } -#endif // !defined(ORT_MINIMAL_BUILD) - #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // Add custom domains if (options && !options->custom_op_domains_.empty()) { @@ -231,22 +217,38 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, ORT_ENFORCE(session_logger != nullptr, "Session logger is invalid, but should have been initialized during session construction."); - // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of - // byte addressable memory - std::vector> provider_list; - if (options) { + const bool has_provider_factories = options != nullptr && !options->provider_factories.empty(); + + if (has_provider_factories) { + std::vector> provider_list; for (auto& factory : options->provider_factories) { auto provider = factory->CreateProvider(*options, *session_logger->ToExternal()); provider_list.push_back(std::move(provider)); } + + // register the providers + for (auto& provider : provider_list) { + if (provider) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + } + } } +#if !defined(ORT_MINIMAL_BUILD) + else { + // TEMPORARY for testing. Manually specify the EP to select. + auto auto_select_ep_name = sess.GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select"); + if (auto_select_ep_name) { + ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(sess.GetEnvironment(), sess, *auto_select_ep_name)); + } - // register the providers - for (auto& provider : provider_list) { - if (provider) { - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider))); + // if there are no providers registered, and there's an ep selection policy set, do auto ep selection. + // note: the model has already been loaded so model metadata should be available to the policy delegate callback. + if (options != nullptr && options->value.ep_selection_policy.enable) { + ProviderPolicyContext context; + ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(sess.GetEnvironment(), *options, sess)); } } +#endif // !defined(ORT_MINIMAL_BUILD) if (prepacked_weights_container != nullptr) { ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer( diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index 04b1b2ea0bdc4..cea1299adc26f 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -68,6 +68,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod const std::function&)>& select_devices = nullptr, // auto select using policy std::optional policy = std::nullopt, + std::optional delegate = std::nullopt, bool test_session_creation_only = false) { Ort::SessionOptions session_options; @@ -77,7 +78,9 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } if (auto_select) { - if (policy) { + if (delegate) { + session_options.SetEpSelectionPolicy(*delegate, nullptr); + } else if (policy) { session_options.SetEpSelectionPolicy(*policy); } else { // manually specify EP to select @@ -353,6 +356,150 @@ TEST(AutoEpSelection, PreferNpu) { OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_NPU); } +static OrtStatus* ORT_API_CALL PolicyDelegate(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + if (max_selected <= 2) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Expected to be able to select 2 devices."); + } + + if (model_metadata->entries.empty()) { + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Model metadata was empty."); + } + + selected[0] = ep_devices[0]; + *num_selected = 1; + if (num_devices > 1) { + // CPU EP is always last. + selected[1] = ep_devices[num_devices - 1]; + *num_selected = 2; + } + + return nullptr; +} + +static OrtStatus* ORT_API_CALL PolicyDelegateSelectNone(_In_ const OrtEpDevice** /*ep_devices*/, + _In_ size_t /*num_devices*/, + _In_ const OrtKeyValuePairs* /*model_metadata*/, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** /*selected*/, + _In_ size_t /*max_selected*/, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + return nullptr; +} + +static OrtStatus* ORT_API_CALL PolicyDelegateReturnError(_In_ const OrtEpDevice** /*ep_devices*/, + _In_ size_t /*num_devices*/, + _In_ const OrtKeyValuePairs* /*model_metadata*/, + _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/, + _Inout_ const OrtEpDevice** /*selected*/, + _In_ size_t /*max_selected*/, + _Out_ size_t* num_selected, + _In_ void* /*state*/) { + *num_selected = 0; + + return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Selection error."); +} + +// test providing a delegate +TEST(AutoEpSelection, PolicyDelegate) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegate); +} + +// test providing a delegate +TEST(AutoEpSelection, PolicyDelegateSelectsNothing) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + ASSERT_THROW( + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegateSelectNone, + /*test_session_creation_only*/ true), + Ort::Exception); +} + +TEST(AutoEpSelection, PolicyDelegateReturnsError) { + std::vector> inputs(1); + auto& input = inputs.back(); + input.name = "X"; + input.dims = {3, 2}; + input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {3, 2}; + std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + const Ort::KeyValuePairs provider_options; + + ASSERT_THROW( + TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), + "", // don't need EP name + std::nullopt, + provider_options, + inputs, + "Y", + expected_dims_y, + expected_values_y, + /* auto_select */ true, + /*select_devices*/ nullptr, + std::nullopt, + PolicyDelegateReturnError, + /*test_session_creation_only*/ true), + Ort::Exception); +} + namespace { struct ExamplePluginInfo { const std::filesystem::path library_path =