Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -921,11 +921,18 @@ class Graph {
@returns Node with fused subgraph.
@remarks As a new Graph instance for the fused nodes is not created, a GraphViewer can be constructed with the
IndexedSubGraph information to provide a view of the subgraph. The original nodes are left in place
while this is in use.
while this is in use.
Call FinalizeFuseSubGraph to remove them once the fused replacement node is fully created.
*/
Node& BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::string& fused_node_name);

/**
If we have BeginFuseSubGraph, but somehow hit errors, such as Compile of an EP failed on thesub_graph.
We can call CancelFuseSubGraph to undo the changes of BeginFuseSubGraph
@param fused_node The fused node and it's function body to be removed from the graph
*/
void CancelFuseSubGraph(const Node& fused_node);

void FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node);
#endif

Expand Down
61 changes: 31 additions & 30 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -448,42 +448,43 @@ static Status PartitionOrtFormatModelImpl(Graph& graph, FuncManager& func_mgr,
nodes_and_viewers.push_back(IExecutionProvider::FusedNodeAndGraph{fused_node, *viewers.back()});
}

std::vector<NodeComputeInfo> node_compute_funcs;
node_compute_funcs.reserve(nodes_and_viewers.size());

ORT_RETURN_IF_ERROR(current_ep.Compile(nodes_and_viewers, node_compute_funcs));

if (node_compute_funcs.size() != nodes_and_viewers.size()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, type, " did not return correct number of compiled functions");
}

for (size_t j = 0, end = nodes_and_viewers.size(); j < end; j++) {
// We will compile the fused nodes one by one, and fuse the subgraph if successful.
// If a compilation fails we undo the fusion and leave the original nodes available for other EPs to take
for (size_t j = 0, end = nodes_and_viewers.size(); j < end; ++j) {
Node& node = nodes_and_viewers[j].fused_node;
std::vector<NodeComputeInfo> single_node_compute_func;
auto status = current_ep.Compile({nodes_and_viewers[j]}, single_node_compute_func);
if (!status.IsOK()) {
// There is compile error with the nodes_and_viewer[j], remove the fused_node and function from the graph
LOGS_DEFAULT(ERROR) << "EP: " << current_ep.Type() << " has Compile error: " << status.ErrorMessage();
graph.CancelFuseSubGraph(node);
} else {
ORT_RETURN_IF(single_node_compute_func.empty(), "single_node_compute_func should have 1 elements");
ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(single_node_compute_func[0])));

ORT_RETURN_IF_ERROR(func_mgr.AddFuncInfo(node.Name(), std::move(node_compute_funcs[j])));

const auto& cur_capability = capabilities[j];
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();
const auto& cur_capability = capabilities[j];
const IndexedSubGraph& indexed_sub_graph = *cur_capability->sub_graph;
const IndexedSubGraph::MetaDef& metadef = *indexed_sub_graph.GetMetaDef();

KernelDefBuilder builder;
BuildFusedKernelDef(builder, metadef, type);
auto kernel_def = builder.Build();
KernelDefBuilder builder;
BuildFusedKernelDef(builder, metadef, type);
auto kernel_def = builder.Build();

// save hash so SessionState can find the kernel. each kernel name should be unique
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
". Execution Provider must generate unique names across the entire model.");
}
// save hash so SessionState can find the kernel. each kernel name should be unique
if (compiled_kernel_hashes.insert({metadef.name, kernel_def->GetHash()}).second == false) {
ORT_THROW("Existing entry in compiled kernel hashes for ", metadef.name,
". Execution Provider must generate unique names across the entire model.");
}

ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
KernelCreateInfo(std::move(kernel_def), static_cast<KernelCreatePtrFn>(
[](const OpKernelInfo& info) -> OpKernel* {
return new FunctionKernel(info);
}))));
ORT_RETURN_IF_ERROR(fused_kernel_registry.Register(
KernelCreateInfo(std::move(kernel_def), static_cast<KernelCreatePtrFn>(
[](const OpKernelInfo& info) -> OpKernel* {
return new FunctionKernel(info);
}))));

// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
// now that we're done compiling we can remove the original nodes from the Graph and wire in the new one
graph.FinalizeFuseSubGraph(indexed_sub_graph, node);
}
}

return Status::OK();
Expand Down
29 changes: 26 additions & 3 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3389,6 +3389,31 @@ Node& Graph::BeginFuseSubGraph(const IndexedSubGraph& sub_graph, const std::stri
return node;
}

void Graph::CancelFuseSubGraph(const Node& fused_node) {
auto node_idx = fused_node.Index();
if (!GetNode(node_idx))
return;

if (fused_node.NodeType() != Node::Type::Fused)
return;

#if !defined(ORT_MINIMAL_BUILD)
// Remove the function body from function container
const auto* fused_node_func = fused_node.GetFunctionBody();
auto it = std::find_if(
function_container_.begin(), function_container_.end(),
[fused_node_func](const std::unique_ptr<onnxruntime::Function>& func) {
return func.get() == fused_node_func;
});
if (it != function_container_.end()) {
function_container_.erase(it);
}
#endif

// Remove the fused_node
RemoveNode(node_idx);
}

