Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/onnxruntime/core/graph/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ constexpr size_t kMaxExecutionProviderNameLen = 30;

constexpr const char* kCpuExecutionProvider = "CPUExecutionProvider";
constexpr const char* kCudaExecutionProvider = "CUDAExecutionProvider";
constexpr const char* kCudaPluginExecutionProvider = "CudaPluginExecutionProvider";
constexpr const char* kCudaNHWCExecutionProvider = "CUDANHWCExecutionProvider";
constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider";
constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider";
Expand Down
53 changes: 53 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3729,5 +3729,58 @@ using UnownedSharedPrePackedWeightCache =

///< Wraps OrtEpApi::GetEnvConfigEntries()
Ort::KeyValuePairs GetEnvConfigEntries();

/// \brief Non-owning C++ wrapper for resource budget queries on OrtEpGraphSupportInfo.
///
/// Constructed from the OrtEpGraphSupportInfo* passed to OrtEp::GetCapability.
/// Provides convenient methods for resource-constrained node selection.
/// All costs and budgets use OrtResourceCount, the ABI-stable tagged union.
///
/// Example use in a plugin EP's GetCapability implementation:
/// \code
/// OrtStatus* GetCapabilityImpl(OrtEp*, const OrtGraph* graph,
/// OrtEpGraphSupportInfo* info) noexcept {
/// Ort::ResourceBudget budget(info);
/// if (budget.HasBudget()) {
/// OrtResourceCount remaining = budget.GetBudget();
/// OrtResourceCount consumed = budget.GetConsumedResources();
/// for (const OrtNode* node : candidates) {
/// OrtResourceCount cost = budget.ComputeNodeCost(Ort::ConstNode{node});
/// if (cost.AsTotalBytes() > remaining.AsTotalBytes() - consumed.AsTotalBytes()) {
/// budget.SignalStopAssignment();
/// break;
/// }
/// budget.ReportAcceptedNodeCost(Ort::ConstNode{node}, cost);
/// }
/// }
/// }
/// \endcode
struct ResourceBudget {
explicit ResourceBudget(OrtEpGraphSupportInfo* info) : info_(info) {}

/// Returns true if a resource budget is configured for this EP.
bool HasBudget() const;

/// Returns the total resource budget. Only valid if HasBudget() is true.
OrtResourceCount GetBudget() const;

/// Returns the amount of resources already consumed.
OrtResourceCount GetConsumedResources() const;

/// Computes the estimated resource cost of the given node.
OrtResourceCount ComputeNodeCost(ConstNode node) const;

/// Reports that the plugin accepted a node at the given cost.
void ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost);

/// Returns true if stop has been signaled (by this or another EP).
bool IsStopIssued() const;

/// Signals that this EP wants to stop receiving nodes.
void SignalStopAssignment();

private:
OrtEpGraphSupportInfo* info_;
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
39 changes: 39 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -4174,4 +4174,43 @@ inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const c
ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema));
return OpSchema{schema};
}

// ResourceBudget implementation
inline bool ResourceBudget::HasBudget() const {
bool has_budget = false;
ThrowOnError(GetEpApi().EpGraphSupportInfo_HasResourceBudget(info_, &has_budget));
return has_budget;
}

inline OrtResourceCount ResourceBudget::GetBudget() const {
OrtResourceCount budget = OrtResourceCount::None();
ThrowOnError(GetEpApi().EpGraphSupportInfo_GetResourceBudget(info_, &budget));
return budget;
}

inline OrtResourceCount ResourceBudget::GetConsumedResources() const {
OrtResourceCount consumed = OrtResourceCount::None();
ThrowOnError(GetEpApi().EpGraphSupportInfo_GetConsumedResources(info_, &consumed));
return consumed;
}

inline OrtResourceCount ResourceBudget::ComputeNodeCost(ConstNode node) const {
OrtResourceCount cost = OrtResourceCount::None();
ThrowOnError(GetEpApi().EpGraphSupportInfo_ComputeNodeResourceCost(info_, node, &cost));
return cost;
}

