Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -2100,6 +2100,8 @@ if (onnxruntime_BUILD_SHARED_LIB AND
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h"
Expand All @@ -2108,6 +2110,10 @@ if (onnxruntime_BUILD_SHARED_LIB AND
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc"
"${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
#include <memory>
#include <optional>
#include <string>
#include <unordered_set>
#include <vector>

#include "ep_factory.h"
#include "get_capability_utils.h"
#include "../plugin_ep_utils.h"

ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger)
Expand Down Expand Up @@ -60,12 +62,19 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons
}

// Collect candidate nodes that this EP may support.
std::vector<Ort::ConstNode> candidate_nodes;
std::vector<const OrtNode*> candidate_nodes;

for (const auto& node : all_nodes) {
std::string op_type = node.GetOperatorType();

if (op_type == "Relu" || op_type == "Squeeze") {
const OrtKernelDef* kernel_def = nullptr;
RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def));

if (kernel_def == nullptr) {
continue; // Does not have a registered kernel for this node.
}

if (op_type == "Relu" || op_type == "Squeeze" || op_type == "Shape" || op_type == "Reshape") {
candidate_nodes.push_back(node);
} else if (op_type == "Mul") {
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
Expand All @@ -86,12 +95,13 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons
}
}

// Mark candidate nodes as supported if we have a registered kernel.
for (const auto& node : candidate_nodes) {
const OrtKernelDef* kernel_def = nullptr;
RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def));
// Get subset of candidate nodes that would be better to offload to CPU.
std::unordered_set<const OrtNode*> cpu_nodes;
RETURN_IF_ERROR(GetCpuPreferredNodes(*ort_graph, *graph_support_info, ep->logger_, candidate_nodes, cpu_nodes));

if (kernel_def != nullptr) {
// Mark candidate nodes as supported.
for (const auto& node : candidate_nodes) {
if (cpu_nodes.count(node) == 0) {
RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_AddSingleNode(graph_support_info, node));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = {
BuildKernelCreateInfo<class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 21, 22, Squeeze)>,
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 23, Squeeze)>,
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 24, Squeeze)>,

// Support Shape 21, 23, and 24.
// Note: end versions are inclusive.
BuildKernelCreateInfo<class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 21, 22, Shape)>,
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 23, Shape)>,
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 24, Shape)>,

// Support Reshape 21, 23, and 24.
// Note: end versions are inclusive.
BuildKernelCreateInfo<class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kOnnxDomain, 21, 22, Reshape)>,
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 23, Reshape)>,
BuildKernelCreateInfo<class ONNX_OPERATOR_KERNEL_CLASS_NAME(kOnnxDomain, 24, Reshape)>,
};

size_t GetNumKernels() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "get_capability_utils.h"

#include <queue>
#include <unordered_map>
#include <unordered_set>
#include <vector>

using NodeId = size_t;
constexpr int64_t kSmallInitializerThreshold = 100;

constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) {
return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput;
}

// Get all output nodes that consume an output from the given node.
static OrtStatus* GetOutputNodes(gsl::span<Ort::ConstValueInfo const> node_outputs, std::vector<Ort::ConstNode>& result) {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
std::vector<Ort::ConstNode> output_nodes;
output_nodes.reserve(node_outputs.size()); // May have more

// Gather the OrtNode consumers of every output.
for (Ort::ConstValueInfo output : node_outputs) {
if (output == nullptr) continue; // Skip missing optional output

auto consumers_info = output.GetConsumers();
for (const auto& consumer : consumers_info) {
output_nodes.push_back(consumer.node);
}
}

result = std::move(output_nodes);
return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}

// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies.
// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc
OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info,
const OrtLogger& logger, gsl::span<const OrtNode* const> tentative_nodes,
/*out*/ std::unordered_set<const OrtNode*>& cpu_preferred_nodes) {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
const OrtApi& ort_api = Ort::GetApi();
const OrtEpApi& ep_api = Ort::GetEpApi();
Ort::ConstGraph graph{&ort_graph};
std::vector<Ort::ConstNode> ordered_nodes = graph.GetNodes();

if (ordered_nodes.empty()) {
return nullptr;
}

std::unordered_map<NodeId, Ort::ConstNode> node_id_to_node;
std::unordered_map<NodeId, size_t> node_id_to_order_map;
for (size_t i = 0, num_nodes = ordered_nodes.size(); i < num_nodes; i++) {
NodeId node_id = ordered_nodes[i].GetId();
node_id_to_node[node_id] = ordered_nodes[i];
node_id_to_order_map[node_id] = i;
}

// If return false, n1 will be output first; If return true, n2 will be output first
auto greater_order_comp = [&](const NodeId node_id1, const NodeId node_id2) {
return node_id_to_order_map[node_id1] > node_id_to_order_map[node_id2];
};
std::priority_queue<NodeId, std::vector<NodeId>, decltype(greater_order_comp)> candidates(greater_order_comp);
std::unordered_set<const OrtValueInfo*> cpu_output_args;

std::unordered_set<NodeId> provider_nodes;
provider_nodes.reserve(tentative_nodes.size());

std::unordered_map<NodeId, Ort::ConstKernelDef> node_to_kernel;
node_to_kernel.reserve(tentative_nodes.size());

for (const OrtNode* ort_node : tentative_nodes) {
Ort::ConstNode node(ort_node);
NodeId node_id = node.GetId();

provider_nodes.insert(node_id);

// Expect at least one registry has a target provider's kernel for this node.
const OrtKernelDef* ort_kernel_def = nullptr;
RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel(&graph_support_info, node, &ort_kernel_def));
RETURN_IF(ort_kernel_def == nullptr, ort_api, "Must have a registered kernel definition on the target EP");

Ort::ConstKernelDef kernel_def(ort_kernel_def);
node_to_kernel.insert({node_id, kernel_def});

// Find all the direct consumers of CPU tensors.
std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs();
for (size_t out_index = 0; out_index < outputs.size(); out_index++) {
Ort::ConstValueInfo output = outputs[out_index];
if (output == nullptr) continue; // Skip missing optional output

bool is_output_on_cpu = MemTypeOnCpuExplicitly(kernel_def.GetOutputMemType(out_index));
if (is_output_on_cpu) {
cpu_output_args.insert(output);

auto consumer_infos = output.GetConsumers();
for (const auto& consumer_info : consumer_infos) {
candidates.push(consumer_info.node.GetId());
ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_INFO, "Candidate for fallback CPU execution: %s\n",
consumer_info.node.GetName().c_str());
}
}
}
}

std::unordered_set<NodeId> visited;
visited.reserve(candidates.size());

std::unordered_set<const OrtNode*> cpu_nodes;
cpu_nodes.reserve(candidates.size());

// The algo below is trying to identity a subgraph that only depends on cpu tensors.
// Usually it is a subgraph that doing shape calculation based on a GPU tensor, then reshape it back.
// The detail:
// for each candidate, if one of its input is a cpu tensor and the Non-CPU kernel doesn't mark it as cpu input,
// force the node to CPU to avoid memory cpu and add its output to the small cpu tensors.
while (!candidates.empty()) {
NodeId cur = candidates.top();
candidates.pop();

auto p = visited.insert(cur);
if (!p.second) {
continue;
}

auto node_iter = node_id_to_node.find(cur);
RETURN_IF(node_iter == node_id_to_node.end(), ort_api, "Unable to get OrtNode for a given node ID");
Ort::ConstNode node = node_iter->second;

if (provider_nodes.find(cur) == provider_nodes.end()) {
// Nodes not in provider_nodes are either have EP assigned or no kernel found on target EP.
// we assume these nodes will fallback to CPU, so add all direct consumers of all outputs to candidates.
std::string ep_name = node.GetEpName();
if (ep_name.empty() || ep_name == "CPUExecutionProvider") {
std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs();

for (Ort::ConstValueInfo output : outputs) {
if (output == nullptr) continue; // Skip missing optional output
cpu_output_args.insert(output);
}

std::vector<Ort::ConstNode> output_nodes;
RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes));

for (Ort::ConstNode downstream_node : output_nodes) {
candidates.push(downstream_node.GetId());
}
}
continue;
}

std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
bool place_in_cpu = true;

