Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ typedef enum OrtExecutionProviderDevicePolicy {
* \param max_ep_devices The maximum number of devices that can be selected in the pre-allocated array.
Currently the maximum is 8.
* \param num_ep_devices The number of selected devices.
* \param state Opaque pointer. Required to use the delegate from other languages like C# and python.
*
* \return OrtStatus* Selection status. Return nullptr on success.
* Use CreateStatus to provide error info. Use ORT_FAIL as the error code.
Expand All @@ -449,7 +450,8 @@ typedef OrtStatus* (*EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices,
_In_opt_ const OrtKeyValuePairs* runtime_metadata,
_Inout_ const OrtEpDevice** selected,
_In_ size_t max_selected,
_Out_ size_t* num_selected);
_Out_ size_t* num_selected,
_In_ void* state);

/** \brief Algorithm to use for cuDNN Convolution Op
*/
Expand Down Expand Up @@ -5127,18 +5129,30 @@ struct OrtApi {

/** \brief Set the execution provider selection policy for the session.
*
* Allows users to specify a device selection policy for automatic execution provider (EP) selection,
* or provide a delegate callback for custom selection logic.
* Allows users to specify a device selection policy for automatic execution provider (EP) selection.
* If custom selection is required please use SessionOptionsSetEpSelectionPolicyDelegate instead.
*
* \param[in] session_options The OrtSessionOptions instance.
* \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy).
* \param[in] delegate Optional delegate callback for custom selection. Pass nullptr to use the built-in policy.
*
* \since Version 1.22
*/
ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options,
_In_ OrtExecutionProviderDevicePolicy policy,
_In_opt_ EpSelectionDelegate* delegate);
_In_ OrtExecutionProviderDevicePolicy policy);

/** \brief Set the execution provider selection policy delegate for the session.
*
* Allows users to provide a custom device selection policy for automatic execution provider (EP) selection.
*
* \param[in] session_options The OrtSessionOptions instance.
* \param[in] delegate Delegate callback for custom selection.
* \param[in] delegate_state Optional state that will be passed to the delegate callback. nullptr if not required.
*
* \since Version 1.22
*/
ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* session_options,
_In_ EpSelectionDelegate delegate,
_In_opt_ void* delegate_state);

