Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
661a530
init
chilo-ms Jun 26, 2025
2784498
update comments
chilo-ms Jun 27, 2025
e1eca15
address lintrunner issue
chilo-ms Jun 27, 2025
4db3002
update comment to better review
chilo-ms Jun 27, 2025
3370de7
clean up and fix a compile warning
chilo-ms Jun 27, 2025
3077677
update test
chilo-ms Jun 27, 2025
256d055
merge main
chilo-ms Jul 5, 2025
e039ac9
refactor the code and address reviewers' comments
chilo-ms Jul 5, 2025
010f51f
update API comment
chilo-ms Jul 5, 2025
2439718
address reviewer's comments
chilo-ms Jul 5, 2025
9232c85
fix to change the function name
chilo-ms Jul 5, 2025
f686ba8
add an option to construct the sub-graph as a standalone OrtGraph.
chilo-ms Jul 6, 2025
86d4779
address reviewer comments
chilo-ms Jul 7, 2025
0589766
comment out the debug code
chilo-ms Jul 7, 2025
6e4dbee
address lintrunner issue
chilo-ms Jul 7, 2025
5246851
Add ORT_UNUSED_PARAMETER to address the build issue in minimal build
chilo-ms Jul 7, 2025
211e305
address reviewer comment
chilo-ms Jul 7, 2025
d5ec60a
fix bug
chilo-ms Jul 7, 2025
ecbeffb
remove the option to create a standalone OrtGraph
chilo-ms Jul 8, 2025
004de71
update comment
chilo-ms Jul 8, 2025
46c5dca
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 8, 2025
517cf02
Add another edge case test for nother 3-layer nested graph
chilo-ms Jul 9, 2025
f58b4d5
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 9, 2025
c15f43d
remove file that accidentally uploaded
chilo-ms Jul 9, 2025
7896ea8
revert back that in unit test to use half of the nodes to create OrtG…
chilo-ms Jul 9, 2025
57f851e
address reviewer comment
chilo-ms Jul 9, 2025
2fa60e2
Merge branch 'main' into chi/add_graph_getsubgraph
chilo-ms Jul 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -952,9 +952,12 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
return const_cast<Graph*>(this)->GetNodeArg(name);
}

// search this and up through any parent_graph_ instance for a NodeArg
// Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding mutable NodeArg
NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name);

// Searches for a NodeArg in the current graph and its parent graphs, and returns the corresponding const NodeArg
const NodeArg* GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const;

/** Gets a mutable NodeArg by name. Creates a new NodeArg that is owned by this Graph if not found.
@param name The NodeArg name.
@param[in] p_arg_type Optional TypeProto to use if the NodeArg needs to be created.
Expand Down
18 changes: 18 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -5748,6 +5748,24 @@ struct OrtApi {
*/
ORT_API2_STATUS(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node);

/** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph.
*
* Note:
* The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference
* the same underlying graph.
*
* \param[in] src_graph The source OrtGraph instance.
* \param[in] nodes A subset of the nodes/OrtNodes in 'graph'.
* \param[in] num_nodes Number of nodes.
* \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23.
*/
ORT_API2_STATUS(Graph_GetGraphView, _In_ const OrtGraph* src_graph, _In_ const OrtNode** nodes,
_In_ size_t num_nodes, _Outptr_ OrtGraph** dst_graph);

/// @}

/// \name OrtNode
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/graph/ep_api_types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -505,10 +505,34 @@ void EpGraph::IndexToEpNodeMap::SetEpNode(NodeIndex node_index, EpNode* ep_node)
EpGraph::EpGraph(const GraphViewer& graph_viewer, PrivateTag)
: OrtGraph(OrtGraphIrApi::kEpApi), graph_viewer_(graph_viewer) {}

EpGraph::EpGraph(std::unique_ptr<GraphViewer> graph_viewer,
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
PrivateTag)
: OrtGraph(OrtGraphIrApi::kEpApi),
graph_viewer_(*graph_viewer.get()),
owned_graph_viewer_(std::move(graph_viewer)),
owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
auto ep_graph = std::make_unique<EpGraph>(graph_viewer, PrivateTag{});

return CreateImpl(std::move(ep_graph), graph_viewer, result);
}

