Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3725,5 +3725,58 @@ using UnownedSharedPrePackedWeightCache =

///< Wraps OrtEpApi::GetEnvConfigEntries()
Ort::KeyValuePairs GetEnvConfigEntries();

/// \brief Non-owning C++ wrapper for resource budget queries on OrtEpGraphSupportInfo.
///
/// Constructed from the OrtEpGraphSupportInfo* passed to OrtEp::GetCapability.
/// Provides convenient methods for resource-constrained node selection.
/// All costs and budgets use OrtResourceCount, the ABI-stable tagged union.
///
/// Example use in a plugin EP's GetCapability implementation:
/// \code
/// OrtStatus* GetCapabilityImpl(OrtEp*, const OrtGraph* graph,
/// OrtEpGraphSupportInfo* info) noexcept {
/// Ort::ResourceBudget budget(info);
/// if (budget.HasBudget()) {
/// OrtResourceCount remaining = budget.GetBudget();
/// OrtResourceCount consumed = budget.GetConsumedResources();
/// for (const OrtNode* node : candidates) {
/// OrtResourceCount cost = budget.ComputeNodeCost(Ort::ConstNode{node});
/// if (consumed.AsTotalBytes() + cost.AsTotalBytes() > remaining.AsTotalBytes()) {
/// budget.SignalStopAssignment();
/// break;
/// }
/// budget.ReportAcceptedNodeCost(Ort::ConstNode{node}, cost);
Comment thread
yuslepukhin marked this conversation as resolved.
Outdated
/// }
/// }
/// }
/// \endcode
struct ResourceBudget {
explicit ResourceBudget(OrtEpGraphSupportInfo* info) : info_(info) {}

/// Returns true if a resource budget is configured for this EP.
bool HasBudget() const;

/// Returns the total resource budget. Only valid if HasBudget() is true.
OrtResourceCount GetBudget() const;

/// Returns the amount of resources already consumed.
OrtResourceCount GetConsumedResources() const;

/// Computes the estimated resource cost of the given node.
OrtResourceCount ComputeNodeCost(ConstNode node) const;

/// Reports that the plugin accepted a node at the given cost.
void ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost);

/// Returns true if stop has been signaled (by this or another EP).
bool IsStopIssued() const;

/// Signals that this EP wants to stop receiving nodes.
void SignalStopAssignment();

private:
OrtEpGraphSupportInfo* info_;
};
} // namespace Ort
#include "onnxruntime_cxx_inline.h"
39 changes: 39 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -4169,4 +4169,43 @@ inline OpSchema GetOpSchema(const char* name, int max_inclusive_version, const c
ThrowOnError(GetEpApi().GetOpSchema(name, max_inclusive_version, domain, &schema));
return OpSchema{schema};
}

// ResourceBudget implementation
inline bool ResourceBudget::HasBudget() const {
bool has_budget = false;
ThrowOnError(GetEpApi().EpGraphSupportInfo_HasResourceBudget(info_, &has_budget));
return has_budget;
}

inline OrtResourceCount ResourceBudget::GetBudget() const {
OrtResourceCount budget = OrtResourceCount::None();
ThrowOnError(GetEpApi().EpGraphSupportInfo_GetResourceBudget(info_, &budget));
return budget;
}

inline OrtResourceCount ResourceBudget::GetConsumedResources() const {
OrtResourceCount consumed = OrtResourceCount::None();
ThrowOnError(GetEpApi().EpGraphSupportInfo_GetConsumedResources(info_, &consumed));
return consumed;
}

inline OrtResourceCount ResourceBudget::ComputeNodeCost(ConstNode node) const {
OrtResourceCount cost = OrtResourceCount::None();
ThrowOnError(GetEpApi().EpGraphSupportInfo_ComputeNodeResourceCost(info_, node, &cost));
return cost;
}

inline void ResourceBudget::ReportAcceptedNodeCost(ConstNode node, OrtResourceCount cost) {
ThrowOnError(GetEpApi().EpGraphSupportInfo_ReportAcceptedNodeCost(info_, node, cost));
}

inline bool ResourceBudget::IsStopIssued() const {
bool stop = false;
ThrowOnError(GetEpApi().EpGraphSupportInfo_IsStopIssued(info_, &stop));
return stop;
}

inline void ResourceBudget::SignalStopAssignment() {
ThrowOnError(GetEpApi().EpGraphSupportInfo_SignalStopAssignment(info_));
}
} // namespace Ort
172 changes: 172 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -946,6 +946,62 @@ struct OrtScanKernelHelper {
_In_ const OrtValue* input, _In_opt_ OrtSyncStream* stream, _Inout_ OrtValue* output);
};

