Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e1735d4
update
chilo-ms Dec 3, 2025
3a3e63d
fix compile error
chilo-ms Dec 3, 2025
982d6dc
lintrunner -a
chilo-ms Dec 3, 2025
72fff82
add back test Check_Graph_GetSubgraph()
chilo-ms Dec 3, 2025
eb8c938
Add implementation for GetConsumerNodes() for MINIMAL_BUILD
chilo-ms Dec 3, 2025
ade76c5
Add check node in EpGraph when getting GetProducerInfo/GetConsumerInfo
chilo-ms Dec 8, 2025
059c918
Log warning if the node is outside of the subgraph
chilo-ms Dec 9, 2025
0eafb99
Make EpGraph create parent node EpNode if the graph is a subgraph of …
chilo-ms Dec 16, 2025
d94ba2a
add macro for non minimal build
chilo-ms Dec 18, 2025
3bdca16
Handle outer scope initializers for the subgraph
chilo-ms Dec 19, 2025
0f96d23
update ort_graph_to_proto.h
chilo-ms Jan 5, 2026
cfadecd
Merge branch 'main' into chi/update_graph_view_api
chilo-ms Jan 5, 2026
373d828
update ort_graph_to_proto.h
chilo-ms Jan 5, 2026
2e81538
use graph.GetInitializer() instead of graph.GetConstantInitializer()
chilo-ms Jan 5, 2026
d245852
update ort_graph_to_proto.h to include missing initializers
chilo-ms Jan 13, 2026
b863464
Use the unified implementation for node arg to consumer nodes across …
chilo-ms Jan 14, 2026
492efa3
address reviewer's comment
chilo-ms Jan 14, 2026
53f1553
add comments to functions
chilo-ms Jan 14, 2026
7fbaa70
address reveiwer's comments
chilo-ms Jan 15, 2026
428b2e9
address reviewr's comments
chilo-ms Jan 16, 2026
e9c1f8e
address reviewer's comment
chilo-ms Jan 19, 2026
e833a4a
Revert the code in GraphViewer so that it stays the old behavior that…
chilo-ms Jan 19, 2026
78774d1
address reviewer's comments
chilo-ms Jan 19, 2026
ef35e91
address reviewer's comments
chilo-ms Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions include/onnxruntime/core/providers/utils/ort_graph_to_proto.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
#define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_

#include <functional>
#include <optional>
#include "core/session/onnxruntime_cxx_api.h"
#include "onnx/onnx_pb.h"

Expand Down Expand Up @@ -317,9 +318,11 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph,

// Don't add graph inputs or graph outputs to GraphProto's list of value_infos.
// Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors.
// For values defined in an outer scope, just add the value info but not the initializer.
if (is_from_outer_scope) {
value_infos.emplace(value_name, ort_value_info);
if (is_constant_initializer) {
initializer_value_infos.emplace(value_name, ort_value_info);
}
} else if (is_optional_graph_input) {
initializer_value_infos.emplace(value_name, ort_value_info);
} else if (is_constant_initializer) {
Expand Down Expand Up @@ -413,6 +416,16 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph,
ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto));
}

// There may be initializers in the original OrtGraph that have not been added yet.
// For example, an initializer may not be used by any node but is still a graph output.
// Iterating through all nodes to collect initializer value info is therefore not sufficient,
// initializers must also be obtained from ort_graph.GetInitializers().
// Add those missing initializers and skip the ones that already in `initializer_value_infos`
std::vector<Ort::ConstValueInfo> ort_graph_initializers = ort_graph.GetInitializers();
for (const auto& initializer : ort_graph_initializers) {
initializer_value_infos.emplace(initializer.GetName(), initializer);
}