/** \brief Get the hardware device type.
*
Expand Down
6 changes: 4 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1103,8 +1103,10 @@ struct SessionOptionsImpl : ConstSessionOptionsImpl<T> {
const std::unordered_map<std::string, std::string>& ep_options);

/// Wraps OrtApi::SessionOptionsSetEpSelectionPolicy
SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy,
EpSelectionDelegate* delegate = nullptr);
SessionOptionsImpl& SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy);

/// Wraps OrtApi::SessionOptionsSetEpSelectionPolicyDelegate
SessionOptionsImpl& SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state = nullptr);

SessionOptionsImpl& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn
SessionOptionsImpl& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions
Expand Down
11 changes: 8 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -1150,9 +1150,14 @@ inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::AppendExecutionProvider_V2(
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy,
EpSelectionDelegate* delegate) {
ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy, delegate));
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy policy) {
ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicy(this->p_, policy));
return *this;
}

template <typename T>
inline SessionOptionsImpl<T>& SessionOptionsImpl<T>::SetEpSelectionPolicy(EpSelectionDelegate delegate, void* state) {
ThrowOnError(GetApi().SessionOptionsSetEpSelectionPolicyDelegate(this->p_, delegate, state));
return *this;
}

Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/framework/session_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ struct EpSelectionPolicy {
// and no selection policy was explicitly provided.
bool enable{false};
OrtExecutionProviderDevicePolicy policy = OrtExecutionProviderDevicePolicy_DEFAULT;
EpSelectionDelegate* delegate{};
EpSelectionDelegate delegate{};
void* state{nullptr}; // state for the delegate
};

/**
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/session/abi_session_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,24 @@ ORT_API_STATUS_IMPL(OrtApis::SetDeterministicCompute, _Inout_ OrtSessionOptions*
}

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* options,
_In_ OrtExecutionProviderDevicePolicy policy,
_In_opt_ EpSelectionDelegate* delegate) {
_In_ OrtExecutionProviderDevicePolicy policy) {
API_IMPL_BEGIN
options->value.ep_selection_policy.enable = true;
options->value.ep_selection_policy.policy = policy;
options->value.ep_selection_policy.delegate = nullptr;
options->value.ep_selection_policy.state = nullptr;
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* options,
_In_opt_ EpSelectionDelegate delegate,
_In_opt_ void* state) {
API_IMPL_BEGIN
options->value.ep_selection_policy.enable = true;
options->value.ep_selection_policy.policy = OrtExecutionProviderDevicePolicy_DEFAULT;
options->value.ep_selection_policy.delegate = delegate;
options->value.ep_selection_policy.state = state;
return nullptr;
API_IMPL_END
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3264,6 +3264,7 @@ common::Status InferenceSession::SaveModelMetadata(const onnxruntime::Model& mod

// save model metadata
model_metadata_.producer_name = model.ProducerName();
model_metadata_.producer_version = model.ProducerVersion();
model_metadata_.description = model.DocString();
model_metadata_.graph_description = model.GraphDocString();
model_metadata_.domain = model.Domain();
Expand Down Expand Up @@ -3428,6 +3429,10 @@ const Model& InferenceSession::GetModel() const {
return *model_;
}

const Environment& InferenceSession::GetEnvironment() const {
return environment_;
}

SessionIOBinding::SessionIOBinding(InferenceSession* session) : sess_(session) {
ORT_ENFORCE(session->NewIOBinding(&binding_).IsOK());
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ struct ModelMetadata {
ModelMetadata& operator=(const ModelMetadata&) = delete;

std::string producer_name;
std::string producer_version;
std::string graph_name;
std::string domain;
std::string description;
Expand Down Expand Up @@ -601,6 +602,7 @@ class InferenceSession {
#endif

const Model& GetModel() const;
const Environment& GetEnvironment() const;

protected:
#if !defined(ORT_MINIMAL_BUILD)
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3013,6 +3013,7 @@ static constexpr OrtApi ort_api_1_to_22 = {
&OrtApis::GetEpDevices,
&OrtApis::SessionOptionsAppendExecutionProvider_V2,
&OrtApis::SessionOptionsSetEpSelectionPolicy,
&OrtApis::SessionOptionsSetEpSelectionPolicyDelegate,

&OrtApis::HardwareDevice_Type,
&OrtApis::HardwareDevice_VendorId,
Expand Down Expand Up @@ -3062,7 +3063,7 @@ static_assert(offsetof(OrtApi, AddExternalInitializersFromFilesInMemory) / sizeo
// no additions in version 19, 20, and 21
static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Size of version 20 API cannot change");

static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 316, "Size of version 22 API cannot change");
static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change");

// So that nobody forgets to finish an API version, this check will serve as a reminder:
static_assert(std::string_view(ORT_VERSION) == "1.22.0",
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,11 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOpt
size_t num_ep_options);

ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* sess_options,
_In_ OrtExecutionProviderDevicePolicy policy,
_In_opt_ EpSelectionDelegate* delegate);
_In_ OrtExecutionProviderDevicePolicy policy);

ORT_API_STATUS_IMPL(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* sess_options,
_In_ EpSelectionDelegate delegate,
_In_opt_ void* state);

// OrtHardwareDevice accessors.
ORT_API(OrtHardwareDeviceType, HardwareDevice_Type, _In_ const OrtHardwareDevice* device);
Expand Down
52 changes: 42 additions & 10 deletions onnxruntime/core/session/provider_policy_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@
bool bIsDefaultCpuEp = IsDefaultCpuEp(b);
if (!aIsDefaultCpuEp && !bIsDefaultCpuEp) {
// neither are default CPU EP. both do/don't match vendor.
// TODO: implement tie-breaker for this scenario. arbitrarily prefer the first for now
return true;
// TODO: implement tie-breaker for this scenario. arbitrarily sort by ep name

Check warning on line 97 in onnxruntime/core/session/provider_policy_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/session/provider_policy_context.cc:97: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
return a->ep_name < b->ep_name;
}

// one is the default CPU EP
Expand All @@ -104,31 +104,57 @@

return sorted_devices;
}

OrtKeyValuePairs GetModelMetadata(const InferenceSession& session) {
OrtKeyValuePairs metadata;
auto status_and_metadata = session.GetModelMetadata();

if (!status_and_metadata.first.IsOK()) {
return metadata;
}

// use field names from onnx.proto
const auto& model_metadata = *status_and_metadata.second;
metadata.Add("producer_name", model_metadata.producer_name);
metadata.Add("producer_version", model_metadata.producer_version);
metadata.Add("domain", model_metadata.domain);
metadata.Add("model_version", std::to_string(model_metadata.version));
metadata.Add("doc_string", model_metadata.description);
metadata.Add("graph_name", model_metadata.graph_name); // name from main GraphProto
metadata.Add("graph_description", model_metadata.graph_description); // descriptions from main GraphProto
for (const auto& entry : model_metadata.custom_metadata_map) {
metadata.Add(entry.first, entry.second);
}

return metadata;
}
} // namespace

// Select execution providers based on the device policy and available devices and add to session
Status ProviderPolicyContext::SelectEpsForSession(const Environment& env, const OrtSessionOptions& options,
InferenceSession& sess) {
ORT_ENFORCE(options.value.ep_selection_policy.delegate == nullptr,
"EP selection delegate support is not implemented yet.");

// Get the list of devices from the environment and order them.
// Ordered by preference within each type. NPU -> GPU -> NPU
// TODO: Should environment.cc do the ordering?
const auto& execution_devices = OrderDevices(env.GetOrtEpDevices());
std::vector<const OrtEpDevice*> execution_devices = OrderDevices(env.GetOrtEpDevices());

// The list of devices selected by policies
std::vector<const OrtEpDevice*> devices_selected;

// Run the delegate if it was passed in lieu of any other policy
if (options.value.ep_selection_policy.delegate) {
auto policy_fn = options.value.ep_selection_policy.delegate;
auto model_metadata = GetModelMetadata(sess);
OrtKeyValuePairs runtime_metadata; // TODO: where should this come from?

Check warning on line 147 in onnxruntime/core/session/provider_policy_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/core/session/provider_policy_context.cc:147: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]

std::vector<const OrtEpDevice*> delegate_devices(execution_devices.begin(), execution_devices.end());
std::array<const OrtEpDevice*, 8> selected_devices{nullptr};

size_t num_selected = 0;
auto* status = (*policy_fn)(delegate_devices.data(), delegate_devices.size(),
nullptr, nullptr, selected_devices.data(), selected_devices.size(), &num_selected);

EpSelectionDelegate delegate = options.value.ep_selection_policy.delegate;
auto* status = delegate(delegate_devices.data(), delegate_devices.size(),
&model_metadata, &runtime_metadata,
selected_devices.data(), selected_devices.size(), &num_selected,
options.value.ep_selection_policy.state);

// return or fall-through for both these cases
// going with explicit failure for now so it's obvious to user what is happening
Expand All @@ -142,6 +168,12 @@
if (num_selected == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "EP selection delegate did not select anything.");
}

// Copy the selected devices to the output vector
devices_selected.reserve(num_selected);
for (size_t i = 0; i < num_selected; ++i) {
devices_selected.push_back(selected_devices[i]);
}
} else {
// Create the selector for the chosen policy
std::unique_ptr<IEpPolicySelector> selector;
Expand Down
46 changes: 24 additions & 22 deletions onnxruntime/core/session/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,20 +172,6 @@
env->GetEnvironment());
}

#if !defined(ORT_MINIMAL_BUILD)
// TEMPORARY for testing. Manually specify the EP to select.
auto auto_select_ep_name = sess->GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select");
if (auto_select_ep_name) {
ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(env->GetEnvironment(), *sess, *auto_select_ep_name));
}

// if there are no providers registered, and there's an ep selection policy set, do auto ep selection
if (options != nullptr && options->provider_factories.empty() && options->value.ep_selection_policy.enable) {
ProviderPolicyContext context;
ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(env->GetEnvironment(), *options, *sess));
}
#endif // !defined(ORT_MINIMAL_BUILD)

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
// Add custom domains
if (options && !options->custom_op_domains_.empty()) {
Expand Down Expand Up @@ -216,22 +202,38 @@
ORT_ENFORCE(session_logger != nullptr,
"Session logger is invalid, but should have been initialized during session construction.");

// we need to disable mem pattern if DML is one of the providers since DML doesn't have the concept of
// byte addressable memory
std::vector<std::unique_ptr<IExecutionProvider>> provider_list;
if (options) {
const bool has_provider_factories = options != nullptr && !options->provider_factories.empty();

if (has_provider_factories) {
std::vector<std::unique_ptr<IExecutionProvider>> provider_list;
for (auto& factory : options->provider_factories) {
auto provider = factory->CreateProvider(*options, *session_logger->ToExternal());
provider_list.push_back(std::move(provider));
}

// register the providers
for (auto& provider : provider_list) {
if (provider) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider)));
}
}
}
#if !defined(ORT_MINIMAL_BUILD)
else {

Check warning on line 222 in onnxruntime/core/session/utils.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/core/session/utils.cc:222: If an else has a brace on one side, it should have it on both [readability/braces] [5]
// TEMPORARY for testing. Manually specify the EP to select.
auto auto_select_ep_name = sess.GetSessionOptions().config_options.GetConfigEntry("test.ep_to_select");
if (auto_select_ep_name) {
ORT_API_RETURN_IF_STATUS_NOT_OK(TestAutoSelectEPsImpl(sess.GetEnvironment(), sess, *auto_select_ep_name));
}

// register the providers
for (auto& provider : provider_list) {
if (provider) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess.RegisterExecutionProvider(std::move(provider)));
// if there are no providers registered, and there's an ep selection policy set, do auto ep selection.
// note: the model has already been loaded so model metadata should be available to the policy delegate callback.
if (options != nullptr && options->value.ep_selection_policy.enable) {
ProviderPolicyContext context;
ORT_API_RETURN_IF_STATUS_NOT_OK(context.SelectEpsForSession(sess.GetEnvironment(), *options, sess));
}
}
#endif // !defined(ORT_MINIMAL_BUILD)

if (prepacked_weights_container != nullptr) {
ORT_API_RETURN_IF_STATUS_NOT_OK(sess.AddPrePackedWeightsContainer(
Expand Down
Loading
Loading