Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 48 additions & 42 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ struct OrtEp {
*
* \since Version 1.22.
*/
const char*(ORT_API_CALL* GetName)(_In_ const OrtEp* this_ptr);
ORT_API_T(const char*, GetName, _In_ const OrtEp* this_ptr);

/** \brief Get information about the nodes supported by the OrtEp instance.
*
Expand All @@ -376,8 +376,8 @@ struct OrtEp {
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* GetCapability)(_In_ OrtEp* this_ptr, _In_ const OrtGraph* graph,
_Inout_ OrtEpGraphSupportInfo* graph_support_info);
ORT_API2_STATUS(GetCapability, _In_ OrtEp* this_ptr, _In_ const OrtGraph* graph,
_Inout_ OrtEpGraphSupportInfo* graph_support_info);

/** \brief Compile OrtGraph instances assigned to the OrtEp. Implementer must set a OrtNodeComputeInfo instance
* for each OrtGraph in order to define its computation function.
Expand Down Expand Up @@ -416,10 +416,10 @@ struct OrtEp {
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* Compile)(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
_In_ const OrtNode** fused_nodes, _In_ size_t count,
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
_Out_writes_(count) OrtNode** ep_context_nodes);
ORT_API2_STATUS(Compile, _In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
_In_ const OrtNode** fused_nodes, _In_ size_t count,
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
_Out_writes_(count) OrtNode** ep_context_nodes);

/** \brief Release OrtNodeComputeInfo instances.
*
Expand All @@ -429,9 +429,9 @@ struct OrtEp {
*
* \since Version 1.23.
*/
void(ORT_API_CALL* ReleaseNodeComputeInfos)(_In_ OrtEp* this_ptr,
OrtNodeComputeInfo** node_compute_infos,
_In_ size_t num_node_compute_infos);
ORT_API_T(void, ReleaseNodeComputeInfos, _In_ OrtEp* this_ptr,
OrtNodeComputeInfo** node_compute_infos,
_In_ size_t num_node_compute_infos);

/** \brief Get the EP's preferred data layout.
*
Expand All @@ -445,8 +445,7 @@ struct OrtEp {
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* GetPreferredDataLayout)(_In_ OrtEp* this_ptr,
_Out_ OrtEpDataLayout* preferred_data_layout);
ORT_API2_STATUS(GetPreferredDataLayout, _In_ OrtEp* this_ptr, _Out_ OrtEpDataLayout* preferred_data_layout);

/** \brief Given an op with domain `domain` and type `op_type`, determine whether an associated node's data layout
* should be converted to `target_data_layout`.
Expand All @@ -470,11 +469,10 @@ struct OrtEp {
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* ShouldConvertDataLayoutForOp)(_In_ OrtEp* this_ptr,
_In_z_ const char* domain,
_In_z_ const char* op_type,
_In_ OrtEpDataLayout target_data_layout,
_Outptr_ int* should_convert);
ORT_API2_STATUS(ShouldConvertDataLayoutForOp, _In_ OrtEp* this_ptr,
_In_z_ const char* domain, _In_z_ const char* op_type,
_In_ OrtEpDataLayout target_data_layout,
_Outptr_ int* should_convert);

/** \brief Set dynamic options on this EP.
*
Expand All @@ -492,10 +490,10 @@ struct OrtEp {
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* SetDynamicOptions)(_In_ OrtEp* this_ptr,
_In_reads_(num_options) const char* const* option_keys,
_In_reads_(num_options) const char* const* option_values,
_In_ size_t num_options);
ORT_API2_STATUS(SetDynamicOptions, _In_ OrtEp* this_ptr,
_In_reads_(num_options) const char* const* option_keys,
_In_reads_(num_options) const char* const* option_values,
_In_ size_t num_options);

/** \brief Called by ORT to notify the EP of the start of a run.
*
Expand All @@ -508,8 +506,7 @@ struct OrtEp {
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* OnRunStart)(_In_ OrtEp* this_ptr,
_In_ const OrtRunOptions* run_options);
ORT_API2_STATUS(OnRunStart, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options);

/** \brief Called by ORT to notify the EP of the end of a run.
*
Expand All @@ -524,9 +521,7 @@ struct OrtEp {
*
* \since Version 1.23.
*/
OrtStatus*(ORT_API_CALL* OnRunEnd)(_In_ OrtEp* this_ptr,
_In_ const OrtRunOptions* run_options,
_In_ bool sync_stream);
ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream);
};

