Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,7 @@ struct OrtEp {
* \param[in] registered_name The name the execution library is registered with by RegisterExecutionProviderLibrary
* \param[in] ort_api_base The OrtApiBase instance that is used by the factory to get the OrtApi instance for the
* version of ORT that the library was compiled against.
* \param[in] default_logger The default ORT logger that can be used for logging outside of an inference session.
* \param[in,out] factories The implementation should create and add OrtEpFactory instances to this
* pre-allocated array.
* i.e. usage is `factories[0] = new MyEpFactory();`
Expand All @@ -689,6 +690,7 @@ struct OrtEp {
* \since Version 1.22.
*/
typedef OrtStatus* (*CreateEpApiFactoriesFn)(_In_ const char* registered_name, _In_ const OrtApiBase* ort_api_base,
_In_ const OrtLogger* default_logger,
_Inout_ OrtEpFactory** factories, _In_ size_t max_factories,
_Out_ size_t* num_factories);

Expand Down
11 changes: 7 additions & 4 deletions onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,10 @@ struct CudaSyncStreamImpl : OrtSyncStreamImpl {
struct CudaEpFactory : OrtEpFactory {
using MemoryInfoUniquePtr = std::unique_ptr<OrtMemoryInfo, std::function<void(OrtMemoryInfo*)>>;

CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in},
ep_api{*ort_api_in.GetEpApi()},
data_transfer_impl{ort_api_in} {
CudaEpFactory(const OrtApi& ort_api_in, const OrtLogger& default_logger_in) : ort_api{ort_api_in},
default_logger{default_logger_in},
ep_api{*ort_api_in.GetEpApi()},
data_transfer_impl{ort_api_in} {
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetVendorId = GetVendorIdImpl;
Expand Down Expand Up @@ -911,6 +912,7 @@ struct CudaEpFactory : OrtEpFactory {

const OrtApi& ort_api;
const OrtEpApi& ep_api;
const OrtLogger& default_logger;
const std::string ep_name{kCudaExecutionProvider}; // EP name
const std::string vendor{"Microsoft"}; // EP vendor name
uint32_t vendor_id{0x1414}; // Microsoft vendor ID
Expand All @@ -935,12 +937,13 @@ extern "C" {
// Public symbols
//
OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base,
const OrtLogger* default_logger,
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
ErrorHelper::ort_api = ort_api; // setup our error helper

// Factory could use registration_name or define its own EP name.
std::unique_ptr<OrtEpFactory> factory = std::make_unique<CudaEpFactory>(*ort_api);
std::unique_ptr<OrtEpFactory> factory = std::make_unique<CudaEpFactory>(*ort_api, *default_logger);

if (max_factories < 1) {
return ort_api->CreateStatus(ORT_INVALID_ARGUMENT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@
// OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection.
struct NvTensorRtRtxEpFactory : OrtEpFactory {
NvTensorRtRtxEpFactory(const OrtApi& ort_api_in,
const OrtLogger& default_logger_in,
const char* ep_name,
OrtHardwareDeviceType hw_type)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
: ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} {
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
GetVersion = GetVersionImpl;
Expand Down Expand Up @@ -228,6 +229,7 @@
}

const OrtApi& ort_api;
const OrtLogger& default_logger;
const std::string ep_name;
const std::string vendor{"NVIDIA"};

Expand All @@ -241,11 +243,12 @@
// Public symbols
//
OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base,
const OrtLogger* default_logger,
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);

// Factory could use registration_name or define its own EP name.
auto factory_gpu = std::make_unique<NvTensorRtRtxEpFactory>(*ort_api,
auto factory_gpu = std::make_unique<NvTensorRtRtxEpFactory>(*ort_api, *default_logger,

Check warning on line 251 in onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc:251: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
onnxruntime::kNvTensorRTRTXExecutionProvider,
OrtHardwareDeviceType_GPU);

Expand Down
11 changes: 9 additions & 2 deletions onnxruntime/core/providers/qnn/qnn_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,15 @@
// OrtEpApi infrastructure to be able to use the QNN EP as an OrtEpFactory for auto EP selection.
struct QnnEpFactory : OrtEpFactory {
QnnEpFactory(const OrtApi& ort_api_in,
const OrtLogger& default_logger_in,
const char* ep_name,
OrtHardwareDeviceType hw_type,
const char* qnn_backend_type)
: ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, qnn_backend_type{qnn_backend_type} {
: ort_api{ort_api_in},
default_logger{default_logger_in},
ep_name{ep_name},
ort_hw_device_type{hw_type},
qnn_backend_type{qnn_backend_type} {
ort_version_supported = ORT_API_VERSION;
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
Expand Down Expand Up @@ -245,6 +250,7 @@
}

const OrtApi& ort_api;
const OrtLogger& default_logger;
const std::string ep_name; // EP name
const std::string ep_vendor{"Microsoft"}; // EP vendor name
uint32_t ep_vendor_id{0x1414}; // Microsoft vendor ID
Expand All @@ -260,11 +266,12 @@
// Public symbols
//
OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base,
const OrtLogger* default_logger,
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);

// Factory could use registration_name or define its own EP name.
auto factory_npu = std::make_unique<QnnEpFactory>(*ort_api,
auto factory_npu = std::make_unique<QnnEpFactory>(*ort_api, *default_logger,

Check warning on line 274 in onnxruntime/core/providers/qnn/qnn_provider_factory.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for make_unique<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/qnn/qnn_provider_factory.cc:274: Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
onnxruntime::kQnnExecutionProvider,
OrtHardwareDeviceType_NPU, "htp");

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ep_library_plugin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "core/session/ep_library_plugin.h"

#include "core/common/logging/logging.h"
#include "core/framework/error_code_helper.h"
#include "core/session/environment.h"

Expand All @@ -24,6 +25,7 @@ Status EpLibraryPlugin::Load() {

size_t num_factories = 0;
ORT_RETURN_IF_ERROR(ToStatusAndRelease(create_fn_(registration_name_.c_str(), OrtGetApiBase(),
logging::LoggingManager::DefaultLogger().ToExternal(),
factories.data(), factories.size(), &num_factories)));

for (size_t i = 0; i < num_factories; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/autoep/library/ep_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
#include "ep_data_transfer.h"
#include "ep_stream_support.h"

ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis)
: ApiPtrs(apis), ep_name_{ep_name} {
ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger)
: ApiPtrs(apis), default_logger_{default_logger}, ep_name_{ep_name} {
ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with.
GetName = GetNameImpl;
GetVendor = GetVendorImpl;
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/autoep/library/ep_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
/// </summary>
class ExampleEpFactory : public OrtEpFactory, public ApiPtrs {
public:
ExampleEpFactory(const char* ep_name, ApiPtrs apis);
ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtLogger& default_logger);

OrtDataTransferImpl* GetDataTransfer() const {
return data_transfer_impl_.get();
Expand Down Expand Up @@ -59,6 +59,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs {
const OrtKeyValuePairs* stream_options,
OrtSyncStreamImpl** stream) noexcept;

const OrtLogger& default_logger_; // default logger for the EP factory
const std::string ep_name_; // EP name
const std::string vendor_{"Contoso"}; // EP vendor name
const uint32_t vendor_id_{0xB357}; // EP vendor ID
Expand Down
4 changes: 3 additions & 1 deletion onnxruntime/test/autoep/library/example_plugin_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ extern "C" {
// Public symbols
//
EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const OrtApiBase* ort_api_base,
const OrtLogger* default_logger,
OrtEpFactory** factories, size_t max_factories, size_t* num_factories) {
const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION);
const OrtEpApi* ep_api = ort_api->GetEpApi();
Expand All @@ -23,7 +24,8 @@ EXPORT_SYMBOL OrtStatus* CreateEpFactories(const char* registration_name, const
// Factory could use registration_name or define its own EP name.
std::unique_ptr<OrtEpFactory> factory = std::make_unique<ExampleEpFactory>(registration_name,
ApiPtrs{*ort_api, *ep_api,
*model_editor_api});
*model_editor_api},
*default_logger);

if (max_factories < 1) {
return ort_api->CreateStatus(ORT_INVALID_ARGUMENT,
Expand Down
Loading