Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 52 additions & 13 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,9 @@ typedef struct OrtModelEditorApi OrtModelEditorApi;
struct OrtCompileApi;
typedef struct OrtCompileApi OrtCompileApi;

struct OrtEpApi;
typedef struct OrtEpApi OrtEpApi;

/** \brief The helper interface to get the right version of OrtApi
*
* Get a pointer to this structure through ::OrtGetApiBase
Expand Down Expand Up @@ -5186,6 +5189,12 @@ struct OrtApi {
* \since Version 1.22.
*/
const OrtHardwareDevice*(ORT_API_CALL* EpDevice_Device)(_In_ const OrtEpDevice* ep_device);

/** \brief Get the OrtEpApi instance for implementing an execution provider.
*
* \since Version 1.22.
*/
const OrtEpApi*(ORT_API_CALL* GetEpApi)();
};

/*
Expand Down Expand Up @@ -5889,6 +5898,29 @@ struct OrtCompileApi {
ORT_RUNTIME_CLASS(Ep);
ORT_RUNTIME_CLASS(EpFactory);

struct OrtEpApi {
/** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice.
* \param[in] ep_factory Execution provider factory that is creating the instance.
* \param[in] hardware_device Hardware device that the EP can utilize.
* \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used
* during execution provider selection and passed to CreateEp.
* ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
* \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added
* to the Session configuration options if the execution provider is selected.
* ep_device will copy this instance and the user should call ReleaseKeyValuePairs.
* \param ep_device OrtExecutionDevice that is created.
*
* \since Version 1.22.
*/
ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory,
_In_ const OrtHardwareDevice* hardware_device,
_In_opt_ const OrtKeyValuePairs* ep_metadata,
_In_opt_ const OrtKeyValuePairs* ep_options,
_Out_ OrtEpDevice** ep_device);

ORT_CLASS_RELEASE(EpDevice);
};

/**
* \brief The OrtEp struct provides functions to implement for an execution provider.
* \since Version 1.22.
Expand Down Expand Up @@ -5993,33 +6025,40 @@ struct OrtEpFactory {
/** \brief Get information from the execution provider if it supports the OrtHardwareDevice.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[in] device The OrtHardwareDevice instance.
* \param[out] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used
* during execution provider selection and/or CreateEp.
* \param[out] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added
* to the Session configuration options if the execution provider is selected.
* Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice.
* \param[in] devices The OrtHardwareDevice instances that are available.
* \param[in] num_devices The number of OrtHardwareDevice instances.
* \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use.
* The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice
* instances to this pre-allocated array. ORT will take ownership of the values returned.
* i.e. usage is `ep_devices[0] = <ptr to OrtEpDevice created with OrtEpApi::CreateEpDevice>;`
* \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices.
* Current default is 8. This can be increased if needed.
* \param[out] num_ep_devices The number of EP devices added to ep_devices.
* \return true if the factory can create an execution provider that uses `device`.
*
* \note ORT will take ownership or ep_metadata and/or ep_options if they are not null.
*
* \since Version 1.22.
*/
bool(ORT_API_CALL* GetDeviceInfoIfSupported)(const OrtEpFactory* this_ptr,
_In_ const OrtHardwareDevice* device,
_Out_opt_ OrtKeyValuePairs** ep_metadata,
_Out_opt_ OrtKeyValuePairs** ep_options);
OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_ size_t num_devices,
_Inout_ OrtEpDevice** ep_devices,
_In_ size_t max_ep_devices,
_Out_ size_t* num_ep_devices);

/** \brief Function to create an OrtEp instance for use in a Session.
*
* ORT will call ReleaseEp to release the instance when it is no longer needed.
*
* \param[in] this_ptr The OrtEpFactory instance.
* \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use.
* \param[in] ep_metadata_pairs Execution provider metadata that was returned in GetDeviceInfoIfSupported, for each
* \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each
* device.
* \param[in] num_devices The number of devices the execution provider was selected for.
* \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the
* session. This will include ep_options from GetDeviceInfoIfSupported as well as any
* session. This will include ep_options from GetSupportedDevices as well as any
* user provided overrides.
* Execution provider options will have been added with a prefix of 'ep.<ep name>.'.
* The OrtSessionOptions instance will NOT be valid after this call and should not be
Expand All @@ -6029,7 +6068,7 @@ struct OrtEpFactory {
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.22.
* \since Version <coming soon>. This is a placeholder.
*/
OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
Expand All @@ -6043,7 +6082,7 @@ struct OrtEpFactory {
* \param[in] this_ptr The OrtEpFactory instance.
* \param[in] ep The OrtEp instance to release.
*
* \since Version 1.22.
* \since Version <coming soon>. This is a placeholder.
*/
void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep);
};
Expand Down
34 changes: 33 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,20 @@ inline const OrtCompileApi& GetCompileApi() {
return *api;
}