/** \brief The function signature that ORT will call to create OrtEpFactory instances.
Expand Down Expand Up @@ -586,7 +581,7 @@ struct OrtEpFactory {
*
* \since Version 1.22.
*/
const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr);
ORT_API_T(const char*, GetName, const OrtEpFactory* this_ptr);

/** \brief Get the name of vendor who owns the execution provider that the factory creates.
*
Expand All @@ -597,7 +592,7 @@ struct OrtEpFactory {
*
* \since Version 1.22.
*/
const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor
ORT_API_T(const char*, GetVendor, const OrtEpFactory* this_ptr); // return EP vendor

/** \brief Get information from the execution provider about OrtHardwareDevice support.
*
Expand All @@ -616,12 +611,12 @@ struct OrtEpFactory {
*
* \since Version 1.22.
*/
OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_ size_t num_devices,
_Inout_ OrtEpDevice** ep_devices,
_In_ size_t max_ep_devices,
_Out_ size_t* num_ep_devices);
ORT_API2_STATUS(GetSupportedDevices, _In_ OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_ size_t num_devices,
_Inout_ OrtEpDevice** ep_devices,
_In_ size_t max_ep_devices,
_Out_ size_t* num_ep_devices);

/** \brief Function to create an OrtEp instance for use in a Session.
*
Expand All @@ -647,12 +642,12 @@ struct OrtEpFactory {
*
* \since Version 1.22.
*/
OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs,
_In_ size_t num_devices,
_In_ const OrtSessionOptions* session_options,
_In_ const OrtLogger* logger, _Outptr_ OrtEp** ep);
ORT_API2_STATUS(CreateEp, _In_ OrtEpFactory* this_ptr,
_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
_In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs,
_In_ size_t num_devices,
_In_ const OrtSessionOptions* session_options,
_In_ const OrtLogger* logger, _Outptr_ OrtEp** ep);

/** \brief Release the OrtEp instance.
*
Expand All @@ -661,7 +656,18 @@ struct OrtEpFactory {
*
* \since Version 1.22.
*/
void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep);
ORT_API_T(void, ReleaseEp, OrtEpFactory* this_ptr, struct OrtEp* ep);

/** \brief Get the vendor id who owns the execution provider that the factory creates.
*
* This is typically the PCI vendor ID. See https://pcisig.com/membership/member-companies
*
* \param[in] this_ptr The OrtEpFactory instance.
* \return vendor_id The vendor ID of the execution provider the factory creates.
*
* \since Version 1.23.
*/
ORT_API_T(uint32_t, GetVendorId, const OrtEpFactory* this_ptr);

/** \brief Get the version of the execution provider that the factory creates.
*
Expand All @@ -675,7 +681,7 @@ struct OrtEpFactory {
*
* \since Version 1.23.
*/
const char*(ORT_API_CALL* GetVersion)(_In_ const OrtEpFactory* this_ptr);
ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr);