// Add initializers to GraphProto as TensorProto objects.
for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) {
std::vector<int64_t> initializer_dims;
Expand Down Expand Up @@ -490,10 +503,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph,
onnx::ModelProto& model_proto,
HandleInitializerDataFunc handle_initializer_data_func) {
try {
// Check that OrtGraph is a top-level graph (no parent node).
Ort::ConstGraph ort_graph{&graph};
Ort::ConstNode parent_node = ort_graph.GetParentNode();
ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, "Cannot serialize nested OrtGraph into a ModelProto");

// Set model description.
model_proto.set_doc_string("Serialized from OrtGraph");
Expand Down
6 changes: 5 additions & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5917,7 +5917,8 @@ struct OrtApi {
/** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph.
*
* \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference
* the same underlying graph.
* the same underlying graph. "dst_graph" preserves the input order of "src_graph", and
* its output order corresponds to the outputs produced by the nodes in "nodes" with the given order.
*
* \param[in] src_graph The source OrtGraph instance.
* \param[in] nodes A subset of the nodes/OrtNodes in 'graph'.
Expand Down Expand Up @@ -6206,6 +6207,9 @@ struct OrtApi {
/** \brief Get the node's parent OrtGraph instance.
*
* Can return NULL if the OrtNode was created without an owning graph.
* In another case, this API may also return NULL if `node` is obtained by calling Graph_GetParentNode()
* on an OrtGraph that is a subgraph of a control-flow op, and the parent graph has not been created yet,
* for example during ORT's GetCapability() when processing the innermost subgraph.
*
* \param[in] node The OrtNode instance.
* \param[out] graph Output parameter set to the node's OrtGraph. Can be set to NULL
Expand Down
49 changes: 43 additions & 6 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ Status EpValueInfo::GetProducerInfo(OrtValueInfo::ProducerInfo& producer_info) c
producer_info.output_index = 0;

if (graph_ == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get producer node for OrtValueInfo '", name_,
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_FOUND, "Unable to get producer node for OrtValueInfo '", name_,
"' that is not owned by a OrtGraph.");
}

Expand All @@ -379,7 +379,15 @@ Status EpValueInfo::GetProducerInfo(OrtValueInfo::ProducerInfo& producer_info) c

const EpNode* ep_node = graph_->GetNode(node->Index());
if (ep_node == nullptr) {
return Status::OK(); // Node is not in this GraphViewer
producer_info.node = nullptr;
producer_info.output_index = 0;
#if !defined(ORT_MINIMAL_BUILD)
const auto& logger = graph_->GetGraphViewer().GetGraph().GetLogger();
LOGS(logger, WARNING) << "Unable to get producer node for OrtValueInfo '"
<< name_
<< "' that is not owned by an OrtGraph.";
#endif // !defined(ORT_MINIMAL_BUILD)
return Status::OK();
}

size_t output_index = 0;
Expand Down Expand Up @@ -543,6 +551,9 @@ void EpGraph::IndexToEpNodeMap::Resize(NodeIndex min_node_index, NodeIndex max_n
}

EpNode* EpGraph::IndexToEpNodeMap::GetEpNode(NodeIndex node_index) const {
if (node_index < min_node_index_ || node_index > (min_node_index_ + nodes_.size() - 1)) {
return nullptr;
}
size_t i = node_index - min_node_index_;
assert(i < nodes_.size());
return nodes_[i];
Expand All @@ -566,10 +577,10 @@ EpGraph::EpGraph(std::unique_ptr<GraphViewer> graph_viewer,
owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result, bool create_parent_node) {
auto ep_graph = std::make_unique<EpGraph>(graph_viewer, PrivateTag{});

return CreateImpl(std::move(ep_graph), graph_viewer, result);
return CreateImpl(std::move(ep_graph), graph_viewer, result, create_parent_node);
}

// Static class function to create a std::unique_ptr<EpGraph>.
Expand All @@ -584,7 +595,8 @@ Status EpGraph::Create(std::unique_ptr<GraphViewer> src_graph_viewer,
return CreateImpl(std::move(ep_graph), graph_viewer, result);
}

Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer,
/*out*/ std::unique_ptr<EpGraph>& result, bool create_parent_node) {
AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance();
std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> value_infos_map;

Expand Down Expand Up @@ -687,13 +699,30 @@ Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer&
}
}

std::unique_ptr<EpNode> ep_parent_node = nullptr;
std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> parent_node_value_infos_map;