/// <summary>
/// This returns a reference to the ORT C EP API. Used if authoring a plugin execution provider.
/// </summary>
/// <returns>ORT C EP API reference</returns>
inline const OrtEpApi& GetEpApi() {
auto* api = GetApi().GetEpApi();
if (api == nullptr) {
// minimal build
ORT_CXX_API_THROW("EP API is not available in this build", ORT_FAIL);
}

return *api;
}

/** \brief IEEE 754 half-precision floating point data type
*
* \details This struct is used for converting float to float16 and back
Expand Down Expand Up @@ -561,6 +575,7 @@ ORT_DEFINE_RELEASE(Graph);
ORT_DEFINE_RELEASE(Model);
ORT_DEFINE_RELEASE(KeyValuePairs)
ORT_DEFINE_RELEASE_FROM_API_STRUCT(ModelCompilationOptions, GetCompileApi);
ORT_DEFINE_RELEASE_FROM_API_STRUCT(EpDevice, GetEpApi);

#undef ORT_DEFINE_RELEASE
#undef ORT_DEFINE_RELEASE_FROM_API_STRUCT
Expand Down Expand Up @@ -763,10 +778,16 @@ struct KeyValuePairs : detail::KeyValuePairsImpl<OrtKeyValuePairs> {
/// Take ownership of a pointer created by C API
explicit KeyValuePairs(OrtKeyValuePairs* p) : KeyValuePairsImpl<OrtKeyValuePairs>{p} {}

/// \brief Wraps OrtApi::CreateKeyValuePairs
explicit KeyValuePairs();

/// \brief Wraps OrtApi::CreateKeyValuePairs and OrtApi::AddKeyValuePair
explicit KeyValuePairs(const std::unordered_map<std::string, std::string>& kv_pairs);

/// \brief Wraps OrtApi::AddKeyValuePair
void Add(const char* key, const char* value);

/// \brief Wraps OrtApi::RemoveKeyValuePair
void Remove(const char* key);

ConstKeyValuePairs GetConst() const { return ConstKeyValuePairs{this->p_}; }
Expand Down Expand Up @@ -806,10 +827,21 @@ struct EpDeviceImpl : Ort::detail::Base<T> {
} // namespace detail

/** \brief Wrapper around ::OrtEpDevice
* \remarks EpDevice is always read-only for API users.
* \remarks EpDevice is always read-only for ORT API users.
*/
using ConstEpDevice = detail::EpDeviceImpl<Ort::detail::Unowned<const OrtEpDevice>>;

/** \brief Mutable EpDevice that is created by EpApi users.
*/
struct EpDevice : detail::EpDeviceImpl<OrtEpDevice> {
explicit EpDevice(std::nullptr_t) {} ///< No instance is created
explicit EpDevice(OrtEpDevice* p) : EpDeviceImpl<OrtEpDevice>{p} {} ///< Take ownership of a pointer created by C API

/// \brief Wraps OrtEpApi::CreateEpDevice
EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
ConstKeyValuePairs ep_metadata = {}, ConstKeyValuePairs ep_options = {});
};

/** \brief The Env (Environment)
*
* The Env holds the logging state used by all other objects.
Expand Down
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,11 @@ inline ConstHardwareDevice EpDeviceImpl<T>::Device() const {
}
} // namespace detail

inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardware_device,
ConstKeyValuePairs ep_metadata, ConstKeyValuePairs ep_options) {
ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_));
}

inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) {
ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_));
if (strcmp(logid, "onnxruntime-node") == 0) {
Expand Down
32 changes: 20 additions & 12 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ struct CudaEpFactory : OrtEpFactory {
CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} {
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetDeviceInfoIfSupported = GetDeviceInfoIfSupportedImpl;
GetSupportedDevices = GetSupportedDevicesImpl;
CreateEp = CreateEpImpl;
ReleaseEp = ReleaseEpImpl;
}
Expand All @@ -329,18 +329,26 @@ struct CudaEpFactory : OrtEpFactory {
return factory->vendor.c_str();
}

static bool GetDeviceInfoIfSupportedImpl(const OrtEpFactory* this_ptr,
const OrtHardwareDevice* device,
_Out_opt_ OrtKeyValuePairs** /*ep_metadata*/,
_Out_opt_ OrtKeyValuePairs** /*ep_options*/) {
const auto* factory = static_cast<const CudaEpFactory*>(this_ptr);

if (factory->ort_api.HardwareDevice_Type(device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU &&
factory->ort_api.HardwareDevice_VendorId(device) == 0x10de) {
return true;
static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr,
const OrtHardwareDevice* const* devices,
size_t num_devices,
OrtEpDevice** ep_devices,
size_t max_ep_devices,
size_t* p_num_ep_devices) {
size_t& num_ep_devices = *p_num_ep_devices;
auto* factory = static_cast<CudaEpFactory*>(this_ptr);

for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) {
const OrtHardwareDevice& device = *devices[i];
if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU &&
factory->ort_api.HardwareDevice_VendorId(&device) == 0x10de) {
ORT_API_RETURN_IF_ERROR(
factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, nullptr,
&ep_devices[num_ep_devices++]));
}
}