for (size_t i = 0; i < inputs.size(); i++) {
Ort::ConstValueInfo input = inputs[i];
if (input == nullptr) continue; // Skip missing optional input

// skip placing on CPU if the data types is float16 or bfloat16 or
// float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz or float4e2m1
Ort::ConstTypeInfo type_info = input.TypeInfo();
auto type_shape_info = type_info.GetTensorTypeAndShapeInfo();
auto elem_type = type_shape_info.GetElementType();
if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 ||
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 ||
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN ||
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ ||
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 ||
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ ||
elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1) {
place_in_cpu = false;
break;
}

bool is_small_initializer = input.IsConstantInitializer() &&
type_shape_info.GetElementCount() <= kSmallInitializerThreshold;

// Allow placing on CPU if it's a small initializer or graph input
if (is_small_initializer || input.IsRequiredGraphInput() || input.IsOptionalGraphInput()) {
continue;
}

// the input is not a CPU tensor
if (cpu_output_args.find(input) == cpu_output_args.end()) {
place_in_cpu = false;
break;
}

// input is a CPU tensor, but it's intended to be consumed as CPU input by the target EP
bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetInputMemType(i));
if (is_input_on_cpu) {
place_in_cpu = false;
break;
}
}

if (place_in_cpu) {
cpu_nodes.insert(node);
ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_WARNING,
"EP optimization: Force fallback to CPU execution for node %s because the CPU execution path "
"is deemed faster than overhead involved with execution on other EPs capable of executing "
"this node.\n",
node.GetName().c_str());

std::vector<Ort::ConstValueInfo> outputs = node.GetOutputs();
for (Ort::ConstValueInfo output : outputs) {
if (output == nullptr) continue; // Skip missing optional output
cpu_output_args.insert(output);
}

std::vector<Ort::ConstNode> output_nodes;
RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes));

for (Ort::ConstNode downstream_node : output_nodes) {
candidates.push(downstream_node.GetId());
}
}
}

cpu_preferred_nodes = std::move(cpu_nodes);

return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <gsl/span>
#include <unordered_set>
#include "../plugin_ep_utils.h"

// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies.
// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc
OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info,
const OrtLogger& logger, gsl::span<const OrtNode* const> tentative_nodes,
/*out*/ std::unordered_set<const OrtNode*>& cpu_preferred_nodes);
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ ONNX_OPERATOR_KERNEL_EX(
kOnnxDomain,
/*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive)
(Ort::KernelDefBuilder()
.AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT))
.AddTypeConstraint("T", GetTensorTypes({ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64}))
.AddInputOutputMutableAlias(0, 0)),
Relu)

Expand All @@ -36,22 +37,39 @@ OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::uni
EXCEPTION_TO_RETURNED_STATUS_END
}

template <typename T>
static OrtStatus* ApplyRelu(Ort::KernelContext kernel_context) noexcept {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
gsl::span<const T> input0;
std::vector<int64_t> shape0;
RETURN_IF_ERROR(GetKernelInputDataAndShape<T>(kernel_context, 0, input0, shape0));

Ort::UnownedValue output = kernel_context.GetOutput(0, shape0);
T* output_data = output.GetTensorMutableData<T>();

for (size_t i = 0; i < input0.size(); ++i) {
output_data[i] = std::max(static_cast<T>(0), input0[i]);
}
return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}

/*static*/
OrtStatus* ORT_API_CALL Relu::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
Relu* relu_kernel = static_cast<Relu*>(this_ptr);
Ort::KernelContext kernel_context(kernel_ctx);
static_cast<void>(relu_kernel->info_); // NOTE: Unused in this example.

gsl::span<const float> input0;
std::vector<int64_t> shape0;
RETURN_IF_ERROR(GetKernelInputDataAndShape<float>(kernel_context, 0, input0, shape0));
Ort::ConstValue input = kernel_context.GetInput(0);
auto type_shape = input.GetTensorTypeAndShapeInfo();
auto elem_type = type_shape.GetElementType();

Ort::UnownedValue output = kernel_context.GetOutput(0, shape0);
float* output_data = output.GetTensorMutableData<float>();

for (size_t i = 0; i < input0.size(); ++i) {
output_data[i] = std::max(0.0f, input0[i]);
if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) {
RETURN_IF_ERROR(ApplyRelu<float>(kernel_context));
} else {
assert(elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64);
RETURN_IF_ERROR(ApplyRelu<int64_t>(kernel_context));
}

return nullptr;
Expand Down
Loading
Loading