diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
index 620c13b8641b5..c543414ca13a9 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs
@@ -353,6 +353,7 @@ public struct OrtApi
public IntPtr SessionOptionsAppendExecutionProvider_V2;
public IntPtr SessionOptionsSetEpSelectionPolicy;
+ public IntPtr SessionOptionsSetEpSelectionPolicyDelegate;
public IntPtr HardwareDevice_Type;
public IntPtr HardwareDevice_VendorId;
@@ -692,6 +693,11 @@ static NativeMethods()
(DSessionOptionsSetEpSelectionPolicy)Marshal.GetDelegateForFunctionPointer(
api_.SessionOptionsSetEpSelectionPolicy,
typeof(DSessionOptionsSetEpSelectionPolicy));
+
+ OrtSessionOptionsSetEpSelectionPolicyDelegate =
+ (DSessionOptionsSetEpSelectionPolicyDelegate)Marshal.GetDelegateForFunctionPointer(
+ api_.SessionOptionsSetEpSelectionPolicyDelegate,
+ typeof(DSessionOptionsSetEpSelectionPolicyDelegate));
}
internal class NativeLib
@@ -2278,28 +2284,49 @@ out IntPtr lora_adapter
#region Auto EP API related
//
// OrtKeyValuePairs
+
+ ///
+ /// Create an OrtKeyValuePairs instance.
+ ///
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtCreateKeyValuePairs(out IntPtr /* OrtKeyValuePairs** */ kvps);
+ ///
+ /// Add/replace a key-value pair in the OrtKeyValuePairs instance.
+ ///
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
byte[] /* const char* */ key,
byte[] /* const char* */ value);
+ ///
+ /// Get the value for the provided key.
+ ///
+ /// Value. Returns IntPtr.Zero if key was not found.
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* const char* */ DOrtGetKeyValue(IntPtr /* const OrtKeyValuePairs* */ kvps,
byte[] /* const char* */ key);
+ ///
+ /// Get all the key-value pairs in the OrtKeyValuePairs instance.
+ ///
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtGetKeyValuePairs(IntPtr /* const OrtKeyValuePairs* */ kvps,
out IntPtr /* const char* const** */ keys,
out IntPtr /* const char* const** */ values,
out UIntPtr /* size_t* */ numEntries);
+ ///
+ /// Remove a key-value pair from the OrtKeyValuePairs instance.
+ /// Ignores keys that are not present.
+ ///
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
byte[] /* const char* */ key);
+ ///
+ /// Release the OrtKeyValuePairs instance.
+ ///
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseKeyValuePairs(IntPtr /* OrtKeyValuePairs* */ kvps);
@@ -2370,12 +2397,27 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
//
// Auto Selection EP registration and selection customization
+
+ ///
+ /// Register an execution provider library.
+ /// The library must implement CreateEpFactories and ReleaseEpFactory.
+ ///
+ /// Environment to add the EP library to.
+ /// Name to register the library under.
+ /// Absolute path to the library.
+ /// OrtStatus*
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtRegisterExecutionProviderLibrary(
IntPtr /* OrtEnv* */ env,
byte[] /* const char* */ registration_name,
byte[] /* const ORTCHAR_T* */ path);
+ ///
+ /// Unregister an execution provider library.
+ ///
+ /// The environment to unregister the library from.
+ /// The name the library was registered under.
+ /// OrtStatus*
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtUnregisterExecutionProviderLibrary(
IntPtr /* OrtEnv* */ env,
@@ -2384,6 +2426,11 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
public static DOrtRegisterExecutionProviderLibrary OrtRegisterExecutionProviderLibrary;
public static DOrtUnregisterExecutionProviderLibrary OrtUnregisterExecutionProviderLibrary;
+ ///
+ /// Get the OrtEpDevices that are available.
+ /// These are all the possible execution provider and device pairs.
+ ///
+ /// OrtStatus*
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtGetEpDevices(
IntPtr /* const OrtEnv* */ env,
@@ -2392,6 +2439,20 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
public static DOrtGetEpDevices OrtGetEpDevices;
+ ///
+ /// Add execution provider devices to the session options.
+ /// Priority is based on the order of the OrtEpDevice instances. Highest priority first.
+ /// All OrtEpDevice instances in ep_devices must be for the same execution provider.
+ /// e.g. selecting OpenVINO for GPU and NPU would have an OrtEpDevice for GPU and NPU.
+ ///
+ /// SessionOptions to add to.
+ /// Environment that the OrtEpDevice instances came from by calling GetEpDevices
+ /// One or more OrtEpDevice instances.
+ /// Number of OrtEpDevice instances.
+ /// User overrides for execution provider options. May be IntPtr.Zero.
+ /// User overrides for execution provider options. May be IntPtr.Zero.
+ /// Number of user overrides for execution provider options.
+ ///
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DOrtSessionOptionsAppendExecutionProvider_V2(
IntPtr /* OrtSessionOptions* */ sess_options,
@@ -2404,6 +2465,18 @@ public delegate void DOrtRemoveKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps,
public static DOrtSessionOptionsAppendExecutionProvider_V2 OrtSessionOptionsAppendExecutionProvider_V2;
+ ///
+ /// Delegate to do custom execution provider selection.
+ ///
+ /// Available OrtEpDevices to select from.
+ /// Number of OrtEpDevices.
+ /// Metadata from the ONNX model.
+ /// Runtime metadata. May be IntPtr.Zero.
+ /// OrtEpDevices that were selected. Pre-allocated array for delegate to update.
+ /// Maximum number of OrtEpDevices that can be selected.
+ /// Number of OrtEpDevices that were selected.
+ /// State that was provided in when the delegate was registered.
+ /// OrtStatus*
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr DOrtEpSelectionDelegate(
IntPtr /* OrtEpDevice** */ epDevices,
@@ -2412,16 +2485,36 @@ public delegate IntPtr DOrtEpSelectionDelegate(
IntPtr /* OrtKeyValuePairs* */ runtimeMetadata,
IntPtr /* OrtEpDevice** */ selected,
uint maxSelected,
- out UIntPtr numSelected
+ out UIntPtr numSelected,
+ IntPtr /* void* */ state
);
+ ///
+ /// Set the execution provider selection policy.
+ ///
+ /// SessionOptions to set the policy for.
+ /// Selection policy.
+ /// OrtStatus*
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicy(
IntPtr /* OrtSessionOptions* */ session_options,
- int /* OrtExecutionProviderDevicePolicy */ policy,
- IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate);
+ int /* OrtExecutionProviderDevicePolicy */ policy);
public static DSessionOptionsSetEpSelectionPolicy OrtSessionOptionsSetEpSelectionPolicy;
+ ///
+ /// Set the execution provider selection policy delegate.
+ ///
+ /// SessionOptions to set the policy for.
+ /// Selection policy delegate.
+ /// State that is passed through to the selection delegate.
+ /// OrtStatus*
+ [UnmanagedFunctionPointer(CallingConvention.Winapi)]
+ public delegate IntPtr /* OrtStatus* */ DSessionOptionsSetEpSelectionPolicyDelegate(
+ IntPtr /* OrtSessionOptions* */ session_options,
+ IntPtr /* DOrtEpSelectionDelegate* */ selection_delegate,
+ IntPtr /* void* */ state);
+ public static DSessionOptionsSetEpSelectionPolicyDelegate OrtSessionOptionsSetEpSelectionPolicyDelegate;
+
#endregion
#region Misc API
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs
index e3947d900214e..0318e08519128 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEpDevice.shared.cs
@@ -10,18 +10,18 @@ namespace Microsoft.ML.OnnxRuntime
/// Represents the combination of an execution provider and a hardware device
/// that the execution provider can utilize.
///
- public class OrtEpDevice : SafeHandle
+ public class OrtEpDevice
{
///
/// Construct an OrtEpDevice from an existing native OrtEpDevice instance.
///
/// Native OrtEpDevice handle.
internal OrtEpDevice(IntPtr epDeviceHandle)
- : base(epDeviceHandle, ownsHandle: false)
{
+ _handle = epDeviceHandle;
}
- internal IntPtr Handle => handle;
+ internal IntPtr Handle => _handle;
///
/// The name of the execution provider.
@@ -30,7 +30,7 @@ public string EpName
{
get
{
- IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(handle);
+ IntPtr namePtr = NativeMethods.OrtEpDevice_EpName(_handle);
return NativeOnnxValueHelper.StringFromNativeUtf8(namePtr);
}
}
@@ -42,7 +42,7 @@ public string EpVendor
{
get
{
- IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(handle);
+ IntPtr vendorPtr = NativeMethods.OrtEpDevice_EpVendor(_handle);
return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr);
}
}
@@ -54,7 +54,7 @@ public OrtKeyValuePairs EpMetadata
{
get
{
- return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(handle));
+ return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpMetadata(_handle));
}
}
@@ -65,7 +65,7 @@ public OrtKeyValuePairs EpOptions
{
get
{
- return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(handle));
+ return new OrtKeyValuePairs(NativeMethods.OrtEpDevice_EpOptions(_handle));
}
}
@@ -76,23 +76,11 @@ public OrtHardwareDevice HardwareDevice
{
get
{
- IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(handle);
+ IntPtr devicePtr = NativeMethods.OrtEpDevice_Device(_handle);
return new OrtHardwareDevice(devicePtr);
}
}
- ///
- /// Indicates whether the native handle is invalid.
- ///
- public override bool IsInvalid => handle == IntPtr.Zero;
-
- ///
- /// No-op. OrtEpDevice is always read-only as the instance is owned by native ORT.
- ///
- /// True
- protected override bool ReleaseHandle()
- {
- return true;
- }
+ private readonly IntPtr _handle;
}
}
\ No newline at end of file
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs
index 8e7caae90ff79..af7115a92285e 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtHardwareDevice.shared.cs
@@ -21,16 +21,16 @@ public enum OrtHardwareDeviceType
///
/// Represents a hardware device that is available on the current system.
///
- public class OrtHardwareDevice : SafeHandle
+ public class OrtHardwareDevice
{
///
/// Construct an OrtHardwareDevice for a native OrtHardwareDevice instance.
///
/// Native OrtHardwareDevice handle.
- internal OrtHardwareDevice(IntPtr deviceHandle)
- : base(deviceHandle, ownsHandle: false)
+ internal OrtHardwareDevice(IntPtr deviceHandle)
{
+ _handle = deviceHandle;
}
///
@@ -40,7 +40,7 @@ public OrtHardwareDeviceType Type
{
get
{
- return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(handle);
+ return (OrtHardwareDeviceType)NativeMethods.OrtHardwareDevice_Type(_handle);
}
}
@@ -54,7 +54,7 @@ public uint VendorId
{
get
{
- return NativeMethods.OrtHardwareDevice_VendorId(handle);
+ return NativeMethods.OrtHardwareDevice_VendorId(_handle);
}
}
@@ -65,7 +65,7 @@ public string Vendor
{
get
{
- IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(handle);
+ IntPtr vendorPtr = NativeMethods.OrtHardwareDevice_Vendor(_handle);
return NativeOnnxValueHelper.StringFromNativeUtf8(vendorPtr);
}
}
@@ -82,7 +82,7 @@ public uint DeviceId
{
get
{
- return NativeMethods.OrtHardwareDevice_DeviceId(handle);
+ return NativeMethods.OrtHardwareDevice_DeviceId(_handle);
}
}
@@ -95,22 +95,10 @@ public OrtKeyValuePairs Metadata
{
get
{
- return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(handle));
+ return new OrtKeyValuePairs(NativeMethods.OrtHardwareDevice_Metadata(_handle));
}
}
- ///
- /// Indicates whether the native handle is invalid.
- ///
- public override bool IsInvalid => handle == IntPtr.Zero;
-
- ///
- /// No-op. OrtHardwareDevice is always read-only as the instance is owned by native ORT.
- ///
- /// True
- protected override bool ReleaseHandle()
- {
- return true;
- }
+ private readonly IntPtr _handle;
}
}
\ No newline at end of file
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
index de6189e105f78..d60bf75ccbd7c 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/SessionOptions.shared.cs
@@ -636,9 +636,32 @@ public void AddFreeDimensionOverrideByName(string dimName, long dimValue)
public void SetEpSelectionPolicy(ExecutionProviderDevicePolicy policy)
{
NativeApiStatus.VerifySuccess(
- NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy, IntPtr.Zero));
+ NativeMethods.OrtSessionOptionsSetEpSelectionPolicy(handle, (int)policy));
}
+ ///
+ /// Set the execution provider selection policy if using automatic execution provider selection.
+ /// Execution providers must be registered with the OrtEnv to be available for selection.
+ ///
+ /// Delegate that implements the custom selection policy.
+ public void SetEpSelectionPolicyDelegate(EpSelectionDelegate selectionDelegate = null)
+ {
+ _epSelectionPolicyConnector = new EpSelectionPolicyConnector(selectionDelegate);
+ _epSelectionPolicyDelegate = new NativeMethods.DOrtEpSelectionDelegate(
+ EpSelectionPolicyConnector.EpSelectionPolicyWrapper);
+
+ // make sure these stay alive. not sure if this is necessary when they're class members though
+ _epSelectionPolicyConnectorHandle = GCHandle.Alloc(_epSelectionPolicyConnector);
+ _epSelectionPolicyDelegateHandle = GCHandle.Alloc(_epSelectionPolicyDelegate);
+
+ IntPtr funcPtr = Marshal.GetFunctionPointerForDelegate(_epSelectionPolicyDelegate);
+
+ NativeApiStatus.VerifySuccess(
+ NativeMethods.OrtSessionOptionsSetEpSelectionPolicyDelegate(
+ handle,
+ funcPtr,
+ GCHandle.ToIntPtr(_epSelectionPolicyConnectorHandle)));
+ }
#endregion
internal IntPtr Handle
@@ -914,7 +937,98 @@ public void SetLoadCancellationFlag(bool value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsSetLoadCancellationFlag(handle, value));
}
+ #endregion
+
+ #region Selection Policy Delegate helpers
+ ///
+ /// Delegate to select execution provider devices from a list of available devices.
+ ///
+ /// OrtEpDevices to select from.
+ /// Model metadata.
+ /// Runtime metadata.
+ /// Maximum number of devices that can be selected.
+ /// Selected devices. Ordered by priority. Highest priority first.
+ public delegate List EpSelectionDelegate(IReadOnlyList epDevices,
+ OrtKeyValuePairs modelMetadata,
+ OrtKeyValuePairs runtimeMetadata,
+ uint maxSelections);
+
+ ///
+ /// Class to bridge the C# and native worlds for the EP selection policy delegate
+ ///
+ internal class EpSelectionPolicyConnector
+ {
+ private readonly EpSelectionDelegate _csharpDelegate;
+
+ internal EpSelectionPolicyConnector(EpSelectionDelegate selectionDelegate)
+ {
+ _csharpDelegate = selectionDelegate;
+ }
+
+ ///
+ /// Delegate to convert between the C and C# worlds
+ ///
+ /// OrtEpDevices to select from.
+ /// Number of OrtEpDevices.
+ /// Model metadata.
+ /// Runtime metadata.
+ /// Pre-allocated OrtEpDevice buffer to update with selected devices.
+ /// Number of entries in selectedOut.
+ /// Number of OrtEpDevies that were selected.
+ /// Opaque state.
+ /// nullptr for OrtStatus* to indicate success.
+ /// Currently we don't have a way to create an OrtStatus instance from the C# bindings.
+ /// Can add if we need to return an explicit error message.
+ ///
+ public static IntPtr EpSelectionPolicyWrapper(IntPtr /* OrtEpDevice** */ epDevicesIn,
+ uint numDevices,
+ IntPtr /* OrtKeyValuePairs* */ modelMetadataIn,
+ IntPtr /* OrtKeyValuePairs* */ runtimeMetadataIn,
+ IntPtr /* OrtEpDevice** */ selectedOut,
+ uint maxSelected,
+ out UIntPtr numSelected,
+ IntPtr state)
+ {
+ Span epDevicesIntPtrs;
+ Span selectedDevicesIntPtrs;
+ EpSelectionPolicyConnector connector = (EpSelectionPolicyConnector)GCHandle.FromIntPtr(state).Target;
+
+ unsafe
+ {
+ void* ptr = epDevicesIn.ToPointer();
+ epDevicesIntPtrs = new Span(ptr, checked((int)numDevices));
+ }
+
+ List epDevices = new List();
+ for (int i = 0; i < numDevices; i++)
+ {
+
+ epDevices.Add(new OrtEpDevice(epDevicesIntPtrs[i]));
+ }
+
+ OrtKeyValuePairs modelMetadata = new OrtKeyValuePairs(modelMetadataIn);
+ OrtKeyValuePairs runtimeMetadata = new OrtKeyValuePairs(runtimeMetadataIn);
+ var selected = connector._csharpDelegate(epDevices, modelMetadata, runtimeMetadata, maxSelected);
+
+ numSelected = (UIntPtr)selected.Count;
+
+ unsafe
+ {
+ void* ptr = selectedOut.ToPointer();
+ selectedDevicesIntPtrs = new Span(ptr, (int)maxSelected);
+ }
+
+ int idx = 0;
+ foreach (var epDevice in selected)
+ {
+ selectedDevicesIntPtrs[idx] = epDevice.Handle;
+ idx++;
+ }
+
+ return IntPtr.Zero;
+ }
+ }
#endregion
#region Private Methods
@@ -1000,8 +1114,43 @@ protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseSessionOptions(handle);
handle = IntPtr.Zero;
+
+ if (_epSelectionPolicyConnectorHandle.IsAllocated)
+ {
+ _epSelectionPolicyConnectorHandle.Free();
+ _epSelectionPolicyConnector = null;
+ }
+
+ if (_epSelectionPolicyDelegateHandle.IsAllocated)
+ {
+ _epSelectionPolicyDelegateHandle.Free();
+ _epSelectionPolicyDelegate = null;
+ }
+
+
return true;
}
#endregion
+
+ ///
+ /// Helper class to connect C and C# usage of the EP selection policy delegate.
+ ///
+ EpSelectionPolicyConnector _epSelectionPolicyConnector = null;
+
+ ///
+ /// Handle to the EP selection policy connector that is passed to the C API as state for the
+ /// EP selection policy delegate.
+ ///
+ GCHandle _epSelectionPolicyConnectorHandle = default;
+
+ ///
+ /// Delegate instance that is provided to the C API.
+ ///
+ NativeMethods.DOrtEpSelectionDelegate _epSelectionPolicyDelegate = null;
+
+ ///
+ /// Handle to the EP selection policy delegate that is passed to the C API.
+ ///
+ GCHandle _epSelectionPolicyDelegateHandle = default;
}
}
diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs
index 1aa4db15d275c..d95a649bd95c5 100644
--- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs
+++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs
@@ -93,7 +93,7 @@ public void AppendToSessionOptionsV2()
{
var runTest = (Func> getEpOptions) =>
{
- SessionOptions sessionOptions = new SessionOptions();
+ using SessionOptions sessionOptions = new SessionOptions();
sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE;
var epDevices = ortEnvInstance.GetEpDevices();
@@ -138,7 +138,7 @@ public void AppendToSessionOptionsV2()
[Fact]
public void SetEpSelectionPolicy()
{
- SessionOptions sessionOptions = new SessionOptions();
+ using SessionOptions sessionOptions = new SessionOptions();
sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE;
var epDevices = ortEnvInstance.GetEpDevices();
@@ -150,7 +150,50 @@ public void SetEpSelectionPolicy()
var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
// session should load successfully
- using (var session = new InferenceSession(model))
+ using (var session = new InferenceSession(model, sessionOptions))
+ {
+ Assert.NotNull(session);
+ }
+ }
+
+ private static List SelectionPolicyDelegate(IReadOnlyList epDevices,
+ OrtKeyValuePairs modelMetadata,
+ OrtKeyValuePairs runtimeMetadata,
+ uint maxSelections)
+ {
+ Assert.NotEmpty(modelMetadata.Entries);
+ Assert.True(epDevices.Count > 0);
+
+ // select first device and last (if there are more than one).
+ var selected = new List();
+
+ selected.Add(epDevices[0]);
+
+ // add ORT CPU EP which is always last.
+ if (maxSelections > 2 && epDevices.Count > 1)
+ {
+ selected.Add(epDevices.Last());
+ }
+
+ return selected;
+ }
+
+ [Fact]
+ public void SetEpSelectionPolicyDelegate()
+ {
+ using SessionOptions sessionOptions = new SessionOptions();
+ sessionOptions.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_VERBOSE;
+
+ var epDevices = ortEnvInstance.GetEpDevices();
+ Assert.NotEmpty(epDevices);
+
+ // doesn't matter what the value is. should fallback to ORT CPU EP
+ sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegate);
+
+ var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx");
+
+ // session should load successfully
+ using (var session = new InferenceSession(model, sessionOptions))
{
Assert.NotNull(session);
}
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index cef5eab9a505e..6c7d910b4963b 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -438,18 +438,20 @@ typedef enum OrtExecutionProviderDevicePolicy {
* \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array.
Currently the maximum is 8.
* \param num_ep_devices The number of selected devices.
+ * \param state Opaque pointer. Required to use the delegate from other languages like C# and python.
*
* \return OrtStatus* Selection status. Return nullptr on success.
* Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
* ORT will release the OrtStatus* if not null.
*/
-typedef OrtStatus* (*EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices,
- _In_ size_t num_devices,
- _In_ const OrtKeyValuePairs* model_metadata,
- _In_opt_ const OrtKeyValuePairs* runtime_metadata,
- _Inout_ const OrtEpDevice** selected,
- _In_ size_t max_selected,
- _Out_ size_t* num_selected);
+typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices,
+ _In_ size_t num_devices,
+ _In_ const OrtKeyValuePairs* model_metadata,
+ _In_opt_ const OrtKeyValuePairs* runtime_metadata,
+ _Inout_ const OrtEpDevice** selected,
+ _In_ size_t max_selected,
+ _Out_ size_t* num_selected,
+ _In_ void* state);
/** \brief Algorithm to use for cuDNN Convolution Op
*/
@@ -5127,18 +5129,30 @@ struct OrtApi {
/** \brief Set the execution provider selection policy for the session.
*
- * Allows users to specify a device selection policy for automatic execution provider (EP) selection,
- * or provide a delegate callback for custom selection logic.
+ * Allows users to specify a device selection policy for automatic execution provider (EP) selection.
+ * If custom selection is required please use SessionOptionsSetEpSelectionPolicyDelegate instead.
*
* \param[in] session_options The OrtSessionOptions instance.
* \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy).
- * \param[in] delegate Optional delegate callback for custom selection. Pass nullptr to use the built-in policy.
*
* \since Version 1.22
*/
ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options,
- _In_ OrtExecutionProviderDevicePolicy policy,
- _In_opt_ EpSelectionDelegate* delegate);
+ _In_ OrtExecutionProviderDevicePolicy policy);
+
+ /** \brief Set the execution provider selection policy delegate for the session.
+ *
+ * Allows users to provide a custom device selection policy for automatic execution provider (EP) selection.
+ *
+ * \param[in] session_options The OrtSessionOptions instance.
+ * \param[in] delegate Delegate callback for custom selection.
+ * \param[in] delegate_state Optional state that will be passed to the delegate callback. nullptr if not required.
+ *
+ * \since Version 1.22
+ */
+ ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* session_options,
+ _In_ EpSelectionDelegate delegate,
+ _In_opt_ void* delegate_state);
/** \brief Get the hardware device type.
*
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index 6c175c606b4a1..bc6f381bb82a0 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -1103,8 +1103,10 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl {
const std::unordered_map& ep_options);
/// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy
- SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy,
- EpSelectionDelegate* delegate = nullptr);
+ SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy);
+
+ /// Wraps OrtApi::SessionOptionsSetEpSelectionPolicyDelegate
+ SessionOptionsImpl& SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state = nullptr);
SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
index 1fdb8f16d9600..94ad2118fa4d6 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h
@@ -1150,9 +1150,14 @@ inline SessionOptionsImpl& SessionOptionsImpl::AppendExecutionProvider_V2(
}
template
-inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy,
- EpSelectionDelegate* delegate) {
- ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy, delegate));
+inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy) {
+ ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy));
+ return *this;
+}
+
+template
+inline SessionOptionsImpl& SessionOptionsImpl::SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state) {
+ ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicyDelegate(this->p_, delegate, state));
return *this;
}
diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h
index 8f8a3d6634a7e..89a43c4f71ee6 100644
--- a/onnxruntime/core/framework/session_options.h
+++ b/onnxruntime/core/framework/session_options.h
@@ -96,7 +96,8 @@ struct EpSelectionPolicy {
// and no selection policy was explicitly provided.
bool enable{false};
OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT;
- EpSelectionDelegate* delegate{};
+ EpSelectionDelegate delegate{};
+ void* state{nullptr}; // state for the delegate
};
/**
diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc
index b1c0467da642e..c205e05baadb9 100644
--- a/onnxruntime/core/session/abi_session_options.cc
+++ b/onnxruntime/core/session/abi_session_options.cc
@@ -367,12 +367,24 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions*
}
ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* options,
- _In_ OrtExecutionProviderDevicePolicy policy,
- _In_opt_ EpSelectionDelegate* delegate) {
+ _In_ OrtExecutionProviderDevicePolicy policy) {
API_IMPL_BEGIN
options->value.ep_selection_policy.enable = true;
options->value.ep_selection_policy.policy = policy;
+ options->value.ep_selection_policy.delegate = nullptr;
+ options->value.ep_selection_policy.state = nullptr;
+ return nullptr;
+ API_IMPL_END
+}
+
+ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* options,
+ _In_opt_ EpSelectionDelegate delegate,
+ _In_opt_ void* state) {
+ API_IMPL_BEGIN
+ options->value.ep_selection_policy.enable = true;
+ options->value.ep_selection_policy.policy = OrtExecutionProviderDevicePolicy_DEFAULT;
options->value.ep_selection_policy.delegate = delegate;
+ options->value.ep_selection_policy.state = state;
return nullptr;
API_IMPL_END
}
diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc
index 8ec7312cc6354..df70856a64e99 100644
--- a/onnxruntime/core/session/inference_session.cc
+++ b/onnxruntime/core/session/inference_session.cc
@@ -3266,6 +3266,7 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod
// save model metadata
model_metadata_.producer_name = model.ProducerName();
+ model_metadata_.producer_version = model.ProducerVersion();
model_metadata_.description = model.DocString();
model_metadata_.graph_description = model.GraphDocString();
model_metadata_.domain = model.Domain();
@@ -3430,6 +3431,10 @@ const Model& InferenceSession::GetModel() const {
return *model_;
}
+const Environment& InferenceSession::GetEnvironment() const {
+ return environment_;
+}
+
SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) {
ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK());
}
diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h
index ba9812a59fec3..51350390a0456 100644
--- a/onnxruntime/core/session/inference_session.h
+++ b/onnxruntime/core/session/inference_session.h
@@ -80,6 +80,7 @@ struct ModelMetadata {
ModelMetadata& operator=(const ModelMetadata&) = delete;
std::string producer_name;
+ std::string producer_version;
std::string graph_name;
std::string domain;
std::string description;
@@ -603,6 +604,7 @@ class InferenceSession {
#endif
const Model& GetModel() const;
+ const Environment& GetEnvironment() const;
protected:
#if !defined(ORT_MINIMAL_BUILD)
diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc
index 304966605c9cf..d03b98a9c1eb5 100644
--- a/onnxruntime/core/session/onnxruntime_c_api.cc
+++ b/onnxruntime/core/session/onnxruntime_c_api.cc
@@ -2980,6 +2980,7 @@ static constexpr OrtApi ort_api_1_to_23 = {
&OrtApis::GetEpDevices,
&OrtApis::SessionOptionsAppendExecutionProvider_V2,
&OrtApis::SessionOptionsSetEpSelectionPolicy,
+ &OrtApis::SessionOptionsSetEpSelectionPolicyDelegate,
&OrtApis::HardwareDevice_Type,
&OrtApis::HardwareDevice_VendorId,
@@ -3029,7 +3030,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo
// no additions in version 19, 20, and 21
static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change");
-static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 316, "Size of version 22 API cannot change");
+static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");
// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.23.0",
diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h
index 7be518a39480f..47d1a543b5a31 100644
--- a/onnxruntime/core/session/ort_apis.h
+++ b/onnxruntime/core/session/ort_apis.h
@@ -577,8 +577,11 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOpt
size_t num_ep_options);
ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* sess_options,
- _In_ OrtExecutionProviderDevicePolicy policy,
- _In_opt_ EpSelectionDelegate* delegate);
+ _In_ OrtExecutionProviderDevicePolicy policy);
+
+ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* sess_options,
+ _In_ EpSelectionDelegate delegate,
+ _In_opt_ void* state);
// OrtHardwareDevice accessors.
ORT_API(OrtHardwareDeviceType, HardwareDevice_Type, _In_ const OrtHardwareDevice* device);
diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc
index 4ce13fe36ea86..f706bd05d8494 100644
--- a/onnxruntime/core/session/provider_policy_context.cc
+++ b/onnxruntime/core/session/provider_policy_context.cc
@@ -94,8 +94,8 @@ std::vector OrderDevices(const std::vectorep_name < b->ep_name;
}
// one is the default CPU EP
@@ -104,31 +104,57 @@ std::vector OrderDevices(const std::vector GPU -> NPU
// TODO: Should environment.cc do the ordering?
- const auto& execution_devices = OrderDevices(env.GetOrtEpDevices());
+ std::vector execution_devices = OrderDevices(env.GetOrtEpDevices());
// The list of devices selected by policies
std::vector devices_selected;
// Run the delegate if it was passed in lieu of any other policy
if (options.value.ep_selection_policy.delegate) {
- auto policy_fn = options.value.ep_selection_policy.delegate;
+ auto model_metadata = GetModelMetadata(sess);
+ OrtKeyValuePairs runtime_metadata; // TODO: where should this come from?
+
std::vector delegate_devices(execution_devices.begin(), execution_devices.end());
std::array selected_devices{nullptr};
-
size_t num_selected = 0;
- auto* status = (*policy_fn)(delegate_devices.data(), delegate_devices.size(),
- nullptr, nullptr, selected_devices.data(), selected_devices.size(), &num_selected);
+
+ EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate;
+ auto* status = delegate(delegate_devices.data(), delegate_devices.size(),
+ &model_metadata, &runtime_metadata,
+ selected_devices.data(), selected_devices.size(), &num_selected,
+ options.value.ep_selection_policy.state);
// return or fall-through for both these cases
// going with explicit failure for now so it's obvious to user what is happening
@@ -142,6 +168,12 @@ Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const
if (num_selected == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything.");
}
+
+ // Copy the selected devices to the output vector
+ devices_selected.reserve(num_selected);
+ for (size_t i = 0; i < num_selected; ++i) {
+ devices_selected.push_back(selected_devices[i]);
+ }
} else {
// Create the selector for the chosen policy
std::unique_ptr selector;
diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc
index c05394039d8c7..8ca4ef6af1f44 100644
--- a/onnxruntime/core/session/utils.cc
+++ b/onnxruntime/core/session/utils.cc
@@ -176,20 +176,6 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op
env);
}
-#if !defined(ORT_MINIMAL_BUILD)
- // TEMPORARY for testing. Manually specify the EP to select.
- auto auto_select_ep_name = sess->GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select");
- if (auto_select_ep_name) {
- ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env, *sess, *auto_select_ep_name));
- }
-
- // if there are no providers registered, and there's an ep selection policy set, do auto ep selection
- if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) {
- ProviderPolicyContext context;
- ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env, *options, *sess));
- }
-#endif // !defined(ORT_MINIMAL_BUILD)
-
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
// Add custom domains
if (options && !options->custom_op_domains_.empty()) {
@@ -231,22 +217,38 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options,
ORT_ENFORCE(session_logger != nullptr,
"Session logger is invalid, but should have been initialized during session construction.");
- // we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of
- // byte addressable memory
- std::vector> provider_list;
- if (options) {
+ const bool has_provider_factories = options != nullptr && !options->provider_factories.empty();
+
+ if (has_provider_factories) {
+ std::vector> provider_list;
for (auto& factory : options->provider_factories) {
auto provider = factory->CreateProvider(*options, *session_logger->ToExternal());
provider_list.push_back(std::move(provider));
}
+
+ // register the providers
+ for (auto& provider : provider_list) {
+ if (provider) {
+ ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider)));
+ }
+ }
}
+#if !defined(ORT_MINIMAL_BUILD)
+ else {
+ // TEMPORARY for testing. Manually specify the EP to select.
+ auto auto_select_ep_name = sess.GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select");
+ if (auto_select_ep_name) {
+ ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(sess.GetEnvironment(), sess, *auto_select_ep_name));
+ }
- // register the providers
- for (auto& provider : provider_list) {
- if (provider) {
- ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider)));
+ // if there are no providers registered, and there's an ep selection policy set, do auto ep selection.
+ // note: the model has already been loaded so model metadata should be available to the policy delegate callback.
+ if (options != nullptr && options->value.ep_selection_policy.enable) {
+ ProviderPolicyContext context;
+ ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(sess.GetEnvironment(), *options, sess));
}
}
+#endif // !defined(ORT_MINIMAL_BUILD)
if (prepacked_weights_container != nullptr) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer(
diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc
index 04b1b2ea0bdc4..cea1299adc26f 100644
--- a/onnxruntime/test/autoep/test_autoep_selection.cc
+++ b/onnxruntime/test/autoep/test_autoep_selection.cc
@@ -68,6 +68,7 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod
const std::function&)>& select_devices = nullptr,
// auto select using policy
std::optional policy = std::nullopt,
+ std::optional delegate = std::nullopt,
bool test_session_creation_only = false) {
Ort::SessionOptions session_options;
@@ -77,7 +78,9 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod
}
if (auto_select) {
- if (policy) {
+ if (delegate) {
+ session_options.SetEpSelectionPolicy(*delegate, nullptr);
+ } else if (policy) {
session_options.SetEpSelectionPolicy(*policy);
} else {
// manually specify EP to select
@@ -353,6 +356,150 @@ TEST(AutoEpSelection, PreferNpu) {
OrtExecutionProviderDevicePolicy::OrtExecutionProviderDevicePolicy_PREFER_NPU);
}
+static OrtStatus* ORT_API_CALL PolicyDelegate(_In_ const OrtEpDevice** ep_devices,
+ _In_ size_t num_devices,
+ _In_ const OrtKeyValuePairs* model_metadata,
+ _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/,
+ _Inout_ const OrtEpDevice** selected,
+ _In_ size_t max_selected,
+ _Out_ size_t* num_selected,
+ _In_ void* /*state*/) {
+ *num_selected = 0;
+
+ if (max_selected <= 2) {
+ return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Expected to be able to select 2 devices.");
+ }
+
+ if (model_metadata->entries.empty()) {
+ return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Model metadata was empty.");
+ }
+
+ selected[0] = ep_devices[0];
+ *num_selected = 1;
+ if (num_devices > 1) {
+ // CPU EP is always last.
+ selected[1] = ep_devices[num_devices - 1];
+ *num_selected = 2;
+ }
+
+ return nullptr;
+}
+
+static OrtStatus* ORT_API_CALL PolicyDelegateSelectNone(_In_ const OrtEpDevice** /*ep_devices*/,
+ _In_ size_t /*num_devices*/,
+ _In_ const OrtKeyValuePairs* /*model_metadata*/,
+ _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/,
+ _Inout_ const OrtEpDevice** /*selected*/,
+ _In_ size_t /*max_selected*/,
+ _Out_ size_t* num_selected,
+ _In_ void* /*state*/) {
+ *num_selected = 0;
+
+ return nullptr;
+}
+
+static OrtStatus* ORT_API_CALL PolicyDelegateReturnError(_In_ const OrtEpDevice** /*ep_devices*/,
+ _In_ size_t /*num_devices*/,
+ _In_ const OrtKeyValuePairs* /*model_metadata*/,
+ _In_opt_ const OrtKeyValuePairs* /*runtime_metadata*/,
+ _Inout_ const OrtEpDevice** /*selected*/,
+ _In_ size_t /*max_selected*/,
+ _Out_ size_t* num_selected,
+ _In_ void* /*state*/) {
+ *num_selected = 0;
+
+ return Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Selection error.");
+}
+
+// test providing a delegate
+TEST(AutoEpSelection, PolicyDelegate) {
+ std::vector> inputs(1);
+ auto& input = inputs.back();
+ input.name = "X";
+ input.dims = {3, 2};
+ input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+
+ // prepare expected inputs and outputs
+ std::vector expected_dims_y = {3, 2};
+ std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
+
+ const Ort::KeyValuePairs provider_options;
+
+ TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"),
+ "", // don't need EP name
+ std::nullopt,
+ provider_options,
+ inputs,
+ "Y",
+ expected_dims_y,
+ expected_values_y,
+ /* auto_select */ true,
+ /*select_devices*/ nullptr,
+ std::nullopt,
+ PolicyDelegate);
+}
+
+// test providing a delegate
+TEST(AutoEpSelection, PolicyDelegateSelectsNothing) {
+ std::vector> inputs(1);
+ auto& input = inputs.back();
+ input.name = "X";
+ input.dims = {3, 2};
+ input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+
+ // prepare expected inputs and outputs
+ std::vector expected_dims_y = {3, 2};
+ std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
+
+ const Ort::KeyValuePairs provider_options;
+
+ ASSERT_THROW(
+ TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"),
+ "", // don't need EP name
+ std::nullopt,
+ provider_options,
+ inputs,
+ "Y",
+ expected_dims_y,
+ expected_values_y,
+ /* auto_select */ true,
+ /*select_devices*/ nullptr,
+ std::nullopt,
+ PolicyDelegateSelectNone,
+ /*test_session_creation_only*/ true),
+ Ort::Exception);
+}
+
+TEST(AutoEpSelection, PolicyDelegateReturnsError) {
+ std::vector> inputs(1);
+ auto& input = inputs.back();
+ input.name = "X";
+ input.dims = {3, 2};
+ input.values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
+
+ // prepare expected inputs and outputs
+ std::vector expected_dims_y = {3, 2};
+ std::vector expected_values_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f};
+
+ const Ort::KeyValuePairs provider_options;
+
+ ASSERT_THROW(
+ TestInference(*ort_env, ORT_TSTR("testdata/mul_1.onnx"),
+ "", // don't need EP name
+ std::nullopt,
+ provider_options,
+ inputs,
+ "Y",
+ expected_dims_y,
+ expected_values_y,
+ /* auto_select */ true,
+ /*select_devices*/ nullptr,
+ std::nullopt,
+ PolicyDelegateReturnError,
+ /*test_session_creation_only*/ true),
+ Ort::Exception);
+}
+
namespace {
struct ExamplePluginInfo {
const std::filesystem::path library_path =