return false;
return nullptr;
}

static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/,
Expand Down Expand Up @@ -385,7 +393,7 @@ OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase
}

OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) {
delete factory;
delete static_cast<CudaEpFactory*>(factory);
return nullptr;
}
}
60 changes: 37 additions & 23 deletions onnxruntime/core/session/environment.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@

#include "core/session/environment.h"

#include <array>

#include "core/common/basic_types.h"
#include "core/framework/allocator_utils.h"
#include "core/framework/error_code_helper.h"
#include "core/graph/constants.h"
#include "core/graph/op.h"
#include "core/platform/device_discovery.h"
Expand Down Expand Up @@ -468,6 +471,28 @@
return status;
}

namespace {
std::vector<const OrtHardwareDevice*> SortDevicesByType() {
auto& devices = DeviceDiscovery::GetDevices();
std::vector<const OrtHardwareDevice*> sorted_devices;
sorted_devices.reserve(devices.size());

const auto select_by_type = [&](OrtHardwareDeviceType type) {
for (const auto& device : devices) {
if (device.type == type) {
sorted_devices.push_back(&device);
}
}
};

select_by_type(OrtHardwareDeviceType_NPU);
select_by_type(OrtHardwareDeviceType_GPU);
select_by_type(OrtHardwareDeviceType_CPU);

return sorted_devices;
}
} // namespace

Status Environment::EpInfo::Create(std::unique_ptr<EpLibrary> library_in, std::unique_ptr<EpInfo>& out,
const std::vector<EpFactoryInternal*>& internal_factories) {
if (!library_in) {
Expand All @@ -482,36 +507,25 @@
ORT_RETURN_IF_ERROR(instance.library->Load());
const auto& factories = instance.library->GetFactories();

// OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured.
// the set of hardware devices is static so this can also be static.
const static std::vector<const OrtHardwareDevice*> sorted_devices = SortDevicesByType();

Check warning on line 512 in onnxruntime/core/session/environment.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Storage-class specifier (static, extern, typedef, etc) should be at the beginning of the declaration. [build/storage_class] [5] Raw Output: onnxruntime/core/session/environment.cc:512: Storage-class specifier (static, extern, typedef, etc) should be at the beginning of the declaration. [build/storage_class] [5]

Check warning on line 512 in onnxruntime/core/session/environment.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/environment.cc:512: Add #include <vector> for vector<> [build/include_what_you_use] [4]

for (auto* factory_ptr : factories) {
ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:",
instance.library->RegistrationName());

auto& factory = *factory_ptr;

// for each device
for (const auto& device : DeviceDiscovery::GetDevices()) {
OrtKeyValuePairs* ep_metadata = nullptr;
OrtKeyValuePairs* ep_options = nullptr;

if (factory.GetDeviceInfoIfSupported(&factory, &device, &ep_metadata, &ep_options)) {
auto ed = std::make_unique<OrtEpDevice>();
ed->ep_name = factory.GetName(&factory);
ed->ep_vendor = factory.GetVendor(&factory);
ed->device = &device;

if (ep_metadata) {
ed->ep_metadata = std::move(*ep_metadata);
delete ep_metadata;
}

if (ep_options) {
ed->ep_options = std::move(*ep_options);
delete ep_options;
}

ed->ep_factory = &factory;
std::array<OrtEpDevice*, 8> ep_devices{nullptr};
size_t num_ep_devices = 0;
ORT_RETURN_IF_ERROR(ToStatus(
factory.GetSupportedDevices(&factory, sorted_devices.data(), sorted_devices.size(),
ep_devices.data(), ep_devices.size(), &num_ep_devices)));

instance.execution_devices.push_back(std::move(ed));
for (size_t i = 0; i < num_ep_devices; ++i) {
if (ep_devices[i] != nullptr) { // should never happen but just in case...
instance.execution_devices.emplace_back(ep_devices[i]); // take ownership
}
}
}
Expand Down
Loading
Loading