/** \brief Create an OrtAllocator for the given OrtMemoryInfo.
*
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,10 @@ CUDA_Provider* GetProvider() {
// OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection.
struct CudaEpFactory : OrtEpFactory {
CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} {
ort_version_supported = ORT_API_VERSION;
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetVendorId = GetVendorIdImpl;
GetVersion = GetVersionImpl;
GetSupportedDevices = GetSupportedDevicesImpl;
CreateEp = CreateEpImpl;
Expand All @@ -331,6 +333,11 @@ struct CudaEpFactory : OrtEpFactory {
return factory->vendor.c_str();
}

static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const CudaEpFactory*>(this_ptr);
return factory->vendor_id;
}

static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
return ORT_VERSION;
}
Expand Down Expand Up @@ -374,6 +381,7 @@ struct CudaEpFactory : OrtEpFactory {
const OrtApi& ort_api;
const std::string ep_name{kCudaExecutionProvider}; // EP name
const std::string vendor{"Microsoft"}; // EP vendor name
uint32_t vendor_id{0x1414}; // Microsoft vendor ID
};

extern "C" {
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/core/providers/qnn/qnn_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,10 @@ struct QnnEpFactory : OrtEpFactory {
OrtHardwareDeviceType hw_type,
const char* qnn_backend_type)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} {
ort_version_supported = ORT_API_VERSION;
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetVendorId = GetVendorIdImpl;
GetVersion = GetVersionImpl;
GetSupportedDevices = GetSupportedDevicesImpl;
CreateEp = CreateEpImpl;
Expand All @@ -142,7 +144,12 @@ struct QnnEpFactory : OrtEpFactory {

static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const QnnEpFactory*>(this_ptr);
return factory->vendor.c_str();
return factory->ep_vendor.c_str();
}

static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept {
const auto* factory = static_cast<const QnnEpFactory*>(this_ptr);
return factory->ep_vendor_id;
}

static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept {
Expand Down Expand Up @@ -195,8 +202,9 @@ struct QnnEpFactory : OrtEpFactory {
}

const OrtApi& ort_api;
const std::string ep_name; // EP name
const std::string vendor{"Microsoft"}; // EP vendor name
const std::string ep_name; // EP name
const std::string ep_vendor{"Microsoft"}; // EP vendor name
uint32_t ep_vendor_id{0x1414}; // Microsoft vendor ID

// Qualcomm vendor ID. Refer to the ACPI ID registry (search Qualcomm): https://uefi.org/ACPI_ID_List
const uint32_t vendor_id{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)};
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/ep_api_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ struct ForwardToFactory {
return static_cast<const TFactory*>(this_ptr)->GetVendor();
}

static uint32_t ORT_API_CALL GetVendorId(const OrtEpFactory* this_ptr) noexcept {
return static_cast<const TFactory*>(this_ptr)->GetVendorId();
}

static const char* ORT_API_CALL GetVersion(const OrtEpFactory* this_ptr) noexcept {
return static_cast<const TFactory*>(this_ptr)->GetVersion();
}
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/session/ep_factory_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,19 @@

using Forward = ForwardToFactory<EpFactoryInternal>;

EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor,
EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id,

Check warning on line 17 in onnxruntime/core/session/ep_factory_internal.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/ep_factory_internal.cc:17: Add #include <string> for string [build/include_what_you_use] [4]
GetSupportedFunc&& get_supported_func,
CreateFunc&& create_func)
: ep_name_{ep_name},
vendor_{vendor},
vendor_id_{vendor_id},
get_supported_func_{std::move(get_supported_func)},
create_func_{create_func} {
ort_version_supported = ORT_API_VERSION;

OrtEpFactory::GetName = Forward::GetFactoryName;
OrtEpFactory::GetVendor = Forward::GetVendor;
OrtEpFactory::GetVendorId = Forward::GetVendorId;
OrtEpFactory::GetVersion = Forward::GetVersion;
OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices;
OrtEpFactory::CreateEp = Forward::CreateEp;
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/core/session/ep_factory_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ class EpFactoryInternal : public OrtEpFactory {
const OrtSessionOptions* session_options,
const OrtLogger* logger, std::unique_ptr<IExecutionProvider>* ep)>;

EpFactoryInternal(const std::string& ep_name, const std::string& vendor,
EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id,
GetSupportedFunc&& get_supported_func,
CreateFunc&& create_func);

const char* GetName() const noexcept { return ep_name_.c_str(); }
const char* GetVendor() const noexcept { return vendor_.c_str(); }
uint32_t GetVendorId() const noexcept { return vendor_id_; }
const char* GetVersion() const noexcept;

OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices,
Expand Down Expand Up @@ -67,6 +68,7 @@ class EpFactoryInternal : public OrtEpFactory {
private:
const std::string ep_name_; // EP name library was registered with
const std::string vendor_; // EP vendor name
const uint32_t vendor_id_; // EP vendor ID
const GetSupportedFunc get_supported_func_; // function to return supported devices
const CreateFunc create_func_; // function to create the EP instance

Expand Down
9 changes: 6 additions & 3 deletions onnxruntime/core/session/ep_library_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ std::unique_ptr<EpLibraryInternal> EpLibraryInternal::CreateCpuEp() {
};

std::string ep_name = kCpuExecutionProvider;
auto cpu_factory = std::make_unique<EpFactoryInternal>(ep_name, "Microsoft", get_supported, create_cpu_ep);
auto cpu_factory = std::make_unique<EpFactoryInternal>(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT,
get_supported, create_cpu_ep);
return std::make_unique<EpLibraryInternal>(std::move(cpu_factory));
}

Expand Down Expand Up @@ -122,7 +123,8 @@ std::unique_ptr<EpLibraryInternal> EpLibraryInternal::CreateDmlEp() {
return nullptr;
};

auto dml_factory = std::make_unique<EpFactoryInternal>(ep_name, "Microsoft", is_supported, create_dml_ep);
auto dml_factory = std::make_unique<EpFactoryInternal>(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT,
is_supported, create_dml_ep);

return std::make_unique<EpLibraryInternal>(std::move(dml_factory));
}
Expand Down Expand Up @@ -170,7 +172,8 @@ std::unique_ptr<EpLibraryInternal> EpLibraryInternal::CreateWebGpuEp() {
return nullptr;
};

auto webgpu_factory = std::make_unique<EpFactoryInternal>(ep_name, "Microsoft", is_supported, create_webgpu_ep);
auto webgpu_factory = std::make_unique<EpFactoryInternal>(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT,
is_supported, create_webgpu_ep);

return std::make_unique<EpLibraryInternal>(std::move(webgpu_factory));
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/ep_library_provider_bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ Status EpLibraryProviderBridge::Load() {

auto internal_factory = std::make_unique<EpFactoryInternal>(factory->GetName(factory),
factory->GetVendor(factory),
factory->GetVendorId(factory),
is_supported_fn,
create_fn);
factory_ptrs_.push_back(internal_factory.get());
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/session/provider_policy_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
namespace onnxruntime {
namespace {
bool MatchesEpVendor(const OrtEpDevice* d) {
// TODO: Would be better to match on Id. Should the EP add that in EP metadata?
// match on vendor id if provided
uint32_t factory_vendor_id = d->ep_factory->GetVendorId(d->ep_factory);
if (factory_vendor_id != 0 && d->device->vendor_id == factory_vendor_id) {
return true;
}

// match on vendor name
return d->device->vendor == d->ep_vendor;
}

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/test/autoep/library/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* graph) {

/*static*/
OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph,
OrtEpGraphSupportInfo* graph_support_info) {
OrtEpGraphSupportInfo* graph_support_info) noexcept {
ExampleEp* ep = static_cast<ExampleEp*>(this_ptr);

size_t num_nodes = 0;
Expand Down Expand Up @@ -290,7 +290,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs,
_In_ const OrtNode** fused_nodes, _In_ size_t count,
_Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos,
_Out_writes_(count) OrtNode** ep_context_nodes) {
_Out_writes_(count) OrtNode** ep_context_nodes) noexcept {
ExampleEp* ep = static_cast<ExampleEp*>(this_ptr);
const OrtApi& ort_api = ep->ort_api;

Expand Down Expand Up @@ -354,7 +354,7 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
/*static*/
void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr,
OrtNodeComputeInfo** node_compute_infos,
size_t num_node_compute_infos) {
size_t num_node_compute_infos) noexcept {
(void)this_ptr;
for (size_t i = 0; i < num_node_compute_infos; i++) {
delete node_compute_infos[i];
Expand Down
Loading
Loading