// Static class function to create a std::unique_ptr<EpGraph>.
Status EpGraph::Create(std::unique_ptr<GraphViewer> src_graph_viewer,
std::unique_ptr<IndexedSubGraph> src_indexed_sub_graph,
/*out*/ std::unique_ptr<EpGraph>& result) {
auto& graph_viewer = *src_graph_viewer.get();
auto ep_graph = std::make_unique<EpGraph>(std::move(src_graph_viewer),
std::move(src_indexed_sub_graph),
PrivateTag{});

return CreateImpl(std::move(ep_graph), graph_viewer, result);
}

Status EpGraph::CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result) {
AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance();
std::unordered_map<std::string, std::unique_ptr<EpValueInfo>> value_infos_map;

Expand Down
30 changes: 30 additions & 0 deletions onnxruntime/core/graph/ep_api_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,32 @@ struct EpGraph : public OrtGraph {

public:
EpGraph(const GraphViewer& graph_viewer, PrivateTag);
EpGraph(std::unique_ptr<GraphViewer> graph_viewer,
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
PrivateTag);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
/// This call is used when creating an EpGraph from a GraphViewer instance. The GraphViewer instance is not onwed by this EpGraph.
/// </summary>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <returns></returns>
static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);

/// <summary>
/// Creates an instance of EpGraph, which wraps a GraphViewer.
/// This call is used when creating an EpGraph from a subset of nodes in another EpGraph.
/// In this case, due to the implementation of OrtApis::Graph_GetGraphView, the new EpGraph instance
/// must take ownership of both the GraphViewer and IndexedSubGraph.
/// </summary>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <returns></returns>
static Status Create(std::unique_ptr<GraphViewer> graph_viewer,
std::unique_ptr<IndexedSubGraph> indexed_sub_graph,
/*out*/ std::unique_ptr<EpGraph>& result);

// Defines ToExternal() and ToInternal() functions to convert between OrtGraph and EpGraph.
DEFINE_ORT_GRAPH_IR_TO_EXTERNAL_INTERNAL_FUNCS(OrtGraph, EpGraph, OrtGraphIrApi::kEpApi)

Expand Down Expand Up @@ -331,9 +348,22 @@ struct EpGraph : public OrtGraph {
const OrtValue* GetInitializerValue(std::string_view name) const;

private:
/// <summary>
/// The real implementation of creating an EpGraph instance.
/// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly.
/// </summary>
/// <param name="ep_graph"></param>
/// <param name="graph_viewer"></param>
/// <param name="result"></param>
/// <returns></returns>
static Status CreateImpl(std::unique_ptr<EpGraph> ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr<EpGraph>& result);

const GraphViewer& graph_viewer_;
const EpNode* parent_node_ = nullptr;

std::unique_ptr<GraphViewer> owned_graph_viewer_ = nullptr;
std::unique_ptr<IndexedSubGraph> owned_indexed_sub_graph_ = nullptr;

std::vector<std::unique_ptr<EpNode>> nodes_;
IndexToEpNodeMap index_to_ep_node_;

Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1818,6 +1818,10 @@ NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name
return node_arg;
}

const NodeArg* Graph::GetNodeArgIncludingParentGraphs(const std::string& node_arg_name) const {
return const_cast<Graph*>(this)->GetNodeArgIncludingParentGraphs(node_arg_name);
}

void Graph::ReverseDFSFrom(gsl::span<NodeIndex const> from,
const std::function<void(const Node*)>& enter,
const std::function<void(const Node*)>& leave,
Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/core/graph/graph_viewer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,15 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
filtered_node_inputs_including_initializers_.reserve(metadef->inputs.size());

for (const auto& input : metadef->inputs) {
const auto* nodearg = graph.GetNodeArg(input);
// NodeArgs from the current scope or any outer scopes should be handled correctly.
//
// There is an edge case where the model consists of a graph with subgraphs nested across three levels.
// In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer).
// When constructing a new GraphViewer for the second- and third-layer subgraphs,
// the second-layer graph may not have the corresponding value_info for that first-layer input,
// because the second-layer graph itself doesn't consume it.
// Therefore, when working within the second-layer graph, we need to search outer scopes for the missing value_info.
const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(input);
ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Input not found:", input);
filtered_node_inputs_including_initializers_.push_back(nodearg);
if (!graph.IsInitializedTensor(input)) {
Expand All @@ -177,7 +185,7 @@ GraphViewer::GraphViewer(const Graph& graph, const IndexedSubGraph* filter_info)
}

for (const auto& output : metadef->outputs) {
const auto* nodearg = graph.GetNodeArg(output);
const auto* nodearg = graph.GetNodeArgIncludingParentGraphs(output);
ORT_ENFORCE(nodearg, "Mismatch between Graph and IndexedSubGraph. Output not found:", output);
filtered_node_outputs_.push_back(nodearg);
}
Expand Down
86 changes: 86 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2714,6 +2714,91 @@
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph,
_In_ const OrtNode** nodes,
_In_ size_t num_nodes,
_Outptr_ OrtGraph** dst_graph) {
API_IMPL_BEGIN

if (num_nodes == 0) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_nodes' argument should be > 0");
}

const EpGraph* ep_graph = EpGraph::ToInternal(src_graph);
if (ep_graph == nullptr) {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph.");
}
const Graph& graph = ep_graph->GetGraphViewer().GetGraph();

// Create a GraphViewer with filtered info
std::unique_ptr<IndexedSubGraph> indexed_sub_graph = std::make_unique<IndexedSubGraph>();
std::unique_ptr<IndexedSubGraph::MetaDef> metadef = std::make_unique<IndexedSubGraph::MetaDef>();
metadef->name = "sub_graph";
metadef->since_version = 1;
std::unordered_set<std::string> outputs;

Check warning on line 2738 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2738: Add #include <string> for string [build/include_what_you_use] [4]

Check warning on line 2738 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2738: Add #include <string> for string [build/include_what_you_use] [4]
std::unordered_set<const NodeArg*> initializers;

Check warning on line 2739 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2739: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

Check warning on line 2739 in onnxruntime/core/session/onnxruntime_c_api.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/session/onnxruntime_c_api.cc:2739: Add #include <unordered_set> for unordered_set<> [build/include_what_you_use] [4]

auto add_inputs = [&](ConstPointerContainer<std::vector<NodeArg*>> defs) {
for (const auto* def : defs) {
if (def->Exists()) {
// not the output of a previous node
if (outputs.count(def->Name()) == 0) {
metadef->inputs.push_back(def->Name());
} else {
// consumed by node so no longer subgraph output
// NOTE: Ignoring edge case where a node output is an overall graph output AND a node input
outputs.erase(def->Name());
}

if (graph.IsInitializedTensor(def->Name())) {
initializers.insert(def);
}
}
}
};

auto add_node = [&](const Node& node) {
indexed_sub_graph->nodes.push_back(node.Index());
add_inputs(node.InputDefs());
add_inputs(node.ImplicitInputDefs());

for (const auto* def : node.OutputDefs()) {
outputs.insert(def->Name());
}
};

// Add nodes
for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) {
const OrtNode* ort_node = nodes[node_idx];
const EpNode* ep_node = EpNode::ToInternal(ort_node);
if (ep_node == nullptr) {
return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph.");
}
add_node(ep_node->GetInternalNode());
}

// Add initializers
for (auto& initializer : initializers) {
metadef->constant_initializers.push_back(initializer->Name());
}

// Add outputs
for (auto& output : outputs) {
metadef->outputs.push_back(output);
}

indexed_sub_graph->SetMetaDef(std::move(metadef));
auto graph_viewer = std::make_unique<GraphViewer>(graph, *indexed_sub_graph.get());

std::unique_ptr<EpGraph> result;
ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result));