void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_node) {
const auto* func_meta_def = sub_graph.GetMetaDef();
ORT_ENFORCE(nullptr != func_meta_def);
Expand Down Expand Up @@ -3429,9 +3454,7 @@ void Graph::FinalizeFuseSubGraph(const IndexedSubGraph& sub_graph, Node& fused_n
if (it != input_indexes.cend()) {
AddEdge(producer_idx, new_node_idx, src_idx, it->second);
}
}
else
{
} else {
int dst_implicit_input_idx = dst_idx - (int)node->InputDefs().size();
ORT_ENFORCE(dst_implicit_input_idx < (int)node->ImplicitInputDefs().size());
auto it = input_indexes.find(node->ImplicitInputDefs()[dst_implicit_input_idx]->Name());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,102 @@ TEST(InternalTestingEP, TestModelWithSubgraph) {
feeds);
}

// A custom InternalTestingEP extension
// This is to testing execution fall back to CPU EP if Compile fails, for ORT format
// This EP will take an additional compile_failure_ops
// If in Compile() any nodes in the partition is also in compile_failure_ops
// The Compile will fail
class CompileFailureTestExecutionProvider : public InternalTestingExecutionProvider {
public:
CompileFailureTestExecutionProvider(const std::unordered_set<std::string>& supported_ops,
const std::unordered_set<std::string>& compile_failure_ops);
virtual ~CompileFailureTestExecutionProvider() = default;

Status Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) override;

private:
std::unordered_set<std::string> compile_failure_ops_;
};

CompileFailureTestExecutionProvider::CompileFailureTestExecutionProvider(
const std::unordered_set<std::string>& supported_ops,
const std::unordered_set<std::string>& compile_failure_ops)
: InternalTestingExecutionProvider(supported_ops),
compile_failure_ops_(compile_failure_ops) {}

Status CompileFailureTestExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused_nodes,
std::vector<NodeComputeInfo>& node_compute_funcs) {
for (const auto& fused_node_and_graph : fused_nodes) {
// If any nodes in this partition is also in compile_failure_ops_, the Compile will fail
const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph);
for (const auto& node : graph_viewer.Nodes()) {
if (compile_failure_ops_.find(node.OpType()) != compile_failure_ops_.end()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"CompileFailureTestExecutionProvider::Compile failed for node: ", node.Name());
}
}
}

return InternalTestingExecutionProvider::Compile(fused_nodes, node_compute_funcs);
}

TEST(InternalTestingEP, TestOrtModelWithCompileFailure) {
// In the test file, there are 2 Conv and 1 Gemm nodes, all disconnected
// So we should have 3 partitions be taken by InternalTestingExecutionProvider/CompileFailureTestExecutionProvider
// But CompileFailureTestExecutionProvider will fail the Compile for partition contains "Gemm" node
// This is to test the model initialization won't fail and Gemm node will not be replaced by the fused_node
const ORTCHAR_T* ort_model_path = ORT_TSTR("testdata/mnist.ort");

const std::unordered_set<std::string>& supported_ops{"Conv", "Gemm"};
const std::unordered_set<std::string>& compile_failure_ops{"Gemm"};

// Use InternalTestingExecutionProvider
// We should have 3 partitions taken by the EP
// 2 Conv and 1 Gemm
{
InferenceSessionWrapper session(SessionOptions(), GetEnvironment());
ASSERT_STATUS_OK(session.RegisterExecutionProvider(
onnxruntime::make_unique<InternalTestingExecutionProvider>(supported_ops)));
ASSERT_STATUS_OK(session.Load(ort_model_path));
ASSERT_STATUS_OK(session.Initialize());

int num_replaced_nodes = CountAndValidateAssignedNodes(
session.GetGraph(), supported_ops, session.GetSessionState().GetFuncMgr());

ASSERT_EQ(num_replaced_nodes, 3);
}

// Use CompileFailureTestExecutionProvider which will fail Compile on "Gemm"
// We should have 2 partitions taken by the EP
// 2 Conv
{
InferenceSessionWrapper session(SessionOptions(), GetEnvironment());
ASSERT_STATUS_OK(session.RegisterExecutionProvider(
onnxruntime::make_unique<CompileFailureTestExecutionProvider>(supported_ops, compile_failure_ops)));
ASSERT_STATUS_OK(session.Load(ort_model_path));
ASSERT_STATUS_OK(session.Initialize());

// 2 Conv nodes shoule be replaced with fused nodes
const auto& graph = session.GetGraph();
int num_replaced_nodes = CountAndValidateAssignedNodes(
session.GetGraph(), {"Conv"}, session.GetSessionState().GetFuncMgr());

ASSERT_EQ(num_replaced_nodes, 2);

// The Gemm node should still not have been replaced
int count_compile_failure_nodes = 0;
for (const auto& node : graph.Nodes()) {
if (compile_failure_ops.find(node.OpType()) != compile_failure_ops.end())
count_compile_failure_nodes++;
}
ASSERT_EQ(count_compile_failure_nodes, 1);

// Execute the session, since the last node is Gemm, and its input 0 is all 0s
// So the result should be the bias initializer of the Gemm node
ExecuteMnist(session, true /* enable_custom_ep */);
}
}

} // namespace test
} // namespace onnxruntime