inline void ResourceBudget::ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost) {
ThrowOnError(GetEpApi().EpGraphSupportInfo_ReportAcceptedNodeCost(info_, node, cost));
}

inline bool ResourceBudget::IsStopIssued() const {
bool stop = false;
ThrowOnError(GetEpApi().EpGraphSupportInfo_IsStopIssued(info_, &stop));
return stop;
}

inline void ResourceBudget::SignalStopAssignment() {
ThrowOnError(GetEpApi().EpGraphSupportInfo_SignalStopAssignment(info_));
}
} // namespace Ort
172 changes: 172 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,62 @@ struct OrtScanKernelHelper {
_In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, _Inout_ OrtValue* output);
};

/**
* \brief Discriminator for the resource count type stored in an OrtResourceCount.
*
* New resource accounting types can be added by appending new enum values.
* The OrtResourceCount union storage is large enough to hold all current and future types.
*
* \since Version 1.26.
*/
typedef enum OrtResourceCountKind {
OrtResourceCountKind_None = 0, ///< Unset / zero-cost sentinel.
OrtResourceCountKind_TotalBytes = 1, ///< Single uint64_t: total estimated bytes.
} OrtResourceCountKind;

/**
* \brief ABI-stable tagged union representing a resource cost or budget.
*
* This struct is a C-safe variant that can be passed by value across the plugin DLL boundary.
* The `kind` field selects which union member is active. The `_storage` member reserves space
* for future resource types without changing the struct layout.
*
* Adding new resource types requires only: (a) a new OrtResourceCountKind enum value,
* (b) a new union member. No new C API functions are needed.
*
* \since Version 1.26.
*/
typedef struct OrtResourceCount {
uint32_t kind; ///< OrtResourceCountKind discriminator. Use uint32_t for ABI stability.
uint32_t reserved_; ///< Alignment padding + future flags. Must be zero.

union {
uint64_t total_bytes; ///< Active when kind == OrtResourceCountKind_TotalBytes.
uint64_t _storage[6]; ///< ABI reserve (48 bytes): all types must fit within this.
} value;

#ifdef __cplusplus
/// Create a zero-cost (None) resource count.
static OrtResourceCount None() {
OrtResourceCount rc{};
return rc;
}

/// Create a TotalBytes resource count.
static OrtResourceCount FromTotalBytes(uint64_t bytes) {
OrtResourceCount rc{};
rc.kind = OrtResourceCountKind_TotalBytes;
rc.value.total_bytes = bytes;
return rc;
}

/// Extract total_bytes. Caller must check kind == OrtResourceCountKind_TotalBytes first.
uint64_t AsTotalBytes() const {
return value.total_bytes;
}
#endif
} OrtResourceCount;