*dst_graph = result.release();

return nullptr;
API_IMPL_END
}

//
// OrtNode
//
Expand Down Expand Up @@ -3629,6 +3714,7 @@
&OrtApis::Graph_GetNumNodes,
&OrtApis::Graph_GetNodes,
&OrtApis::Graph_GetParentNode,
&OrtApis::Graph_GetGraphView,
&OrtApis::Node_GetId,
&OrtApis::Node_GetName,
&OrtApis::Node_GetOperatorType,
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ ORT_API_STATUS_IMPL(Graph_GetNumNodes, _In_ const OrtGraph* graph, _Out_ size_t*
ORT_API_STATUS_IMPL(Graph_GetNodes, const OrtGraph* graph,
_Out_writes_(num_nodes) const OrtNode** nodes, _In_ size_t num_nodes);
ORT_API_STATUS_IMPL(Graph_GetParentNode, _In_ const OrtGraph* graph, _Outptr_result_maybenull_ const OrtNode** node);
ORT_API_STATUS_IMPL(Graph_GetGraphView, _In_ const OrtGraph* graph, _In_ const OrtNode** nodes, _In_ size_t num_nodes,
_Outptr_ OrtGraph** subgraph);

// OrtNode
ORT_API_STATUS_IMPL(Node_GetId, _In_ const OrtNode* node, _Out_ size_t* node_id);
Expand Down
59 changes: 59 additions & 0 deletions onnxruntime/test/ep_graph/test_ep_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
#include <gsl/gsl>
#include <memory>
#include <vector>
#include <fstream>

#include "core/common/common.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/tensor_type_and_shape.h"
#include "core/framework/onnxruntime_typeinfo.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/graph/ep_api_types.h"
#include "core/graph/graph_proto_serializer.h"

#define ORT_EP_UTILS_ORT_GRAPH_TO_PROTO_IMPL
#include "core/providers/utils/ort_graph_to_proto.h"
Expand All @@ -31,6 +34,7 @@ namespace test {
// forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent
// to a graph represented by the internal ORT GraphViewer class.
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph);
static void Check_Graph_GetSubgraph(const OrtGraph& api_graph);

//
// Tests
Expand Down Expand Up @@ -73,6 +77,16 @@ TEST(EpGraphTest, Check3LayerNestedSubgraph) {
CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
}

TEST(EpGraphTest, Check3LayerNestedSubgraphV2) {
// The overall structure of this model is similar to the one used in "Check3LayerNestedSubgraph" test.
// The model consists of a graph with subgraphs nested across three levels.
// In this scenario, a third-layer subgraph consumes an input from the first-layer graph (not an initializer).
auto test_graph = TestGraph::Load(ORT_TSTR("testdata/three_layer_nested_subgraph_v2.onnx"));
ASSERT_NE(test_graph, nullptr) << "Failed to load test model";

CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph());
}

