diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 1ae7b5c9eb991..26b22f2499ad6 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -451,6 +451,10 @@ public struct OrtApi public IntPtr Graph_GetModelMetadata; public IntPtr GetModelCompatibilityForEpDevices; public IntPtr CreateExternalInitializerInfo; + + public IntPtr TensorTypeAndShape_HasShape; + public IntPtr KernelInfo_GetConfigEntries; + public IntPtr RegisterExecutionProviderLibraryWithOptions; } internal static class NativeMethods @@ -847,7 +851,7 @@ static NativeMethods() api_.CreateSyncStreamForEpDevice, typeof(DOrtCreateSyncStreamForEpDevice)); - OrtSyncStream_GetHandle = + OrtSyncStream_GetHandle = (DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer( api_.SyncStream_GetHandle, typeof(DOrtSyncStream_GetHandle)); @@ -861,6 +865,11 @@ static NativeMethods() (DOrtCopyTensors)Marshal.GetDelegateForFunctionPointer( api_.CopyTensors, typeof(DOrtCopyTensors)); + + OrtRegisterExecutionProviderLibraryWithOptions = + (DOrtRegisterExecutionProviderLibraryWithOptions)Marshal.GetDelegateForFunctionPointer( + api_.RegisterExecutionProviderLibraryWithOptions, + typeof(DOrtRegisterExecutionProviderLibraryWithOptions)); } internal class NativeLib @@ -2780,6 +2789,22 @@ out IntPtr /* OrtSyncStream** */ stream byte[] /* const char* */ registration_name, byte[] /* const ORTCHAR_T* */ path); + /// + /// Register an execution provider library. The provided options are passed to EP factories after creation. + /// The library must implement CreateEpFactories and ReleaseEpFactory. + /// + /// Environment to add the EP library to. + /// Name to register the library under. + /// Absolute path to the library. + /// Options passed to OrtEpFactory::SetEnvironmentOptions after creation. + /// OrtStatus* + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /* OrtStatus* */ DOrtRegisterExecutionProviderLibraryWithOptions( + IntPtr /* OrtEnv* */ env, + byte[] /* const char* */ registration_name, + byte[] /* const ORTCHAR_T* */ path, + IntPtr /* const OrtKeyValuePairs* */ options); + /// /// Unregister an execution provider library. /// @@ -2792,6 +2817,7 @@ out IntPtr /* OrtSyncStream** */ stream byte[] /* const char* */ registration_name); public static DOrtRegisterExecutionProviderLibrary OrtRegisterExecutionProviderLibrary; + public static DOrtRegisterExecutionProviderLibraryWithOptions OrtRegisterExecutionProviderLibraryWithOptions; public static DOrtUnregisterExecutionProviderLibrary OrtUnregisterExecutionProviderLibrary; /// diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs index 6fcff438c5cf3..10b8ae2cd258a 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtEnv.shared.cs @@ -523,9 +523,6 @@ public OrtLoggingLevel EnvLogLevel /// A registered execution provider library can be used by all sessions created with the OrtEnv instance. /// Devices the execution provider can utilize are added to the values returned by GetEpDevices() and can /// be used in SessionOptions.AppendExecutionProvider to select an execution provider for a device. - /// - /// Coming: A selection policy can be specified and ORT will automatically select the best execution providers - /// and devices for the model. /// /// The name to register the library under. /// The path to the library to register. @@ -540,6 +537,48 @@ public void RegisterExecutionProviderLibrary(string registrationName, string lib NativeMethods.OrtRegisterExecutionProviderLibrary(handle, registrationNameUtf8, pathUtf8)); } + /// + /// Register an execution provider library with the OrtEnv instance. The provided options are passed to + /// EP factory instances after creation. + /// + /// A registered execution provider library can be used by all sessions created with the OrtEnv instance. + /// Devices the execution provider can utilize are added to the values returned by GetEpDevices() and can + /// be used in SessionOptions.AppendExecutionProvider to select an execution provider for a device. + /// + /// The name to register the library under. + /// The path to the library to register. + /// Optional options to pass to each EP factory after creation. May be null. + /// + /// + public void RegisterExecutionProviderLibrary(string registrationName, string libraryPath, + IReadOnlyDictionary options) + { + var registrationNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(registrationName); + var pathUtf8 = NativeOnnxValueHelper.GetPlatformSerializedString(libraryPath); + + if (options != null && options.Count > 0) + { + // this creates an OrtKeyValuePairs instance with a backing native instance + using var optionsKvps = new OrtKeyValuePairs(options); + + NativeApiStatus.VerifySuccess( + NativeMethods.OrtRegisterExecutionProviderLibraryWithOptions( + handle, + registrationNameUtf8, + pathUtf8, + optionsKvps.Handle)); + } + else + { + NativeApiStatus.VerifySuccess( + NativeMethods.OrtRegisterExecutionProviderLibraryWithOptions( + handle, + registrationNameUtf8, + pathUtf8, + IntPtr.Zero)); // Options OrtKeyValuePairs + } + } + /// /// Unregister an execution provider library from the OrtEnv instance. /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs index 1be0b6e9530ed..80c318372ab30 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtAutoEpTests.cs @@ -93,6 +93,37 @@ public void RegisterUnregisterLibrary() } } + [Fact] + public void RegisterLibraryWithOptions() + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + const string registrationName = "example_plugin_ep_kernel_registry"; + const string epName = "ExampleKernelEp"; + + string libFullPath = Path.Combine(Directory.GetCurrentDirectory(), "example_plugin_ep_kernel_registry.dll"); + Assert.True(File.Exists(libFullPath), $"Expected lib {libFullPath} does not exist."); + + Dictionary options = new Dictionary { { "some_env_config", "2" } }; + ortEnvInstance.RegisterExecutionProviderLibrary(registrationName, libFullPath, options); + try + { + // check OrtEpDevice was found + var epDevices = ortEnvInstance.GetEpDevices(); + var epDevice = epDevices.FirstOrDefault(d => string.Equals(epName, d.EpName, StringComparison.OrdinalIgnoreCase)); + Assert.NotNull(epDevice); + + // The example EP stores the env config in the OrtEpDevice metadata for testing convenience. + var epMetadata = epDevice.EpMetadata.Entries; + Assert.Equal("2", epMetadata["some_env_config"]); + } + finally + { // unregister + ortEnvInstance.UnregisterExecutionProviderLibrary(registrationName); + } + } + } + [Fact] public void AppendToSessionOptionsV2() { @@ -194,7 +225,7 @@ public void SetEpSelectionPolicyDelegate() // doesn't matter what the value is. should fallback to ORT CPU EP sessionOptions.SetEpSelectionPolicyDelegate(SelectionPolicyDelegate); - + var model = TestDataLoader.LoadModelFromEmbeddedResource("squeezenet.onnx"); // session should load successfully diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 59ca1a1df762e..114c2d9487a14 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -132,7 +132,8 @@ class Environment { const OrtArenaCfg* arena_cfg = nullptr); #if !defined(ORT_MINIMAL_BUILD) - Status RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path); + Status RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path, + const OrtKeyValuePairs* options); Status UnregisterExecutionProviderLibrary(const std::string& registration_name); // convert an OrtEpFactory* to EpFactoryInternal* if possible. @@ -206,14 +207,16 @@ class Environment { Status RegisterExecutionProviderLibrary(const std::string& registration_name, std::unique_ptr ep_library, - const std::vector& internal_factories = {}); + const std::vector& internal_factories = {}, + const OrtKeyValuePairs* options = nullptr); struct EpInfo { // calls EpLibrary::Load // for each factory gets the OrtEpDevice instances and adds to execution_devices // internal_factory is set if this is an internal EP static Status Create(std::unique_ptr library_in, std::unique_ptr& out, - const std::vector& internal_factories = {}); + const std::vector& internal_factories = {}, + const OrtKeyValuePairs* options = nullptr); // removes entries for this library from execution_devices // calls EpLibrary::Unload diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 68899d75e9294..443931070bf67 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -6608,6 +6608,24 @@ struct OrtApi { * \since Version 1.24 */ ORT_API2_STATUS(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Register an execution provider library with ORT. The provided options are passed to + * OrtEpFactory::SetEnvironmentOptions after factory creation. + * + * The library must export 'CreateEpFactories' and 'ReleaseEpFactory' functions. + * See OrtEpApi for more details. + * + * \param[in] env The OrtEnv instance to register the library in. + * \param[in] registration_name The name to register the execution provider library under. + * \param[in] path The path to the execution provider library. + * \param[in] options Map of options to pass to each OrtEpFactory via OrtEpFactory::SetEnvironmentOptions. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(RegisterExecutionProviderLibraryWithOptions, _In_ OrtEnv* env, _In_ const char* registration_name, + _In_ const ORTCHAR_T* path, _In_opt_ const OrtKeyValuePairs* options); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index fd4d9a683b7cd..f20a5af6d553e 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1199,8 +1199,14 @@ struct Env : detail::Base { void ReleaseSharedAllocator(const OrtEpDevice* ep_device, OrtDeviceMemoryType mem_type); ///< Wraps OrtApi::ReleaseSharedAllocator - Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path); ///< Wraps OrtApi::RegisterExecutionProviderLibrary - Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary + ///< Wraps OrtApi::RegisterExecutionProviderLibrary + Env& RegisterExecutionProviderLibrary(const char* registration_name, const std::basic_string& path); + + ///< Wraps OrtApi::RegisterExecutionProviderLibraryWithOptions + Env& RegisterExecutionProviderLibraryWithOptions(const char* registration_name, const std::basic_string& path, + const OrtKeyValuePairs* options); + + Env& UnregisterExecutionProviderLibrary(const char* registration_name); ///< Wraps OrtApi::UnregisterExecutionProviderLibrary std::vector GetEpDevices() const; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 63dfc85560a39..2254e6da2cb2d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -835,6 +835,13 @@ inline Env& Env::RegisterExecutionProviderLibrary(const char* registration_name, return *this; } +inline Env& Env::RegisterExecutionProviderLibraryWithOptions(const char* registration_name, + const std::basic_string& path, + const OrtKeyValuePairs* options) { + ThrowOnError(GetApi().RegisterExecutionProviderLibraryWithOptions(p_, registration_name, path.c_str(), options)); + return *this; +} + inline Env& Env::UnregisterExecutionProviderLibrary(const char* registration_name) { ThrowOnError(GetApi().UnregisterExecutionProviderLibrary(p_, registration_name)); return *this; diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index c67e73d1cd4a0..30b79393c551c 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -1510,12 +1510,13 @@ struct OrtEpFactory { _In_opt_ const OrtKeyValuePairs* stream_options, _Outptr_ OrtSyncStreamImpl** stream); - /** \brief Set environment options on this EP factory. + /** \brief Sets environment options that are provided by the application during EP library registration. * - * Environment options can be set by ORT after calling the library's 'CreateEpFactories' function to - * create EP factories. + * If defined, ORT calls this function during EP library registration directly after creating the factory instance. + * Valid option keys and values are defined by the EP library. However, some common EP-agnostic options are listed + * below. * - * Supported options: + * Common EP options: * "allow_virtual_devices": Allows EP factory to specify OrtEpDevice instances that use custom * virtual OrtHardwareDevices, which can be created via OrtEpApi::CreateHardwareDevice(). * @@ -1528,10 +1529,11 @@ struct OrtEpFactory { * -# "1": Creation of virtual devices is allowed. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] options The configuration options. + * \param[in] options The configuration options. Do not cache pointers to the OrtKeyValuePairs instance or its + * keys and values. Key and value strings should be copied if necessary. * * \note Implementation of this function is optional. - * An EP factory should only implement this if it needs to handle any environment options. + * An EP factory should implement this if it needs to handle any environment options. * * \snippet{doc} snippets.dox OrtStatus Return Value * diff --git a/onnxruntime/core/session/abi_key_value_pairs.h b/onnxruntime/core/session/abi_key_value_pairs.h index 72d5007b84e6f..63bfded8fa9fb 100644 --- a/onnxruntime/core/session/abi_key_value_pairs.h +++ b/onnxruntime/core/session/abi_key_value_pairs.h @@ -39,6 +39,10 @@ struct OrtKeyValuePairs { Sync(); } + bool HasKey(const std::string& key) const { + return entries_.find(key) != entries_.end(); + } + void Add(const char* key, const char* value) { // ignore if either are nullptr. if (key && value) { diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index cde77eeed8aa5..9ff26f6ddfd8a 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -498,7 +498,8 @@ Status CreateDataTransferForFactory(OrtEpFactory& ep_factory, Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, std::unique_ptr ep_library, - const std::vector& internal_factories) { + const std::vector& internal_factories, + const OrtKeyValuePairs* options) { if (ep_libraries_.count(registration_name) > 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "library is already registered under ", registration_name); } @@ -508,7 +509,7 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra ORT_TRY { // create the EpInfo which loads the library if required std::unique_ptr ep_info = nullptr; - ORT_RETURN_IF_ERROR(EpInfo::Create(std::move(ep_library), ep_info)); + ORT_RETURN_IF_ERROR(EpInfo::Create(std::move(ep_library), ep_info, internal_factories, options)); // add the pointers to the OrtEpDevice instances to our global list execution_devices_.reserve(execution_devices_.size() + ep_info->execution_devices.size()); @@ -562,13 +563,15 @@ Status Environment::CreateAndRegisterInternalEps() { auto* internal_library_ptr = ep_library.get(); ORT_RETURN_IF_ERROR(RegisterExecutionProviderLibrary(internal_library_ptr->RegistrationName(), std::move(ep_library), - {&internal_library_ptr->GetInternalFactory()})); + {&internal_library_ptr->GetInternalFactory()}, + /*options*/ nullptr)); } return Status::OK(); } -Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path) { +Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, const ORTCHAR_T* lib_path, + const OrtKeyValuePairs* options) { std::lock_guard lock{mutex_}; std::vector internal_factories = {}; @@ -578,7 +581,7 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra ORT_RETURN_IF_ERROR(LoadPluginOrProviderBridge(registration_name, lib_path, ep_library, internal_factories)); - return RegisterExecutionProviderLibrary(registration_name, std::move(ep_library), internal_factories); + return RegisterExecutionProviderLibrary(registration_name, std::move(ep_library), internal_factories, options); } Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_name) { @@ -755,24 +758,30 @@ bool AreVirtualDevicesAllowed(std::string_view lib_registration_name) { suffix.size(), suffix) == 0; } -Status SetEpFactoryEnvironmentOptions(OrtEpFactory& factory, std::string_view lib_registration_name) { +Status SetEpFactoryEnvironmentOptions(OrtEpFactory& factory, std::string_view lib_registration_name, + const OrtKeyValuePairs* options) { // OrtEpFactory::SetEnvironmentOptions was added in ORT 1.24 if (factory.ort_version_supported < 24 || factory.SetEnvironmentOptions == nullptr) { return Status::OK(); } - // We only set one option now but this can be generalized if necessary. - OrtKeyValuePairs options; - options.Add("allow_virtual_devices", AreVirtualDevicesAllowed(lib_registration_name) ? "1" : "0"); + OrtKeyValuePairs factory_options = options != nullptr ? *options : OrtKeyValuePairs{}; - ORT_RETURN_IF_ERROR(ToStatusAndRelease(factory.SetEnvironmentOptions(&factory, &options))); + // Add option for "allow_virtual_devices" if not already specified by application's options. + if (!factory_options.HasKey("allow_virtual_devices")) { + // "allow_virtual_devices" is set to "1" if the library registration name ends with ".virtual". + factory_options.Add("allow_virtual_devices", AreVirtualDevicesAllowed(lib_registration_name) ? "1" : "0"); + } + + ORT_RETURN_IF_ERROR(ToStatusAndRelease(factory.SetEnvironmentOptions(&factory, &factory_options))); return Status::OK(); } } // namespace Status Environment::EpInfo::Create(std::unique_ptr library_in, std::unique_ptr& out, - const std::vector& internal_factories) { + const std::vector& internal_factories, + const OrtKeyValuePairs* options) { if (!library_in) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "EpLibrary was null"); } @@ -795,7 +804,7 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u auto& factory = *factory_ptr; - ORT_RETURN_IF_ERROR(SetEpFactoryEnvironmentOptions(factory, instance.library->RegistrationName())); + ORT_RETURN_IF_ERROR(SetEpFactoryEnvironmentOptions(factory, instance.library->RegistrationName(), options)); std::array ep_devices{nullptr}; size_t num_ep_devices = 0; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 82f7cef4aec49..2f5847ac883b1 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3287,7 +3287,18 @@ ORT_API(void, OrtApis::ReleaseKeyValuePairs, _Frees_ptr_opt_ OrtKeyValuePairs* k ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, const ORTCHAR_T* path) { API_IMPL_BEGIN - ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().RegisterExecutionProviderLibrary(registration_name, path)); + ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().RegisterExecutionProviderLibrary(registration_name, path, + nullptr)); + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibraryWithOptions, _In_ OrtEnv* env, + _In_ const char* registration_name, _In_ const ORTCHAR_T* path, + _In_opt_ const OrtKeyValuePairs* options) { + API_IMPL_BEGIN + ORT_API_RETURN_IF_STATUS_NOT_OK(env->GetEnvironment().RegisterExecutionProviderLibrary(registration_name, path, + options)); return nullptr; API_IMPL_END } @@ -3544,6 +3555,15 @@ ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*en API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibraryWithOptions, _In_ OrtEnv* /*env*/, + _In_ const char* /*registration_name*/, _In_ const ORTCHAR_T* /*path*/, + _In_opt_ const OrtKeyValuePairs* /*options*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "RegisterExecutionProviderLibraryWithOptions is not supported in a minimal build."); + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::UnregisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, _In_ const char* /*registration_name*/) { API_IMPL_BEGIN @@ -4238,6 +4258,7 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::TensorTypeAndShape_HasShape, &OrtApis::KernelInfo_GetConfigEntries, + &OrtApis::RegisterExecutionProviderLibraryWithOptions, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index f3525d8de7b95..dcf90f5e7ca1e 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -755,4 +755,6 @@ ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, ORT_API_STATUS_IMPL(KernelInfo_GetConfigEntries, _In_ const OrtKernelInfo* info, _Outptr_ OrtKeyValuePairs** out); +ORT_API_STATUS_IMPL(RegisterExecutionProviderLibraryWithOptions, _In_ OrtEnv* env, _In_ const char* registration_name, + _In_ const ORTCHAR_T* path, _In_opt_ const OrtKeyValuePairs* options); } // namespace OrtApis diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f0d8906d99c14..c3c65e4e5dc04 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1505,7 +1505,8 @@ void addGlobalMethods(py::module& m) { "register_execution_provider_library", [](const std::string& registration_name, const PathString& library_path) -> void { #if !defined(ORT_MINIMAL_BUILD) - OrtPybindThrowIfError(GetEnv().RegisterExecutionProviderLibrary(registration_name, library_path.c_str())); + OrtPybindThrowIfError(GetEnv().RegisterExecutionProviderLibrary(registration_name, library_path.c_str(), + nullptr)); #else ORT_UNUSED_PARAMETER(registration_name); ORT_UNUSED_PARAMETER(library_path); @@ -1513,6 +1514,23 @@ void addGlobalMethods(py::module& m) { #endif }, R"pbdoc(Register an execution provider library with ONNX Runtime.)pbdoc"); + m.def( + "register_execution_provider_library", + [](const std::string& registration_name, const PathString& library_path, + const std::map& options) -> void { +#if !defined(ORT_MINIMAL_BUILD) + OrtKeyValuePairs key_value_pairs; + key_value_pairs.CopyFromMap(options); + OrtPybindThrowIfError(GetEnv().RegisterExecutionProviderLibrary(registration_name, library_path.c_str(), + &key_value_pairs)); +#else + ORT_UNUSED_PARAMETER(registration_name); + ORT_UNUSED_PARAMETER(library_path); + ORT_UNUSED_PARAMETER(options); + ORT_THROW("Execution provider libraries are not supported in this build."); +#endif + }, + R"pbdoc(Register an execution provider library with ONNX Runtime. Options are passed to EP factories after creation.)pbdoc"); m.def( "unregister_execution_provider_library", [](const std::string& registration_name) -> void { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc index a520b02c20cba..3f142cac399f0 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc @@ -24,6 +24,7 @@ ExampleKernelEpFactory::ExampleKernelEpFactory(const OrtApi& ort_api, const OrtE GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; GetVersion = GetVersionImpl; + SetEnvironmentOptions = SetEnvironmentOptionsImpl; GetSupportedDevices = GetSupportedDevicesImpl; @@ -120,6 +121,23 @@ const char* ORT_API_CALL ExampleKernelEpFactory::GetVersionImpl(const OrtEpFacto return factory->ep_version_.c_str(); } +/*static*/ +OrtStatus* ORT_API_CALL ExampleKernelEpFactory::SetEnvironmentOptionsImpl(OrtEpFactory* this_ptr, + const OrtKeyValuePairs* options) noexcept { + // This factory just gets some trivial value from the environment and stores it in the factory as an example. + // + // An actual EP factory implementation could use these environment options to set common configurations + // for all EPs created by this factory. + auto* factory = static_cast(this_ptr); + const char* value = factory->ort_api_.GetKeyValue(options, "some_env_config"); + + if (value != nullptr) { + factory->some_env_config_ = value; + } + + return nullptr; +} + /*static*/ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* hw_devices, @@ -146,6 +164,11 @@ OrtStatus* ORT_API_CALL ExampleKernelEpFactory::GetSupportedDevicesImpl(OrtEpFac factory->ort_api_.AddKeyValuePair(ep_metadata, "supported_devices", "CrackGriffin 7+"); factory->ort_api_.AddKeyValuePair(ep_options, "run_really_fast", "true"); + if (!factory->some_env_config_.empty()) { + // Store config obtained from OrtEpFactory::SetEnvironmentOptions into the EP metadata for testing convenience. + factory->ort_api_.AddKeyValuePair(ep_metadata, "some_env_config", factory->some_env_config_.c_str()); + } + // OrtEpDevice copies ep_metadata and ep_options. OrtEpDevice* ep_device = nullptr; auto* status = factory->ort_api_.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h index a2340b8b1499d..e64751500a20f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h @@ -35,6 +35,8 @@ class ExampleKernelEpFactory : public OrtEpFactory { static uint32_t ORT_API_CALL GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept; static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL SetEnvironmentOptionsImpl(OrtEpFactory* this_ptr, + const OrtKeyValuePairs* options) noexcept; static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, const OrtHardwareDevice* const* devices, @@ -76,6 +78,7 @@ class ExampleKernelEpFactory : public OrtEpFactory { const std::string vendor_{"Contoso2"}; // EP vendor name const uint32_t vendor_id_{0xB358}; // EP vendor ID const std::string ep_version_{"0.1.0"}; // EP version + std::string some_env_config_; Ort::MemoryInfo default_memory_info_; Ort::MemoryInfo readonly_memory_info_; diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc index 7415c5e138874..a97ead2fa153e 100644 --- a/onnxruntime/test/autoep/test_registration.cc +++ b/onnxruntime/test/autoep/test_registration.cc @@ -167,5 +167,59 @@ TEST(OrtEpLibrary, LoadUnloadPluginVirtGpuLibraryCxxApi) { ort_env->UnregisterExecutionProviderLibrary(registration_name_for_virtual_devices.c_str()); } } + +TEST(OrtEpLibrary, RegisterExecutionProviderLibraryWithOptions) { + const std::filesystem::path& library_path = Utils::example_ep_kernel_registry_info.library_path; + const std::string& registration_name = Utils::example_ep_kernel_registry_info.registration_name; + const std::string& ep_name = Utils::example_ep_kernel_registry_info.ep_name; + + auto get_plugin_ep_devices = [&]() -> std::vector { + std::vector all_ep_devices = ort_env->GetEpDevices(); + std::vector ep_devices; + + std::copy_if(all_ep_devices.begin(), all_ep_devices.end(), std::back_inserter(ep_devices), + [&](Ort::ConstEpDevice& device) { + return device.EpName() == ep_name; + }); + + return ep_devices; + }; + + // Test registering example kernel EP with options. These options are passed to OrtEpFactory::SetEnvironmentOptions. + // For testing convenience, the example EP stores a specific env option into the OrtEpDevice metadata. + { + Ort::KeyValuePairs ep_env_options; + ep_env_options.Add("some_env_config", "2"); + + ort_env->RegisterExecutionProviderLibraryWithOptions(registration_name.c_str(), + library_path.c_str(), + ep_env_options); + + std::vector ep_devices = get_plugin_ep_devices(); + ASSERT_EQ(ep_devices.size(), 1); + + // The example EP stores the env config in the OrtEpDevice metadata just for testing convenience. + auto metadata = ep_devices[0].EpMetadata(); + ASSERT_STREQ(metadata.GetValue("some_env_config"), "2"); + + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); + } + + // Test calling RegisterExecutionProviderLibraryWithOptions using a NULL options argument. + { + ort_env->RegisterExecutionProviderLibraryWithOptions(registration_name.c_str(), + library_path.c_str(), + /*options*/ nullptr); + + std::vector ep_devices = get_plugin_ep_devices(); + ASSERT_EQ(ep_devices.size(), 1); + + // The example EP stores the env config in the OrtEpDevice metadata just for testing convenience. + auto metadata = ep_devices[0].EpMetadata(); + ASSERT_EQ(metadata.GetValue("some_env_config"), nullptr); // Does not exist + + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/autoep_helper.py b/onnxruntime/test/python/autoep_helper.py index e3b214afa6e62..5ff5ec4097b40 100644 --- a/onnxruntime/test/python/autoep_helper.py +++ b/onnxruntime/test/python/autoep_helper.py @@ -37,6 +37,10 @@ def tearDownClass(cls): cls._tmp_model_dir.cleanup() def register_execution_provider_library(self, ep_registration_name: str, ep_lib_path: os.PathLike | str): + """ + Test utility that registers an execution provider library with ORT. + Ensures that the library is unregistered if the unit test doesn't do it. + """ if ep_registration_name in self._registered_providers: return # Already registered @@ -52,6 +56,32 @@ def register_execution_provider_library(self, ep_registration_name: str, ep_lib_ # If the unit test itself does not unregister the library, tearDown() will try. self._registered_providers.add(ep_registration_name) + def register_execution_provider_library_with_options( + self, + ep_registration_name: str, + ep_lib_path: os.PathLike | str, + options: dict[str, str], + ): + """ + Test utility that registers an execution provider library with ORT. + Ensures that the library is unregistered if the unit test doesn't do it. + The provided options are passed to EP factories after creation. + """ + if ep_registration_name in self._registered_providers: + return # Already registered + + try: + onnxrt.register_execution_provider_library(ep_registration_name, ep_lib_path, options) + except Fail as onnxruntime_error: + if "already registered" in str(onnxruntime_error): + pass # Allow register to fail if the EP library was previously registered. + else: + raise onnxruntime_error + + # Add this EP library to set of registered EP libraries. + # If the unit test itself does not unregister the library, tearDown() will try. + self._registered_providers.add(ep_registration_name) + def unregister_execution_provider_library(self, ep_registration_name: str): if ep_registration_name not in self._registered_providers: return # Not registered diff --git a/onnxruntime/test/python/onnxruntime_test_python_autoep.py b/onnxruntime/test/python/onnxruntime_test_python_autoep.py index a24269a312e9b..b9f8057c4b12b 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_autoep.py +++ b/onnxruntime/test/python/onnxruntime_test_python_autoep.py @@ -341,6 +341,30 @@ def test_copy_tensors(self): self.unregister_execution_provider_library(ep_name) + def test_register_execution_provider_with_options(self): + """ + Test registration of an example EP plugin with options. + """ + ep_lib_path = "example_plugin_ep_kernel_registry.dll" + try: + ep_lib_path = get_name("example_plugin_ep_kernel_registry.dll") + except FileNotFoundError: + self.skipTest(f"Skipping test because EP library '{ep_lib_path}' cannot be found") + + registration_name = "example_plugin_ep_kernel_registry" + ep_name = "ExampleKernelEp" + options = {"some_env_config": "2"} # The environment options to pass to OrtEpFactory::SetEnvironmentOptions + self.register_execution_provider_library_with_options(registration_name, os.path.realpath(ep_lib_path), options) + + ep_devices = onnxrt.get_ep_devices() + ep_device = next((ep_device for ep_device in ep_devices if ep_device.ep_name == ep_name), None) + self.assertIsNotNone(ep_device) + + # The example EP stores the env config in the OrtEpDevice metadata for testing convenience. + ep_metadata = ep_device.ep_metadata + self.assertEqual(ep_metadata["some_env_config"], "2") + self.unregister_execution_provider_library(registration_name) + if __name__ == "__main__": unittest.main(verbosity=1)