diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 54e03a31fceef..c18a42cc1bbc1 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi return const_cast(this)->GetNodeArg(name); } - // search this and up through any parent_graph_ instance for a NodeArg + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name); + // Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg + const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const; + /** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found. @param name The NodeArg name. @param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created. diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bf1dd6e20ce64..051a3f7283cbe 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -5748,6 +5748,24 @@ struct OrtApi { */ ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); + /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. + * + * Note: + * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * the same underlying graph. + * + * \param[in] src_graph The source OrtGraph instance. + * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. + * \param[in] num_nodes Number of nodes. + * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes, + _In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph); + /// @} /// \name OrtNode diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 8583fac30cfbf..7f81ab3433911 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -505,10 +505,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node) EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag) : OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {} +EpGraph::EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag) + : OrtGraph(OrtGraphIrApi::kEpApi), + graph_viewer_(*graph_viewer.get()), + owned_graph_viewer_(std::move(graph_viewer)), + owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} + // Static class function to create a std::unique_ptr. Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +// Static class function to create a std::unique_ptr. +Status EpGraph::Create(std::unique_ptr src_graph_viewer, + std::unique_ptr src_indexed_sub_graph, + /*out*/ std::unique_ptr& result) { + auto& graph_viewer = *src_graph_viewer.get(); + auto ep_graph = std::make_unique(std::move(src_graph_viewer), + std::move(src_indexed_sub_graph), + PrivateTag{}); + + return CreateImpl(std::move(ep_graph), graph_viewer, result); +} + +Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 12fa082d3f354..7b67f21bf4eb4 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -251,15 +251,32 @@ struct EpGraph : public OrtGraph { public: EpGraph(const GraphViewer& graph_viewer, PrivateTag); + EpGraph(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + PrivateTag); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph. /// /// /// /// static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + /// + /// Creates an instance of EpGraph, which wraps a GraphViewer. + /// This call is used when creating an EpGraph from a subset of nodes in another EpGraph. + /// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance + /// must take ownership of both the GraphViewer and IndexedSubGraph. + /// + /// + /// + /// + static Status Create(std::unique_ptr graph_viewer, + std::unique_ptr indexed_sub_graph, + /*out*/ std::unique_ptr& result); + // Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph. DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi) @@ -331,9 +348,22 @@ struct EpGraph : public OrtGraph { const OrtValue* GetInitializerValue(std::string_view name) const; private: + /// + /// The real implementation of creating an EpGraph instance. + /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. + /// + /// + /// + /// + /// + static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + const GraphViewer& graph_viewer_; const EpNode* parent_node_ = nullptr; + std::unique_ptr owned_graph_viewer_ = nullptr; + std::unique_ptr owned_indexed_sub_graph_ = nullptr; + std::vector> nodes_; IndexToEpNodeMap index_to_ep_node_; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index ca40bad2b4250..4d3091520d876 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name return node_arg; } +const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const { + return const_cast(this)->GetNodeArgIncludingParentGraphs(node_arg_name); +} + void Graph::ReverseDFSFrom(gsl::span from, const std::function& enter, const std::function& leave, diff --git a/onnxruntime/core/graph/graph_viewer.cc b/onnxruntime/core/graph/graph_viewer.cc index 1842c2b4a0d1f..948ebaa5f7e15 100644 --- a/onnxruntime/core/graph/graph_viewer.cc +++ b/onnxruntime/core/graph/graph_viewer.cc @@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size()); for (const auto& input : metadef->inputs) { - const auto* nodearg = graph.GetNodeArg(input); + // NodeArgs from the current scope or any outer scopes should be handled correctly. + // + // There is an edge case where the model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + // When constructing a new GraphViewer for the second- and third-layer subgraphs, + // the second-layer graph may not have the corresponding value_info for that first-layer input, + // because the second-layer graph itself doesn't consume it. + // Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info. + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input); filtered_node_inputs_including_initializers_.push_back(nodearg); if (!graph.IsInitializedTensor(input)) { @@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info) } for (const auto& output : metadef->outputs) { - const auto* nodearg = graph.GetNodeArg(output); + const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output); ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output); filtered_node_outputs_.push_back(nodearg); } diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 18b545483b38b..312ddd7e52e00 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2714,6 +2714,91 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetParentNode, _In_ const OrtGraph* graph, _O API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, + _In_ const OrtNode** nodes, + _In_ size_t num_nodes, + _Outptr_ OrtGraph** dst_graph) { + API_IMPL_BEGIN + + if (num_nodes == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0"); + } + + const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); + if (ep_graph == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); + } + const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); + + // Create a GraphViewer with filtered info + std::unique_ptr indexed_sub_graph = std::make_unique(); + std::unique_ptr metadef = std::make_unique(); + metadef->name = "sub_graph"; + metadef->since_version = 1; + std::unordered_set outputs; + std::unordered_set initializers; + + auto add_inputs = [&](ConstPointerContainer> defs) { + for (const auto* def : defs) { + if (def->Exists()) { + // not the output of a previous node + if (outputs.count(def->Name()) == 0) { + metadef->inputs.push_back(def->Name()); + } else { + // consumed by node so no longer subgraph output + // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input + outputs.erase(def->Name()); + } + + if (graph.IsInitializedTensor(def->Name())) { + initializers.insert(def); + } + } + } + }; + + auto add_node = [&](const Node& node) { + indexed_sub_graph->nodes.push_back(node.Index()); + add_inputs(node.InputDefs()); + add_inputs(node.ImplicitInputDefs()); + + for (const auto* def : node.OutputDefs()) { + outputs.insert(def->Name()); + } + }; + + // Add nodes + for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { + const OrtNode* ort_node = nodes[node_idx]; + const EpNode* ep_node = EpNode::ToInternal(ort_node); + if (ep_node == nullptr) { + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); + } + add_node(ep_node->GetInternalNode()); + } + + // Add initializers + for (auto& initializer : initializers) { + metadef->constant_initializers.push_back(initializer->Name()); + } + + // Add outputs + for (auto& output : outputs) { + metadef->outputs.push_back(output); + } + + indexed_sub_graph->SetMetaDef(std::move(metadef)); + auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + + std::unique_ptr result; + ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); + + *dst_graph = result.release(); + + return nullptr; + API_IMPL_END +} + // // OrtNode // @@ -3629,6 +3714,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Graph_GetNumNodes, &OrtApis::Graph_GetNodes, &OrtApis::Graph_GetParentNode, + &OrtApis::Graph_GetGraphView, &OrtApis::Node_GetId, &OrtApis::Node_GetName, &OrtApis::Node_GetOperatorType, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 75db44cb9e9ff..b53863c02cfef 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -649,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t* ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph, _Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes); ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node); +ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes, + _Outptr_ OrtGraph** subgraph); // OrtNode ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id); diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index e9bed3ac45529..17e829e37f729 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -7,12 +7,15 @@ #include #include #include +#include #include "core/common/common.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/session/onnxruntime_cxx_api.h" +#include "core/graph/ep_api_types.h" +#include "core/graph/graph_proto_serializer.h" #define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL #include "core/providers/utils/ort_graph_to_proto.h" @@ -31,6 +34,7 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); +static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); // // Tests @@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { + // The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test. + // The model consists of a graph with subgraphs nested across three levels. + // In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer). + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size())); + + // Select a half of nodes to create a OrtGraph + size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1); + std::vector selected_nodes(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + OrtGraph* sub_graph; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph)); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); + auto model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + const char* graph_name = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name)); + std::string name = graph_name; + name += "_half.onnx"; + + // Dump the graph for debugging + // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); + // model_proto->SerializeToOstream(&dump); + + ort_api.ReleaseGraph(sub_graph); +} + // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } } } + + // Check creating an OrtGraph from a subset of nodes in an OrtGraph + Check_Graph_GetSubgraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx new file mode 100644 index 0000000000000..d036541a70aa0 Binary files /dev/null and b/onnxruntime/test/testdata/three_layer_nested_subgraph_v2.onnx differ