/**
* \brief Discriminator for the resource count type stored in an OrtResourceCount.
*
* New resource accounting types can be added by appending new enum values.
* The OrtResourceCount union storage is large enough to hold all current and future types.
*
* \since Version 1.26.
*/
typedef enum OrtResourceCountKind {
OrtResourceCountKind_None = 0, ///< Unset / zero-cost sentinel.
OrtResourceCountKind_TotalBytes = 1, ///< Single size_t: total estimated bytes.
} OrtResourceCountKind;

/**
* \brief ABI-stable tagged union representing a resource cost or budget.
*
* This struct is a C-safe variant that can be passed by value across the plugin DLL boundary.
* The `kind` field selects which union member is active. The `_storage` member reserves space
* for future resource types without changing the struct layout.
*
* Adding new resource types requires only: (a) a new OrtResourceCountKind enum value,
* (b) a new union member. No new C API functions are needed.
*
* \since Version 1.26.
*/
typedef struct OrtResourceCount {
OrtResourceCountKind kind;
Comment thread
yuslepukhin marked this conversation as resolved.
Outdated
uint32_t reserved_; ///< Alignment padding + future flags. Must be zero.

union {
size_t total_bytes; ///< Active when kind == OrtResourceCountKind_TotalBytes.
uint8_t _storage[48]; ///< ABI reserve: all types must fit within 48 bytes.
} value;

#ifdef __cplusplus
/// Create a zero-cost (None) resource count.
static OrtResourceCount None() {
OrtResourceCount rc{};
return rc;
}

/// Create a TotalBytes resource count.
static OrtResourceCount FromTotalBytes(size_t bytes) {
OrtResourceCount rc{};
rc.kind = OrtResourceCountKind_TotalBytes;
rc.value.total_bytes = bytes;
return rc;
}

/// Extract total_bytes. Caller must check kind == OrtResourceCountKind_TotalBytes first.
size_t AsTotalBytes() const {
return value.total_bytes;
}
#endif
} OrtResourceCount;

/**
* \brief The OrtEpApi struct provides functions that are relevant to the implementation of an execution provider.
*
Expand Down Expand Up @@ -2010,6 +2066,122 @@ struct OrtEpApi {
ORT_API2_STATUS(ProfilingEventsContainer_AddEvents, _In_ OrtProfilingEventsContainer* events_container,
_In_reads_(num_events) const OrtProfilingEvent* const* events,
_In_ size_t num_events);

/** \brief Query whether resource accounting is active for this GetCapability call.
*
* Returns true if a resource accountant is attached to the given OrtEpGraphSupportInfo instance,
* meaning the session was configured with resource-constrained partitioning settings.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] has_budget Output parameter set to true if a resource budget is active.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_HasResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ bool* has_budget);

/** \brief Get the total resource budget.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
* If the accountant has no explicit threshold (e.g. auto-detection mode),
* the returned OrtResourceCount will have kind == OrtResourceCountKind_TotalBytes with
* value.total_bytes set to SIZE_MAX.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] budget Output parameter set to the total resource budget.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_GetResourceBudget, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ OrtResourceCount* budget);

/** \brief Get the amount of resources already consumed from prior partitioning passes or previously assigned nodes.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] consumed Output parameter set to the consumed resource amount.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_GetConsumedResources, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ OrtResourceCount* consumed);

/** \brief Compute the estimated resource cost for a node.
*
* Uses pre-recorded memory statistics if available, otherwise estimates from initializer sizes
* and static output shapes with a safety multiplier.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[in] node The OrtNode for which to compute the resource cost.
* \param[out] cost Output parameter set to the estimated resource cost.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_ComputeNodeResourceCost, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _Out_ OrtResourceCount* cost);

/** \brief Report that a node was accepted and its cost should be tracked.
*
* The cost is stored internally so the host can attach it to the IndexedSubGraph after
* GetCapability returns. This does NOT commit the cost to the accountant's consumed amount —
* that happens later during node assignment by the graph partitioner.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[in] node The OrtNode whose cost is being reported.
* \param[in] cost The cost (as returned by EpGraphSupportInfo_ComputeNodeResourceCost).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_ReportAcceptedNodeCost, _In_ OrtEpGraphSupportInfo* graph_support_info,
_In_ const OrtNode* node, _In_ OrtResourceCount cost);

/** \brief Query whether a previous GetCapability pass already signaled stop.
*
* Returns true if EpGraphSupportInfo_SignalStopAssignment was called in a prior pass
* (or by another mechanism). The plugin can use this to early-exit from GetCapability
* without re-evaluating nodes.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
* \param[out] is_stopped Output parameter set to true if stop was previously signaled.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_IsStopIssued, _In_ const OrtEpGraphSupportInfo* graph_support_info,
_Out_ bool* is_stopped);

/** \brief Signal that the EP wants to stop accepting further nodes due to budget exhaustion.
*
* After this call, the accountant's stop flag is set. Subsequent GetCapability calls for this EP
* will see EpGraphSupportInfo_IsStopIssued() returning true and can return early.
*
* Only valid if EpGraphSupportInfo_HasResourceBudget returns true.
*
* \param[in] graph_support_info The OrtEpGraphSupportInfo instance from OrtEp::GetCapability().
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.26.
*/
ORT_API2_STATUS(EpGraphSupportInfo_SignalStopAssignment, _In_ OrtEpGraphSupportInfo* graph_support_info);
};