static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector<float>& output_data) {
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
Ort::SessionOptions sess_options;
Expand Down Expand Up @@ -474,6 +488,48 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span<const
}
}

// Checks the Graph_GetSubgraph C API
static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) {
const OrtApi& ort_api = Ort::GetApi();

// Get all the nodes
size_t num_nodes = 0;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&api_graph, &num_nodes));

std::vector<const OrtNode*> nodes(num_nodes);
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&api_graph, nodes.data(), nodes.size()));

// Select a half of nodes to create a OrtGraph
size_t num_selected_nodes = std::max((nodes.size() >> 1), (size_t)1);
std::vector<const OrtNode*> selected_nodes(num_selected_nodes);

for (size_t i = 0; i < num_selected_nodes; i++) {
selected_nodes[i] = nodes[i];
}

OrtGraph* sub_graph;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetGraphView(&api_graph, selected_nodes.data(), selected_nodes.size(), &sub_graph));

// Convert OrtGraph/GraphViewer to ModelProto and dump it to disk.
// If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw.
const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer();
std::unique_ptr<Model> model = std::make_unique<Model>(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger());
auto model_proto = std::make_unique<ONNX_NAMESPACE::ModelProto>(model->ToProto());
GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast<ExecutionOrder>(1));
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

const char* graph_name = nullptr;
ASSERT_ORTSTATUS_OK(ort_api.Graph_GetName(&api_graph, &graph_name));
std::string name = graph_name;
name += "_half.onnx";

// Dump the graph for debugging
// std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary);
// model_proto->SerializeToOstream(&dump);

ort_api.ReleaseGraph(sub_graph);
}

// Checks that the contents of the original GraphViewer matches the contents of the OrtGraph.
// Uses the public C APIs to traverse the OrtGraph.
static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) {
Expand Down Expand Up @@ -682,6 +738,9 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_
}
}
}

// Check creating an OrtGraph from a subset of nodes in an OrtGraph
Check_Graph_GetSubgraph(api_graph);
}

} // namespace test
Expand Down
Binary file not shown.
Loading