diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 5288296fd4750..9a5891f9e236d 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -700,6 +700,9 @@ typedef struct OrtModelEditorApi OrtModelEditorApi; struct OrtCompileApi; typedef struct OrtCompileApi OrtCompileApi; +struct OrtEpApi; +typedef struct OrtEpApi OrtEpApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -5186,6 +5189,12 @@ struct OrtApi { * \since Version 1.22. */ const OrtHardwareDevice*(ORT_API_CALL* EpDevice_Device)(_In_ const OrtEpDevice* ep_device); + + /** \brief Get the OrtEpApi instance for implementing an execution provider. + * + * \since Version 1.22. + */ + const OrtEpApi*(ORT_API_CALL* GetEpApi)(); }; /* @@ -5889,6 +5898,29 @@ struct OrtCompileApi { ORT_RUNTIME_CLASS(Ep); ORT_RUNTIME_CLASS(EpFactory); +struct OrtEpApi { + /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice. + * \param[in] ep_factory Execution provider factory that is creating the instance. + * \param[in] hardware_device Hardware device that the EP can utilize. + * \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used + * during execution provider selection and passed to CreateEp. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added + * to the Session configuration options if the execution provider is selected. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param ep_device OrtExecutionDevice that is created. + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ep_device); + + ORT_CLASS_RELEASE(EpDevice); +}; + /** * \brief The OrtEp struct provides functions to implement for an execution provider. * \since Version 1.22. @@ -5993,21 +6025,28 @@ struct OrtEpFactory { /** \brief Get information from the execution provider if it supports the OrtHardwareDevice. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] device The OrtHardwareDevice instance. - * \param[out] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used - * during execution provider selection and/or CreateEp. - * \param[out] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added - * to the Session configuration options if the execution provider is selected. + * Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice. + * \param[in] devices The OrtHardwareDevice instances that are available. + * \param[in] num_devices The number of OrtHardwareDevice instances. + * \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use. + * The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice + * instances to this pre-allocated array. ORT will take ownership of the values returned. + * i.e. usage is `ep_devices[0] = ;` + * \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices. + * Current default is 8. This can be increased if needed. + * \param[out] num_ep_devices The number of EP devices added to ep_devices. * \return true if the factory can create an execution provider that uses `device`. * * \note ORT will take ownership or ep_metadata and/or ep_options if they are not null. * * \since Version 1.22. */ - bool(ORT_API_CALL* GetDeviceInfoIfSupported)(const OrtEpFactory* this_ptr, - _In_ const OrtHardwareDevice* device, - _Out_opt_ OrtKeyValuePairs** ep_metadata, - _Out_opt_ OrtKeyValuePairs** ep_options); + OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); /** \brief Function to create an OrtEp instance for use in a Session. * @@ -6015,11 +6054,11 @@ struct OrtEpFactory { * * \param[in] this_ptr The OrtEpFactory instance. * \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use. - * \param[in] ep_metadata_pairs Execution provider metadata that was returned in GetDeviceInfoIfSupported, for each + * \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each * device. * \param[in] num_devices The number of devices the execution provider was selected for. * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the - * session. This will include ep_options from GetDeviceInfoIfSupported as well as any + * session. This will include ep_options from GetSupportedDevices as well as any * user provided overrides. * Execution provider options will have been added with a prefix of 'ep..'. * The OrtSessionOptions instance will NOT be valid after this call and should not be @@ -6029,7 +6068,7 @@ struct OrtEpFactory { * * \snippet{doc} snippets.dox OrtStatus Return Value * - * \since Version 1.22. + * \since Version . This is a placeholder. */ OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, _In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -6043,7 +6082,7 @@ struct OrtEpFactory { * \param[in] this_ptr The OrtEpFactory instance. * \param[in] ep The OrtEp instance to release. * - * \since Version 1.22. + * \since Version . This is a placeholder. */ void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index a9deb2dd3e341..0ecc27c59dc28 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -172,6 +172,20 @@ inline const OrtCompileApi& GetCompileApi() { return *api; } +/// +/// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider. +/// +/// ORT C EP API reference +inline const OrtEpApi& GetEpApi() { + auto* api = GetApi().GetEpApi(); + if (api == nullptr) { + // minimal build + ORT_CXX_API_THROW("EP API is not available in this build", ORT_FAIL); + } + + return *api; +} + /** \brief IEEE 754 half-precision floating point data type * * \details This struct is used for converting float to float16 and back @@ -561,6 +575,7 @@ ORT_DEFINE_RELEASE(Graph); ORT_DEFINE_RELEASE(Model); ORT_DEFINE_RELEASE(KeyValuePairs) ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi); +ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi); #undef ORT_DEFINE_RELEASE #undef ORT_DEFINE_RELEASE_FROM_API_STRUCT @@ -763,10 +778,16 @@ struct KeyValuePairs : detail::KeyValuePairsImpl { /// Take ownership of a pointer created by C API explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl{p} {} + /// \brief Wraps OrtApi::CreateKeyValuePairs explicit KeyValuePairs(); + + /// \brief Wraps OrtApi::CreateKeyValuePairs and OrtApi::AddKeyValuePair explicit KeyValuePairs(const std::unordered_map& kv_pairs); + /// \brief Wraps OrtApi::AddKeyValuePair void Add(const char* key, const char* value); + + /// \brief Wraps OrtApi::RemoveKeyValuePair void Remove(const char* key); ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; } @@ -806,10 +827,21 @@ struct EpDeviceImpl : Ort::detail::Base { } // namespace detail /** \brief Wrapper around ::OrtEpDevice - * \remarks EpDevice is always read-only for API users. + * \remarks EpDevice is always read-only for ORT API users. */ using ConstEpDevice = detail::EpDeviceImpl>; +/** \brief Mutable EpDevice that is created by EpApi users. + */ +struct EpDevice : detail::EpDeviceImpl { + explicit EpDevice(std::nullptr_t) {} ///< No instance is created + explicit EpDevice(OrtEpDevice* p) : EpDeviceImpl{p} {} ///< Take ownership of a pointer created by C API + + /// \brief Wraps OrtEpApi::CreateEpDevice + EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device, + ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {}); +}; + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 57b4f1b3ead66..48b3b80cced55 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -593,6 +593,11 @@ inline ConstHardwareDevice EpDeviceImpl::Device() const { } } // namespace detail +inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device, + ConstKeyValuePairs ep_metadata, ConstKeyValuePairs ep_options) { + ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_)); +} + inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); if (strcmp(logid, "onnxruntime-node") == 0) { diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index ed8d6ea71aea4..c4520fe38cd2a 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -314,7 +314,7 @@ struct CudaEpFactory : OrtEpFactory { CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetDeviceInfoIfSupported = GetDeviceInfoIfSupportedImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; } @@ -329,18 +329,26 @@ struct CudaEpFactory : OrtEpFactory { return factory->vendor.c_str(); } - static bool GetDeviceInfoIfSupportedImpl(const OrtEpFactory* this_ptr, - const OrtHardwareDevice* device, - _Out_opt_ OrtKeyValuePairs** /*ep_metadata*/, - _Out_opt_ OrtKeyValuePairs** /*ep_options*/) { - const auto* factory = static_cast(this_ptr); - - if (factory->ort_api.HardwareDevice_Type(device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && - factory->ort_api.HardwareDevice_VendorId(device) == 0x10de) { - return true; + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && + factory->ort_api.HardwareDevice_VendorId(&device) == 0x10de) { + ORT_API_RETURN_IF_ERROR( + factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } } - return false; + return nullptr; } static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, @@ -385,7 +393,7 @@ OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase } OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { - delete factory; + delete static_cast(factory); return nullptr; } } diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index ad7eb9cdfff25..351c44aae185d 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -3,8 +3,11 @@ #include "core/session/environment.h" +#include + #include "core/common/basic_types.h" #include "core/framework/allocator_utils.h" +#include "core/framework/error_code_helper.h" #include "core/graph/constants.h" #include "core/graph/op.h" #include "core/platform/device_discovery.h" @@ -468,6 +471,28 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam return status; } +namespace { +std::vector SortDevicesByType() { + auto& devices = DeviceDiscovery::GetDevices(); + std::vector sorted_devices; + sorted_devices.reserve(devices.size()); + + const auto select_by_type = [&](OrtHardwareDeviceType type) { + for (const auto& device : devices) { + if (device.type == type) { + sorted_devices.push_back(&device); + } + } + }; + + select_by_type(OrtHardwareDeviceType_NPU); + select_by_type(OrtHardwareDeviceType_GPU); + select_by_type(OrtHardwareDeviceType_CPU); + + return sorted_devices; +} +} // namespace + Status Environment::EpInfo::Create(std::unique_ptr library_in, std::unique_ptr& out, const std::vector& internal_factories) { if (!library_in) { @@ -482,36 +507,25 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u ORT_RETURN_IF_ERROR(instance.library->Load()); const auto& factories = instance.library->GetFactories(); + // OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured. + // the set of hardware devices is static so this can also be static. + const static std::vector sorted_devices = SortDevicesByType(); + for (auto* factory_ptr : factories) { ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:", instance.library->RegistrationName()); auto& factory = *factory_ptr; - // for each device - for (const auto& device : DeviceDiscovery::GetDevices()) { - OrtKeyValuePairs* ep_metadata = nullptr; - OrtKeyValuePairs* ep_options = nullptr; - - if (factory.GetDeviceInfoIfSupported(&factory, &device, &ep_metadata, &ep_options)) { - auto ed = std::make_unique(); - ed->ep_name = factory.GetName(&factory); - ed->ep_vendor = factory.GetVendor(&factory); - ed->device = &device; - - if (ep_metadata) { - ed->ep_metadata = std::move(*ep_metadata); - delete ep_metadata; - } - - if (ep_options) { - ed->ep_options = std::move(*ep_options); - delete ep_options; - } - - ed->ep_factory = &factory; + std::array ep_devices{nullptr}; + size_t num_ep_devices = 0; + ORT_RETURN_IF_ERROR(ToStatus( + factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(), + ep_devices.data(), ep_devices.size(), &num_ep_devices))); - instance.execution_devices.push_back(std::move(ed)); + for (size_t i = 0; i < num_ep_devices; ++i) { + if (ep_devices[i] != nullptr) { // should never happen but just in case... + instance.execution_devices.emplace_back(ep_devices[i]); // take ownership } } } diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc new file mode 100644 index 0000000000000..0cac00326392c --- /dev/null +++ b/onnxruntime/core/session/ep_api.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/ep_api.h" + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/ort_apis.h" + +using namespace onnxruntime; +namespace OrtExecutionProviderApi { +ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ort_ep_device) { + API_IMPL_BEGIN + auto ep_device = std::make_unique(); + ep_device->device = hardware_device; + ep_device->ep_factory = ep_factory; + ep_device->ep_name = ep_factory->GetName(ep_factory); + ep_device->ep_vendor = ep_factory->GetVendor(ep_factory); + + if (ep_metadata) { + ep_device->ep_metadata = *ep_metadata; + } + + if (ep_options) { + ep_device->ep_options = *ep_options; + } + + *ort_ep_device = ep_device.release(); + return nullptr; + API_IMPL_END +} + +ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device) { + delete device; +} + +static constexpr OrtEpApi ort_ep_api = { + // NOTE: ABI compatibility depends on the order within this struct so all additions must be at the end, + // and no functions can be removed (the implementation needs to change to return an error). + + &OrtExecutionProviderApi::CreateEpDevice, + &OrtExecutionProviderApi::ReleaseEpDevice, +}; + +// checks that we don't violate the rule that the functions must remain in the slots they were originally assigned +static_assert(offsetof(OrtEpApi, ReleaseEpDevice) / sizeof(void*) == 1, + "Size of version 22 API cannot change"); // initial version in ORT 1.22 + +} // namespace OrtExecutionProviderApi + +ORT_API(const OrtEpApi*, OrtExecutionProviderApi::GetEpApi) { + return &ort_ep_api; +} diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/ep_api.h new file mode 100644 index 0000000000000..23cd31cbdd861 --- /dev/null +++ b/onnxruntime/core/session/ep_api.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/session/onnxruntime_c_api.h" + +namespace OrtExecutionProviderApi { +// implementation that returns the API struct +ORT_API(const OrtEpApi*, GetEpApi); + +ORT_API_STATUS_IMPL(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ep_device); + +ORT_API(void, ReleaseEpDevice, _Frees_ptr_opt_ OrtEpDevice* device); +} // namespace OrtExecutionProviderApi diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index 1626d9c091893..23c25b4e7befb 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -16,12 +16,14 @@ struct ForwardToFactory { return static_cast(this_ptr)->GetVendor(); } - static bool ORT_API_CALL GetDeviceInfoIfSupported(const OrtEpFactory* this_ptr, - const OrtHardwareDevice* device, - OrtKeyValuePairs** ep_device_metadata, - OrtKeyValuePairs** ep_options_for_device) { - return static_cast(this_ptr)->GetDeviceInfoIfSupported(device, ep_device_metadata, - ep_options_for_device); + static OrtStatus* ORT_API_CALL GetSupportedDevices(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) { + return static_cast(this_ptr)->GetSupportedDevices(devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices); } static OrtStatus* ORT_API_CALL CreateEp(OrtEpFactory* this_ptr, diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index 71774e11a7246..fd907302b6b8d 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,25 +14,27 @@ namespace onnxruntime { using Forward = ForwardToFactory; EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, - IsSupportedFunc&& is_supported_func, + GetSupportedFunc&& get_supported_func, CreateFunc&& create_func) : ep_name_{ep_name}, vendor_{vendor}, - is_supported_func_{std::move(is_supported_func)}, + get_supported_func_{std::move(get_supported_func)}, create_func_{create_func} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; OrtEpFactory::GetVendor = Forward::GetVendor; - OrtEpFactory::GetDeviceInfoIfSupported = Forward::GetDeviceInfoIfSupported; + OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; } -bool EpFactoryInternal::GetDeviceInfoIfSupported(const OrtHardwareDevice* device, - OrtKeyValuePairs** ep_device_metadata, - OrtKeyValuePairs** ep_options_for_device) const { - return is_supported_func_(device, ep_device_metadata, ep_options_for_device); +OrtStatus* EpFactoryInternal::GetSupportedDevices(const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) { + return get_supported_func_(this, devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); } OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, @@ -57,7 +59,7 @@ OrtStatus* EpFactoryInternal::CreateIExecutionProvider(const OrtHardwareDevice* "EpFactoryInternal currently only supports one device at a time."); } - return create_func_(devices, ep_metadata_pairs, num_devices, session_options, session_logger, ep); + return create_func_(this, devices, ep_metadata_pairs, num_devices, session_options, session_logger, ep); } void EpFactoryInternal::ReleaseEp(OrtEp* /*ep*/) { diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index cfe3685e3e8e6..2dcc769ec635e 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -16,26 +16,34 @@ struct SessionOptions; class EpFactoryInternal : public OrtEpFactory { public: - using IsSupportedFunc = std::function; - - using CreateFunc = std::function; + + using CreateFunc = std::function* ep)>; EpFactoryInternal(const std::string& ep_name, const std::string& vendor, - IsSupportedFunc&& is_supported_func, + GetSupportedFunc&& get_supported_func, CreateFunc&& create_func); const char* GetName() const { return ep_name_.c_str(); } const char* GetVendor() const { return vendor_.c_str(); } - bool GetDeviceInfoIfSupported(_In_ const OrtHardwareDevice* device, - _Out_ OrtKeyValuePairs** ep_device_metadata, - _Out_ OrtKeyValuePairs** ep_options_for_device) const; + OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); // we don't implement this. CreateIExecutionProvider should be used. OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -55,10 +63,10 @@ class EpFactoryInternal : public OrtEpFactory { void ReleaseEp(OrtEp* ep); private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const IsSupportedFunc is_supported_func_; // function to check if the device is supported - const CreateFunc create_func_; // function to create the EP instance + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const GetSupportedFunc get_supported_func_; // function to return supported devices + const CreateFunc create_func_; // function to create the EP instance std::vector> eps_; // EP instances created by this factory }; diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index 0684e358b93e9..c515195c7e6bf 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -3,11 +3,13 @@ #include "core/session/ep_library_internal.h" +#include "core/framework/error_code_helper.h" #include "core/framework/session_options.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/abi_devices.h" #include "core/session/abi_logger.h" #include "core/session/abi_session_options_impl.h" +#include "core/session/ep_api.h" #include "core/session/ort_apis.h" #if defined(USE_DML) @@ -20,17 +22,27 @@ namespace onnxruntime { std::unique_ptr EpLibraryInternal::CreateCpuEp() { - const auto is_supported = [](const OrtHardwareDevice* device, - OrtKeyValuePairs** /*ep_metadata*/, - OrtKeyValuePairs** /*ep_options*/) -> bool { - if (device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - return true; + const auto get_supported = [](OrtEpFactory* factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) -> OrtStatus* { + size_t& num_ep_devices = *p_num_ep_devices; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + ORT_API_RETURN_IF_ERROR( + OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } } - return false; + return nullptr; }; - const auto create_cpu_ep = [](const OrtHardwareDevice* const* /*devices*/, + const auto create_cpu_ep = [](OrtEpFactory* /*factory*/, + const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t num_devices, const OrtSessionOptions* session_options, @@ -49,33 +61,47 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { }; std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", is_supported, create_cpu_ep); + auto cpu_factory = std::make_unique(ep_name, "Microsoft", get_supported, create_cpu_ep); return std::make_unique(std::move(cpu_factory)); } #if defined(USE_DML) std::unique_ptr EpLibraryInternal::CreateDmlEp() { static const std::string ep_name = kDmlExecutionProvider; - const auto is_supported = [](const OrtHardwareDevice* device, - OrtKeyValuePairs** /*ep_metadata*/, - OrtKeyValuePairs** ep_options) -> bool { - if (device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is associated with - // a specific device. - // How would we know what options should not allow user overrides if set in OrtEpDevice? - if (auto it = device->metadata.entries.find("DxgiAdapterNumber"); it != device->metadata.entries.end()) { - auto options = std::make_unique(); - options->Add("device_id", it->second.c_str()); - *ep_options = options.release(); + const auto is_supported = [](OrtEpFactory* factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) -> OrtStatus* { + size_t& num_ep_devices = *p_num_ep_devices; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + std::unique_ptr ep_options; + + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is associated with + // a specific device. + // How would we know what options should not allow user overrides if set in OrtEpDevice? + if (auto it = device.metadata.entries.find("DxgiAdapterNumber"); it != device.metadata.entries.end()) { + ep_options = std::make_unique(); + ep_options->Add("device_id", it->second.c_str()); + } + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices++]); + + if (api_status != nullptr) { + return api_status; + } } - - return true; } - return false; + return nullptr; }; - const auto create_dml_ep = [](const OrtHardwareDevice* const* /*devices*/, + const auto create_dml_ep = [](OrtEpFactory* /*factory*/, + const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t num_devices, const OrtSessionOptions* session_options, @@ -106,20 +132,27 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { static const std::string ep_name = kWebGpuExecutionProvider; - const auto is_supported = [](const OrtHardwareDevice* device, - OrtKeyValuePairs** /*ep_metadata*/, - OrtKeyValuePairs** /*ep_options*/) -> bool { - if (device->type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // What is the correct behavior here to match the device if there are multiple GPUs? - // Should WebGPU default to picking the GPU with HighPerformanceIndex of 0? - // Or should we be setting the 'deviceId', 'webgpuInstance' and 'webgpuDevice' options for each GPU? - return true; + const auto is_supported = [](OrtEpFactory* factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) -> OrtStatus* { + size_t& num_ep_devices = *p_num_ep_devices; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } } - return false; + return nullptr; }; - const auto create_webgpu_ep = [](const OrtHardwareDevice* const* /*devices*/, + const auto create_webgpu_ep = [](OrtEpFactory* /*factory*/, + const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t num_devices, const OrtSessionOptions* session_options, diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/ep_library_plugin.cc index 0cd03b2a4be07..3c873ec4a9aeb 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/ep_library_plugin.cc @@ -51,6 +51,8 @@ Status EpLibraryPlugin::Load() { } Status EpLibraryPlugin::Unload() { + std::lock_guard lock{mutex_}; + // Call ReleaseEpFactory for all factories and unload the library. // Current implementation assumes any error is permanent so does not leave pieces around to re-attempt Unload. if (handle_) { diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 790a5a782de1c..73423a4744576 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -3,6 +3,7 @@ #include "core/session/ep_library_provider_bridge.h" +#include "core/common/status.h" #include "core/framework/error_code_helper.h" #include "core/framework/session_options.h" #include "core/providers/cuda/cuda_provider_options.h" @@ -13,17 +14,48 @@ namespace onnxruntime { Status EpLibraryProviderBridge::Load() { - // wrap the EpLibraryPlugin factories that were created by calling CreateEpFactories. - // use GetDeviceInfoIfSupported from the factory. + std::lock_guard lock{mutex_}; + + if (!factories_.empty()) { + // already loaded + return Status::OK(); + } + + // if we have been unloaded we can't just be reloaded. + if (!ep_library_plugin_ || !provider_library_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EpLibraryProviderBridge has been unloaded. " + "Please create a new instance using LoadPluginOrProviderBridge."); + } + + // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. + // use GetSupportedDevices from the library's factory. + // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. + // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. for (const auto& factory : ep_library_plugin_->GetFactories()) { - const auto is_supported_fn = [&factory](const OrtHardwareDevice* device, - OrtKeyValuePairs** ep_metadata, - OrtKeyValuePairs** ep_options) -> bool { - return factory->GetDeviceInfoIfSupported(factory, device, ep_metadata, ep_options); + const auto is_supported_fn = [&factory](OrtEpFactory* ep_factory_internal, // from factory_ptrs_ + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) -> OrtStatus* { + ORT_API_RETURN_IF_ERROR(factory->GetSupportedDevices(factory, devices, num_devices, ep_devices, max_ep_devices, + num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = ep_factory_internal; + } + } + + return nullptr; }; - const auto create_fn = [this, &factory](const OrtHardwareDevice* const* devices, + const auto create_fn = [this, &factory](OrtEpFactory* /*ep_factory_internal from factory_ptrs_*/, + const OrtHardwareDevice* const* devices, const OrtKeyValuePairs* const* ep_metadata_pairs, size_t num_devices, const OrtSessionOptions* session_options, @@ -42,7 +74,6 @@ Status EpLibraryProviderBridge::Load() { factory->GetVendor(factory), is_supported_fn, create_fn); - factory_ptrs_.push_back(internal_factory.get()); internal_factory_ptrs_.push_back(internal_factory.get()); factories_.push_back(std::move(internal_factory)); @@ -52,7 +83,19 @@ Status EpLibraryProviderBridge::Load() { } Status EpLibraryProviderBridge::Unload() { + std::lock_guard lock{mutex_}; + + internal_factory_ptrs_.clear(); + factory_ptrs_.clear(); + factories_.clear(); + + // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. + ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); + ep_library_plugin_ = nullptr; + provider_library_->Unload(); + provider_library_ = nullptr; + return Status::OK(); } diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/ep_library_provider_bridge.h index 5f85192866cf4..3c7f083df227e 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/ep_library_provider_bridge.h @@ -3,6 +3,7 @@ #pragma once #include +#include #include "core/session/ep_library.h" #include "core/session/ep_factory_internal.h" @@ -44,10 +45,11 @@ class EpLibraryProviderBridge : public EpLibrary { ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibraryProviderBridge); private: + std::mutex mutex_; std::unique_ptr provider_library_; // provider bridge EP library // EpLibraryPlugin that provides the CreateEpFactories and ReleaseEpFactory implementations. - // we wrap the factories it contains to pass through GetDeviceInfoIfSupported calls, and + // we wrap the OrtEpFactory instances it contains to pass through GetSupportedDevices calls, and // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index f2d03610bec1e..dd5165eb12190 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -33,6 +33,7 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" +#include "core/session/ep_api.h" #include "core/session/ep_library_internal.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" @@ -2511,7 +2512,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS return nullptr; API_IMPL_END } -#else // defined(ORT_MINIMAL_BUILD) + +ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { + return OrtExecutionProviderApi::GetEpApi(); +} + +#else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { API_IMPL_BEGIN @@ -2545,6 +2551,12 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS API_IMPL_END } + +ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { + fprintf(stderr, "The Execution Provider API is not supported in a minimal build.\n"); + return nullptr; +} + #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -3012,6 +3024,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::EpDevice_EpMetadata, &OrtApis::EpDevice_EpOptions, &OrtApis::EpDevice_Device, + + &OrtApis::GetEpApi, // End of Version 22 - DO NOT MODIFY ABOVE (see above text for more information) }; @@ -3047,7 +3061,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo // no additions in version 19, 20, and 21 static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change"); -static_assert(offsetof(OrtApi, EpDevice_Device) / sizeof(void*) == 314, "Size of version 22 API cannot change"); +static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 315, "Size of version 22 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.23.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 76c5e7bf9c26b..0033eb0d604f2 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -588,4 +588,6 @@ ORT_API(const char*, EpDevice_EpVendor, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtKeyValuePairs*, EpDevice_EpMetadata, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtKeyValuePairs*, EpDevice_EpOptions, _In_ const OrtEpDevice* ep_device); ORT_API(const OrtHardwareDevice*, EpDevice_Device, _In_ const OrtEpDevice* ep_device); + +ORT_API(const OrtEpApi*, GetEpApi); } // namespace OrtApis diff --git a/onnxruntime/test/autoep/library/example_plugin_ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep.cc index b88ad3e896ea8..2c82b9ace3c61 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep.cc @@ -42,9 +42,10 @@ struct ExampleEp : OrtEp, ApiPtrs { struct ExampleEpFactory : OrtEpFactory, ApiPtrs { ExampleEpFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with. GetName = GetNameImpl; GetVendor = GetVendorImpl; - GetDeviceInfoIfSupported = GetDeviceInfoIfSupportedImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; } @@ -59,25 +60,56 @@ struct ExampleEpFactory : OrtEpFactory, ApiPtrs { return factory->vendor_.c_str(); } - static bool ORT_API_CALL GetDeviceInfoIfSupportedImpl(const OrtEpFactory* this_ptr, - const OrtHardwareDevice* device, - _Out_opt_ OrtKeyValuePairs** ep_metadata, - _Out_opt_ OrtKeyValuePairs** ep_options) { - const auto* factory = static_cast(this_ptr); - - if (factory->ort_api.HardwareDevice_Type(device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - // these can be returned as nullptr if you have nothing to add. - factory->ort_api.CreateKeyValuePairs(ep_metadata); - factory->ort_api.CreateKeyValuePairs(ep_options); - - // random example using made up values - factory->ort_api.AddKeyValuePair(*ep_metadata, "version", "0.1"); - factory->ort_api.AddKeyValuePair(*ep_options, "run_really_fast", "true"); + static OrtStatus* ORT_API_CALL GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); - return true; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + // C API + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // these can be returned as nullptr if you have nothing to add. + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_metadata); + factory->ort_api.CreateKeyValuePairs(&ep_options); + + // random example using made up values + factory->ort_api.AddKeyValuePair(ep_metadata, "version", "0.1"); + factory->ort_api.AddKeyValuePair(ep_options, "run_really_fast", "true"); + + // OrtEpDevice copies ep_metadata and ep_options. + auto* status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, + &ep_devices[num_ep_devices++]); + + factory->ort_api.ReleaseKeyValuePairs(ep_metadata); + factory->ort_api.ReleaseKeyValuePairs(ep_options); + + if (status != nullptr) { + return status; + } + } + + // C++ API equivalent. Throws on error. + //{ + // Ort::ConstHardwareDevice device(devices[i]); + // if (device.Type() == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + // Ort::KeyValuePairs ep_metadata; + // Ort::KeyValuePairs ep_options; + // ep_metadata.Add("version", "0.1"); + // ep_options.Add("run_really_fast", "true"); + // Ort::EpDevice ep_device{*this_ptr, device, ep_metadata.GetConst(), ep_options.GetConst()}; + // ep_devices[num_ep_devices++] = ep_device.release(); + // } + //} } - return false; + return nullptr; } static OrtStatus* ORT_API_CALL CreateEpImpl(OrtEpFactory* this_ptr, @@ -88,6 +120,7 @@ struct ExampleEpFactory : OrtEpFactory, ApiPtrs { _In_ const OrtLogger* logger, _Out_ OrtEp** ep) { auto* factory = static_cast(this_ptr); + *ep = nullptr; if (num_devices != 1) { // we only registered for CPU and only expected to be selected for one CPU diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_autoep_selection.cc index a7fa1bccf3210..b3e4e7221b560 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_autoep_selection.cc @@ -231,6 +231,8 @@ TEST(AutoEpSelection, MiscApiTests) { c_api->AddKeyValuePair(kvps, "key2", ""); // empty value is allowed ASSERT_EQ(c_api->GetKeyValue(kvps, "key2"), std::string("")); + + c_api->ReleaseKeyValuePairs(kvps); } // construct KVP from std::unordered_map