/**
Expand Down
71 changes: 66 additions & 5 deletions onnxruntime/core/providers/cuda/plugin/cuda_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
#include "ep/get_capability_utils.h"

#include <cstring>
#include <limits>

Check warning on line 13 in onnxruntime/core/providers/cuda/plugin/cuda_ep.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Found C++ system header after other header. Should be: cuda_ep.h, c system, c++ system, other. [build/include_order] [4] Raw Output: onnxruntime/core/providers/cuda/plugin/cuda_ep.cc:13: Found C++ system header after other header. Should be: cuda_ep.h, c system, c++ system, other. [build/include_order] [4]
#include <string>
#include <string_view>
#include <unordered_set>

#include "core/common/safeint.h"

namespace onnxruntime {
namespace cuda_plugin {

Expand Down Expand Up @@ -98,6 +101,18 @@
auto* ep = static_cast<CudaEp*>(this_ptr);
const OrtEpApi& ep_api = ep->factory_.GetEpApi();

// Early exit if a previous GetCapability pass already signaled stop.
// This mirrors the in-tree CUDA EP's check at the top of GetCapability().
Ort::ResourceBudget resource_budget(graph_support_info);
bool has_budget = resource_budget.HasBudget();
if (has_budget && resource_budget.IsStopIssued()) {
Ort::Status log_status(Ort::GetApi().Logger_LogMessage(
&ep->logger_, ORT_LOGGING_LEVEL_WARNING,
"CUDA Plugin EP returning due to Stop Set",
ORT_FILE, __LINE__, __FUNCTION__));
return nullptr;
Comment thread
yuslepukhin marked this conversation as resolved.
Outdated
}

Ort::ConstGraph graph{ort_graph};
std::vector<Ort::ConstNode> all_nodes = graph.GetNodes();

Expand Down Expand Up @@ -144,13 +159,59 @@
gsl::span<const OrtNode* const>(tentative_nodes.data(), tentative_nodes.size()),
cpu_preferred_nodes));

// Phase 3: Add final supported nodes (tentative minus CPU-preferred).
// Phase 3: Add final supported nodes (tentative minus CPU-preferred),
// respecting the optional resource budget.
// resource_budget and has_budget were computed at the top of this function.
size_t budget_bytes = std::numeric_limits<size_t>::max();
size_t consumed_bytes = 0;
if (has_budget) {
budget_bytes = resource_budget.GetBudget().AsTotalBytes();
consumed_bytes = resource_budget.GetConsumedResources().AsTotalBytes();
}

for (const OrtNode* ort_node : candidate_nodes) {
if (cpu_preferred_nodes.count(ort_node) == 0) {
Ort::ConstNode node{ort_node};
RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode(
graph_support_info, node));
if (cpu_preferred_nodes.count(ort_node) != 0) {
continue;
}

// Previously assigned nodes (ep_name matched) are already accounted for.
Ort::ConstNode node{ort_node};
bool previously_assigned = !node.GetEpName().empty();

if (has_budget && !previously_assigned) {
OrtResourceCount cost = resource_budget.ComputeNodeCost(node);
size_t cost_bytes = cost.AsTotalBytes();
size_t would_be_consumed = SafeInt<size_t>(consumed_bytes) + cost_bytes;

{
// Log per-node cost information (mirrors in-tree CUDA EP logging)
std::string msg = "CUDA Plugin EP Node: " + node.GetName() +
" Memory usage: " + std::to_string(cost_bytes) +
" would be consumed: " + std::to_string(would_be_consumed) +
" threshold: " + std::to_string(budget_bytes);
Ort::Status log_status(Ort::GetApi().Logger_LogMessage(
&ep->logger_, ORT_LOGGING_LEVEL_INFO,
msg.c_str(), ORT_FILE, __LINE__, __FUNCTION__));
}

if (would_be_consumed > budget_bytes) {
{
std::string msg = "CUDA Plugin EP Halting assignment due to capacity threshold at node: " +
node.GetName();
Ort::Status log_status(Ort::GetApi().Logger_LogMessage(
&ep->logger_, ORT_LOGGING_LEVEL_WARNING,
msg.c_str(), ORT_FILE, __LINE__, __FUNCTION__));
}
resource_budget.SignalStopAssignment();
break; // topological-order halt
}

consumed_bytes = would_be_consumed;
resource_budget.ReportAcceptedNodeCost(node, cost);
}

RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_AddSingleNode(
graph_support_info, node));
}

return nullptr;
Expand Down
Loading
Loading