/**
* \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider.
*
Expand Down Expand Up @@ -2010,6 +2066,122 @@ struct OrtEpApi {
ORT_API2_STATUS(ProfilingEventsContainer_AddEvents, _In_ OrtProfilingEventsContainer* events_container,
_In_reads_(num_events) const OrtProfilingEvent* const* events,
_In_ size_t num_events);

/** \brief Query whether resource accounting is active for this GetCapability call.
*
* Returns true if a resource accountant is attached to the given OrtEpGraphSupportInfo instance,
* meaning the session was configured with resource-constrained partitioning settings.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] has_budget Output parameter set to true if a resource budget is active.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_HasResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ bool* has_budget);

/** \brief Get the total resource budget.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
* If the accountant has no explicit threshold (e.g. auto-detection mode),
* the returned OrtResourceCount will have kind == OrtResourceCountKind_TotalBytes with
* value.total_bytes set to UINT64_MAX.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] budget Output parameter set to the total resource budget.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_GetResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ OrtResourceCount* budget);

/** \brief Get the amount of resources already consumed from prior partitioning passes or previously assigned nodes.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] consumed Output parameter set to the consumed resource amount.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_GetConsumedResources, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ OrtResourceCount* consumed);

/** \brief Compute the estimated resource cost for a node.
*
* Uses pre-recorded memory statistics if available, otherwise estimates from initializer sizes
* and static output shapes with a safety multiplier.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[in] node The OrtNode for which to compute the resource cost.
* \param[out] cost Output parameter set to the estimated resource cost.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_ComputeNodeResourceCost, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Out_ OrtResourceCount* cost);

/** \brief Report that a node was accepted and its cost should be tracked.
*
* The cost is stored internally so the host can attach it to the IndexedSubGraph after
* GetCapability returns. This does NOT commit the cost to the accountant's consumed amount —
* that happens later during node assignment by the graph partitioner.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[in] node The OrtNode whose cost is being reported.
* \param[in] cost The cost (as returned by EpGraphSupportInfo_ComputeNodeResourceCost).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_ReportAcceptedNodeCost, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _In_ OrtResourceCount cost);

/** \brief Query whether a previous GetCapability pass already signaled stop.
*
* Returns true if EpGraphSupportInfo_SignalStopAssignment was called in a prior pass
* (or by another mechanism). The plugin can use this to early-exit from GetCapability
* without re-evaluating nodes.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] is_stopped Output parameter set to true if stop was previously signaled.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_IsStopIssued, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ bool* is_stopped);

/** \brief Signal that the EP wants to stop accepting further nodes due to budget exhaustion.
*
* After this call, the accountant's stop flag is set. Subsequent GetCapability calls for this EP
* will see EpGraphSupportInfo_IsStopIssued() returning true and can return early.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_SignalStopAssignment, _In_ OrtEpGraphSupportInfo* graph_support_info);
};

/**
Expand Down
15 changes: 11 additions & 4 deletions onnxruntime/core/framework/layering_annotations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ bool MatchEpDevice(const EpDeviceView& ep,
if (target_specifier.empty()) {
if (ep.device_type == OrtDevice::GPU) return true;
// Heuristic fallback for common GPU EPs if hardware info is missing
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kDmlExecutionProvider;
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider ||
ep.ep_name == kDmlExecutionProvider;
}
// "gpu:<vendor>" or "gpu:<index>"
if (ep.device_type == OrtDevice::GPU) {
Expand All @@ -203,7 +204,7 @@ bool MatchEpDevice(const EpDeviceView& ep,
ep.vendor_id == OrtDevice::VendorIds::INTEL) return true;
// Heuristic: gpu:nvidia -> CUDA
if (CaseInsensitiveCompare(target_specifier, "nvidia") &&
ep.ep_name == kCudaExecutionProvider) return true;
(ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider)) return true;
}
return false;
}
Expand All @@ -225,7 +226,7 @@ bool MatchEpDevice(const EpDeviceView& ep,
}
// "cuda"
if (CaseInsensitiveCompare(target_type_str, "cuda")) {
return ep.ep_name == kCudaExecutionProvider;
return ep.ep_name == kCudaExecutionProvider || ep.ep_name == kCudaPluginExecutionProvider;
}
// "dml"
if (CaseInsensitiveCompare(target_type_str, "dml")) {
Expand Down Expand Up @@ -284,7 +285,13 @@ std::optional<std::string> EpLayeringMatcher::Match(gsl::span<const OrtEpDevice*
ep_device.ep_name,
device_type,
has_hw ? ep_device.device->vendor_id : 0u,
has_hw ? static_cast<OrtDevice::DeviceId>(ep_device.device->device_id) : OrtDevice::DeviceId{},
// Prefer the device ordinal from device_memory_info (set by the EP factory to
// a runtime device ordinal such as a CUDA ordinal) over the OrtHardwareDevice::device_id
// which is a hardware-type identifier and not guaranteed to be a stable runtime ordinal.
ep_device.device_memory_info
? ep_device.device_memory_info->device.Id()
: (has_hw ? static_cast<OrtDevice::DeviceId>(ep_device.device->device_id)
: OrtDevice::DeviceId{}),
has_hw ? std::string_view(ep_device.device->vendor) : std::string_view{}};

if (MatchEpDevice(view, target_type_str, target_specifier, rule.device)) {
Expand Down
Loading
Loading