// If this is a subgraph, add the OrtValueInfo and OrtValue objects that come from the outer scope.
// Wait until we have already processed OrtValueInfos consumed and produced by nodes so that we only add
// outer OrtValueInfo/OrtValue if they are actually used by the nodes in this GraphViewer.
if (graph_viewer.IsSubgraph()) {
gsl::not_null<const Graph*> parent_graph = graph_viewer.GetGraph().ParentGraph();
gsl::not_null<const Node*> parent_node = graph_viewer.ParentNode();

// If the subgraph of a control-flow op is created before its parent node (for example, when constructing
// the graph during ORT's GetCapability() in a bottom-up manner), the parent node must also be created.
if (create_parent_node) {
std::unique_ptr<EpNode> ep_node = nullptr;

// At this point, the EpGraph that contains the parent node hasn't been created yet.
// It's not needed to create that EpGraph here, so just pass nullptr.
ORT_RETURN_IF_ERROR(EpNode::Create(*parent_node, /*ep_graph*/ nullptr, parent_node_value_infos_map, ep_node));

// Note: Calling ep_parent_node.GetGraph() will return nullptr because
// ep_parent_node was created without an associated EpGraph pointer.
ep_parent_node = std::move(ep_node);
}

for (gsl::not_null<const NodeArg*> implicit_node_arg : parent_node->ImplicitInputDefs()) {
const std::string& implicit_name = implicit_node_arg->Name();
auto value_info_iter = value_infos_map.find(implicit_name);
Expand Down Expand Up @@ -741,6 +770,9 @@ Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer&
ep_graph->outer_scope_initializer_values_ = std::move(outer_scope_initializer_values);
ep_graph->inputs_ = std::move(graph_input_value_infos);
ep_graph->outputs_ = std::move(graph_output_value_infos);
ep_graph->parent_node_owned_ = std::move(ep_parent_node);
ep_graph->parent_node_ = ep_graph->parent_node_owned_ ? ep_graph->parent_node_owned_.get() : nullptr;
ep_graph->parent_node_value_infos_map_ = std::move(parent_node_value_infos_map);

result = std::move(ep_graph);

Expand Down Expand Up @@ -873,10 +905,15 @@ Status EpGraph::GetNodes(gsl::span<const OrtNode*> dst) const {

Status EpGraph::GetParentNode(const OrtNode*& result) const {
result = parent_node_ != nullptr ? parent_node_->ToExternal() : nullptr;

return Status::OK();
}

void EpGraph::SetParentNode(const EpNode* node) { parent_node_ = node; }
void EpGraph::SetParentNode(const EpNode* node) {
parent_node_ = node;
parent_node_owned_ = nullptr;
parent_node_value_infos_map_.clear();
}

const GraphViewer& EpGraph::GetGraphViewer() const { return graph_viewer_; }

Expand Down
27 changes: 24 additions & 3 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ struct EpNode : public OrtNode {
const char** opt_attribute_names) const override;

// Gets this node's parent graph, which is the graph that directly contains this node.
// Note: This call may return NULL if this node is obtained by calling GetParentNode()
// on an EpGraph that is a subgraph of a control-flow op, and the parent graph has not been created yet,
// for example during ORT's GetCapability() when processing the innermost subgraph.
Status GetGraph(const OrtGraph*& parent_graph) const override;

//
Expand Down Expand Up @@ -269,8 +272,16 @@ struct EpGraph : public OrtGraph {
/// </summary>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <param name="create_parent_node">If the `graph_viewer` is a subgraph of a control flow op,
/// e.g. Loop/If/Scan op, and `create_parent_node` is set to true,
/// then `result` EpGraph will create and own parent node's EpNode
/// instance. It's mainly used in EP's GetCapability() as it's
/// a bottom-up approach where inner-most subgraph will be constructed
/// first and by the time its parent node/graph hasn't be constructed yet.</param>
/// <returns></returns>
static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);
static Status Create(const GraphViewer& graph_viewer,
/*out*/ std::unique_ptr<EpGraph>& result,
bool create_parent_node = false);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
Expand Down Expand Up @@ -364,17 +375,27 @@ struct EpGraph : public OrtGraph {
private:
/// <summary>
/// The real implementation of creating an EpGraph instance.
/// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly.
/// Please use one of the above 'Create' functions that internally call this function,
/// and avoid calling this function directly.
/// </summary>
/// <param name="ep_graph"></param>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <param name="create_parent_node"></param>
/// <returns></returns>
static Status CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);
static Status CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer,
/*out*/ std::unique_ptr<EpGraph>& result, bool create_parent_node = false);

const GraphViewer& graph_viewer_;

// Hold the parent node created and owned by this graph
std::unique_ptr<EpNode> parent_node_owned_ = nullptr;

// Holds either a pointer to a parent node not owned by this graph, a pointer to parent_node_owned_, or nullptr.
const EpNode* parent_node_ = nullptr;

std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> parent_node_value_infos_map_;

std::unique_ptr<GraphViewer> owned_graph_viewer_ = nullptr;
std::unique_ptr<IndexedSubGraph> owned_indexed_sub_graph_ = nullptr;

Expand Down
Loading
Loading