Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 46 additions & 20 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ KernelDefBuilder& BuildFusedKernelDef(KernelDefBuilder& builder, const onnxrunti
}

Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
// 1. Execution providers' capabilities are checked one by one.
// 2. All sub-graphs that an execution provider returns will be assigned to it if it's not assigned yet.
// 3. CPU execution provider is expected to be able to run any node and is the last one in execution provider preference.

if (providers_.Empty()) {
return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No provider specified.");
}
Expand All @@ -63,10 +68,17 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
std::shared_ptr<KernelRegistry> fused_kernel_registry = std::make_shared<KernelRegistry>();
// Partitioning <graph> based on provider preference and their capabilities.
auto kernel_registries = kernel_registry_mgr_.GetAllKernelRegistries();

std::vector<std::vector<std::unique_ptr<ComputeCapability>>> capabilities_of_all_providers;
GraphViewer graph_viewer(graph);
for (auto& provider : providers_) {
capabilities_of_all_providers.push_back(provider->GetCapability(graph_viewer, kernel_registries));
}

int i = 0;
for (auto& provider : providers_) {
auto capability_results = provider->GetCapability(GraphViewer(graph), kernel_registries);
int count = 0;
for (auto& capability : capability_results) {
for (auto& capability : capabilities_of_all_providers[i++]) {
if (nullptr == capability || nullptr == capability->sub_graph) {
continue;
}
Expand All @@ -78,30 +90,44 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {

auto node = graph.GetNode(capability->sub_graph->nodes[0]);
if (nullptr != node && node->GetExecutionProviderType().empty()) {
// The node was not fused or assigned. Assign it to this <provider>.
node->SetExecutionProviderType(provider->Type());
}
} else {
// The <provider> can run a fused <sub_graph> in the <graph>.
//
// Add fused node into <graph>
ORT_ENFORCE(nullptr != capability->sub_graph->GetMetaDef());
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
fused_node.SetExecutionProviderType(provider->Type());
auto fused_kernel_func = capability->fuse_kernel_function;
if (fused_kernel_func != nullptr) {
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
KernelDefBuilder builder;
BuildFusedKernelDef(builder, fused_node);
fused_kernel_registry->Register(builder, fused_kernel_func);

// Check whether any node in the <sub_graph> was already assigned.
bool sub_graph_available_for_assignment = true;
for (auto node_index : capability->sub_graph->nodes) {
auto node = graph.GetNode(node_index);
if (nullptr == node || !node->GetExecutionProviderType().empty()) {
// The node was fused or assigned, so that the whole sub-graph will not be assigned to this <provider>
// The assumption is that this <provider> can only run the sub-graph as a whole unit.
sub_graph_available_for_assignment = false;
break;
}
}

if (sub_graph_available_for_assignment) {
// Add fused node into <graph>
std::string node_name = provider->Type() + "_" + capability->sub_graph->GetMetaDef()->name + "_" + std::to_string(count++);
auto& fused_node = graph.FuseSubGraph(std::move(capability->sub_graph), node_name);
fused_node.SetExecutionProviderType(provider->Type());
auto fused_kernel_func = capability->fuse_kernel_function;
if (fused_kernel_func != nullptr) {
// build the kernel definition on the fly, and register it to the fused_kernel_regisitry.
KernelDefBuilder builder;
BuildFusedKernelDef(builder, fused_node);
fused_kernel_registry->Register(builder, fused_kernel_func);
}
}
}
}
// all done with this provider, resolve the graph before we move on to the next provider.
// This is needed since we create a new GraphViewer() that we pass into the next provider's GetCapability().
ORT_ENFORCE(graph.Resolve().IsOK());
}

ORT_ENFORCE(graph.Resolve().IsOK());

// To see if the node with no provider can be inlined. If one such nodes can be
// successfully inlined, we re-run the partitioner on the modified graph.
bool inline_flag = false;
Expand All @@ -126,10 +152,10 @@ Status GraphPartitioner::Partition(onnxruntime::Graph& graph) const {
this->Partition(graph);
}

//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
//But we will insert cast op to run the model, so skip the error checking here.
//If after graph transform phase, the node still not assigned, we will report error
//during kernel creation phase.
//For some cases, like fp16 on cpu, right now we don't have any kernel support that.
//But we will insert cast op to run the model, so skip the error checking here.
//If after graph transform phase, the node still not assigned, we will report error
//during kernel creation phase.
#ifdef COUNT_NON_CUDA_OPS
for (auto& node : graph.Nodes()) {
if (node.GetExecutionProviderType() != kCudaExecutionProvider &&
Expand Down