From 96fe212e4b4aa7f216a3abf565407798429263b3 Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 5 Jan 2026 05:04:37 -0800 Subject: [PATCH 1/3] Add Shape and Reshape kernels to example kernel EP --- cmake/onnxruntime_unittests.cmake | 4 + .../ep_kernel_registration.cc | 12 ++ .../kernels/reshape.cc | 149 ++++++++++++++++++ .../kernels/reshape.h | 24 +++ .../kernels/shape.cc | 97 ++++++++++++ .../kernels/shape.h | 23 +++ 6 files changed, 309 insertions(+) create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.h create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.cc create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.h diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 385a32bddfdfd..91f6582310072 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -2108,6 +2108,10 @@ if (onnxruntime_BUILD_SHARED_LIB AND "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/squeeze.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.h" diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc index b9518786f3a04..9de93c159f953 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep_kernel_registration.cc @@ -19,6 +19,18 @@ static const BuildKernelCreateInfoFn build_kernel_create_info_funcs[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Support Shape 21, 23, and 24. + // Note: end versions are inclusive. + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + // Support Reshape 21, 23, and 24. + // Note: end versions are inclusive. + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, }; size_t GetNumKernels() { diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc new file mode 100644 index 0000000000000..4bdead1a2ca04 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc @@ -0,0 +1,149 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "reshape.h" + +#include +#include +#include "utils.h" + +// ONNX Reshape version 21 +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Reshape, + kOnnxDomain, + /*start_version*/ 21, /*end_version (inclusive)*/ 22, + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddInputOutputAlias(0, 0) + .SetInputMemType(1, OrtMemTypeCPU)), + Reshape) + +// ONNX Reshape version 23 +ONNX_OPERATOR_KERNEL_EX( + Reshape, + kOnnxDomain, + /*version*/ 23, // Equivalent to start_version: 23, end_version: 23 + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddInputOutputAlias(0, 0) + .SetInputMemType(1, OrtMemTypeCPU)), + Reshape) + +// ONNX Reshape version 24 +ONNX_OPERATOR_KERNEL_EX( + Reshape, + kOnnxDomain, + /*version*/ 24, // Equivalent start_version: 24, end_version: 24 + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("shape", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .AddInputOutputAlias(0, 0) + .SetInputMemType(1, OrtMemTypeCPU)), + Reshape) + +Reshape::Reshape(const OrtKernelInfo* info, void* state, bool allow_zero, PrivateTag) + : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + info_{info}, + data_transfer_impl_{reinterpret_cast(state)}, + allow_zero_{allow_zero} { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; +} + +/*static*/ +OrtStatus* Reshape::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Ort::ConstKernelInfo kernel_info(info); + bool allow_zero = kernel_info.GetAttribute("allowzero") == 1; + + kernel = std::make_unique(info, state, allow_zero, PrivateTag{}); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +// Computes the requested shape for the reshape operation. +// Implementation is based on ReshapeHelper in onnxruntime/core/providers/cpu/tensor/reshape_helper.h +static OrtStatus* GetRequestedShape(gsl::span input_shape, bool allow_zero, + /*out*/ std::vector& requested_shape) { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + const OrtApi& ort_api = Ort::GetApi(); + + int64_t num_input_elems = 1; + for (auto dim_val : input_shape) { + num_input_elems *= dim_val; + } + RETURN_IF(num_input_elems == -1, ort_api, "Input tensor must not have dynamic (-1) dimensions."); + + size_t num_dims = requested_shape.size(); + int64_t unknown_dim = -1; + int64_t size = 1; + + for (size_t i = 0; i < num_dims; i++) { + RETURN_IF(requested_shape[i] < -1, ort_api, "A dimension cannot be less than -1"); + + if (requested_shape[i] == -1) { + RETURN_IF(unknown_dim != -1, ort_api, "At most one dimension can be -1"); + unknown_dim = static_cast(i); + } else { + if (!allow_zero && requested_shape[i] == 0) { + RETURN_IF(i >= input_shape.size(), ort_api, + "The dimension with value zero exceeds the dimension size of the input"); + requested_shape[i] = input_shape[i]; + } + + size *= requested_shape[i]; + } + } + + if (unknown_dim != -1) { + // Calculate unknown dimension. + RETURN_IF(size == 0 || (num_input_elems % size) != 0, ort_api, + "The input cannot be reshaped to the requested shape"); + requested_shape[unknown_dim] = num_input_elems / size; + } else { + // Check if the output shape is valid. + RETURN_IF(num_input_elems != size, ort_api, "The input cannot be reshaped to the requested shape"); + } + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +OrtStatus* ORT_API_CALL Reshape::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Reshape* reshape_kernel = static_cast(this_ptr); + static_cast(reshape_kernel->info_); // NOTE: Unused in this example. + + Ort::KernelContext kernel_context(kernel_ctx); + + // Input[0] has the data to reshape. + Ort::ConstValue input = kernel_context.GetInput(0); + auto type_shape_info = input.GetTensorTypeAndShapeInfo(); + std::vector input_shape = type_shape_info.GetShape(); + + // Input[1] has the requested shape for the reshape operation. + Ort::ConstValue shape_input = kernel_context.GetInput(1); + gsl::span shape_input_data; + std::vector final_shape; + + RETURN_IF_ERROR(GetValueDataAndShape(shape_input, shape_input_data, final_shape)); + RETURN_IF(final_shape.size() != 1, Ort::GetApi(), "A shape tensor must have one dimension"); + RETURN_IF_ERROR(GetRequestedShape(input_shape, reshape_kernel->allow_zero_, final_shape)); + + Ort::UnownedValue output = kernel_context.GetOutput(0, final_shape); + + // This kernel aliases the input and output, so a copy is not really necessary. + // CopyTensor() will not do a copy if the source and destination buffers are the same. + RETURN_IF_ERROR(CopyTensor(*reshape_kernel->data_transfer_impl_, input, output)); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +void ORT_API_CALL Reshape::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.h new file mode 100644 index 0000000000000..fd3350ca74d01 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../../plugin_ep_utils.h" + +class Reshape : public OrtKernelImpl { + private: + struct PrivateTag {}; + + public: + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; + Reshape(const OrtKernelInfo* info, void* state, bool allow_zero, PrivateTag); + + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + + private: + const OrtKernelInfo* info_; + OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp + bool allow_zero_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.cc new file mode 100644 index 0000000000000..3f0d4b0d83e66 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.cc @@ -0,0 +1,97 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "shape.h" + +#include +#include "utils.h" + +// ONNX Shape version 21 +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Shape, + kOnnxDomain, + /*start_version*/ 21, /*end_version (inclusive)*/ 22, + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("T1", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .SetOutputMemType(0, OrtMemTypeCPU)), + Shape) + +// ONNX Shape version 23 +ONNX_OPERATOR_KERNEL_EX( + Shape, + kOnnxDomain, + /*version*/ 23, // Equivalent to start_version: 23, end_version: 23 + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("T1", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .SetOutputMemType(0, OrtMemTypeCPU)), + Shape) + +// ONNX Shape version 24 +ONNX_OPERATOR_KERNEL_EX( + Shape, + kOnnxDomain, + /*version*/ 24, // Equivalent start_version: 24, end_version: 24 + (Ort::KernelDefBuilder() + .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("T1", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)) + .SetOutputMemType(0, OrtMemTypeCPU)), + Shape) + +Shape::Shape(const OrtKernelInfo* info, void* state, PrivateTag) + : OrtKernelImpl{}, // Initialize all OrtKernelImpl functions to NULL + info_{info}, + data_transfer_impl_{reinterpret_cast(state)} { + ort_version_supported = ORT_API_VERSION; + Compute = ComputeImpl; + Release = ReleaseImpl; +} + +/*static*/ +OrtStatus* Shape::Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Ort::ConstKernelInfo kernel_info(info); + + int64_t start = kernel_info.GetAttribute("start"); + int64_t end = 0; + Ort::Status status{Ort::GetApi().KernelInfoGetAttribute_int64(info, "end", &end)}; + + // This example kernel does not support shape slicing. + RETURN_IF(start != 0 || status.IsOK(), Ort::GetApi(), + "Example Shape kernel does not support non-default start/end attributes"); + + kernel = std::make_unique(info, state, PrivateTag{}); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +OrtStatus* ORT_API_CALL Shape::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + Shape* shape_kernel = static_cast(this_ptr); + static_cast(shape_kernel->info_); // NOTE: Unused in this example. + static_cast(shape_kernel->data_transfer_impl_); // NOTE: Unused in this example. + + Ort::KernelContext kernel_context(kernel_ctx); + + Ort::ConstValue input = kernel_context.GetInput(0); + auto type_shape_info = input.GetTensorTypeAndShapeInfo(); + std::vector input_shape = type_shape_info.GetShape(); + + std::vector output_shape = {static_cast(input_shape.size())}; + Ort::UnownedValue output = kernel_context.GetOutput(0, output_shape); + int64_t* output_data = output.GetTensorMutableData(); + + for (size_t i = 0; i < input_shape.size(); i++) { + output_data[i] = input_shape[i]; + } + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +/*static*/ +void ORT_API_CALL Shape::ReleaseImpl(OrtKernelImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.h new file mode 100644 index 0000000000000..39b8d4004560e --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/shape.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "../../plugin_ep_utils.h" + +class Shape : public OrtKernelImpl { + private: + struct PrivateTag {}; + + public: + static OrtStatus* Create(const OrtKernelInfo* info, void* state, /*out*/ std::unique_ptr& kernel) noexcept; + Shape(const OrtKernelInfo* info, void* state, PrivateTag); + + // Static functions assigned to the OrtKernelImpl fields: + static OrtStatus* ORT_API_CALL ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtKernelImpl* this_ptr) noexcept; + + private: + const OrtKernelInfo* info_; + OrtDataTransferImpl* data_transfer_impl_; // Custom state passed from OrtEp +}; From 9884cf147b71c8644738dd7c21fdd6808887d1bd Mon Sep 17 00:00:00 2001 From: adrianlizarraga Date: Mon, 5 Jan 2026 10:18:04 -0800 Subject: [PATCH 2/3] [EP ABI] GetCpuPreferredNodes for kernel-based plugin EPs --- cmake/onnxruntime_unittests.cmake | 2 + .../example_plugin_ep_kernel_registry/ep.cc | 24 +- .../get_capability_utils.cc | 226 ++++++++++++++++++ .../get_capability_utils.h | 14 ++ .../kernels/relu.cc | 36 ++- .../kernels/reshape.cc | 12 +- .../kernels/utils.h | 20 ++ onnxruntime/test/autoep/test_execution.cc | 59 +++++ .../plugin_kernel_ep_cpu_preferred_nodes.onnx | Bin 0 -> 368 bytes .../plugin_kernel_ep_cpu_preferred_nodes.py | 46 ++++ 10 files changed, 418 insertions(+), 21 deletions(-) create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc create mode 100644 onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.h create mode 100644 onnxruntime/test/testdata/plugin_kernel_ep_cpu_preferred_nodes.onnx create mode 100644 onnxruntime/test/testdata/plugin_kernel_ep_cpu_preferred_nodes.py diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 91f6582310072..81286b420103e 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -2100,6 +2100,8 @@ if (onnxruntime_BUILD_SHARED_LIB AND "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_lib_entry.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_factory.cc" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.h" + "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.h" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep.cc" "${TEST_SRC_DIR}/autoep/library/example_plugin_ep_kernel_registry/ep_allocator.h" diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc index 7b939c0685237..8f8ea9ff68110 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/ep.cc @@ -9,9 +9,11 @@ #include #include #include +#include #include #include "ep_factory.h" +#include "get_capability_utils.h" #include "../plugin_ep_utils.h" ExampleKernelEp::ExampleKernelEp(ExampleKernelEpFactory& factory, const OrtLogger& logger) @@ -60,12 +62,19 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons } // Collect candidate nodes that this EP may support. - std::vector candidate_nodes; + std::vector candidate_nodes; for (const auto& node : all_nodes) { std::string op_type = node.GetOperatorType(); - if (op_type == "Relu" || op_type == "Squeeze") { + const OrtKernelDef* kernel_def = nullptr; + RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def)); + + if (kernel_def == nullptr) { + continue; // Does not have a registered kernel for this node. + } + + if (op_type == "Relu" || op_type == "Squeeze" || op_type == "Shape" || op_type == "Reshape") { candidate_nodes.push_back(node); } else if (op_type == "Mul") { std::vector inputs = node.GetInputs(); @@ -86,12 +95,13 @@ OrtStatus* ORT_API_CALL ExampleKernelEp::GetCapabilityImpl(OrtEp* this_ptr, cons } } - // Mark candidate nodes as supported if we have a registered kernel. - for (const auto& node : candidate_nodes) { - const OrtKernelDef* kernel_def = nullptr; - RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_LookUpKernel(graph_support_info, node, &kernel_def)); + // Get subset of candidate nodes that would be better to offload to CPU. + std::unordered_set cpu_nodes; + RETURN_IF_ERROR(GetCpuPreferredNodes(*ort_graph, *graph_support_info, ep->logger_, candidate_nodes, cpu_nodes)); - if (kernel_def != nullptr) { + // Mark candidate nodes as supported. + for (const auto& node : candidate_nodes) { + if (cpu_nodes.count(node) == 0) { RETURN_IF_ERROR(ep->ep_api_.EpGraphSupportInfo_AddSingleNode(graph_support_info, node)); } } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc new file mode 100644 index 0000000000000..36bfe59e4c2d6 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc @@ -0,0 +1,226 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "get_capability_utils.h" + +#include +#include +#include +#include + +using NodeId = size_t; +constexpr int64_t kSmallInitializerThreshold = 100; + +constexpr static inline bool MemTypeOnCpuExplicitly(OrtMemType mem_type) { + return mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput; +} + +// Get all output nodes that consume an output from the given node. +static OrtStatus* GetOutputNodes(gsl::span node_outputs, std::vector& result) { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + std::vector output_nodes; + output_nodes.reserve(node_outputs.size()); // May have more + + // Gather the OrtNode consumers of every output. + for (Ort::ConstValueInfo output : node_outputs) { + if (output == nullptr) continue; // Skip missing optional output + + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); + } + } + + result = std::move(output_nodes); + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + +// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies. +// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc +OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info, + const OrtLogger& logger, gsl::span tentative_nodes, + /*out*/ std::unordered_set& cpu_preferred_nodes) { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + const OrtApi& ort_api = Ort::GetApi(); + const OrtEpApi& ep_api = Ort::GetEpApi(); + Ort::ConstGraph graph{&ort_graph}; + std::vector ordered_nodes = graph.GetNodes(); + + if (ordered_nodes.empty()) { + return nullptr; + } + + std::unordered_map node_id_to_node; + std::unordered_map node_id_to_order_map; + for (size_t i = 0, num_nodes = ordered_nodes.size(); i < num_nodes; i++) { + NodeId node_id = ordered_nodes[i].GetId(); + node_id_to_node[node_id] = ordered_nodes[i]; + node_id_to_order_map[node_id] = i; + } + + // If return false, n1 will be output first; If return true, n2 will be output first + auto greater_order_comp = [&](const NodeId node_id1, const NodeId node_id2) { + return node_id_to_order_map[node_id1] > node_id_to_order_map[node_id2]; + }; + std::priority_queue, decltype(greater_order_comp)> candidates(greater_order_comp); + std::unordered_set cpu_output_args; + + std::unordered_set provider_nodes; + provider_nodes.reserve(tentative_nodes.size()); + + std::unordered_map node_to_kernel; + node_to_kernel.reserve(tentative_nodes.size()); + + for (const OrtNode* ort_node : tentative_nodes) { + Ort::ConstNode node(ort_node); + NodeId node_id = node.GetId(); + + provider_nodes.insert(node_id); + + // Expect at least one registry has a target provider's kernel for this node. + const OrtKernelDef* ort_kernel_def = nullptr; + RETURN_IF_ERROR(ep_api.EpGraphSupportInfo_LookUpKernel(&graph_support_info, node, &ort_kernel_def)); + RETURN_IF(ort_kernel_def == nullptr, ort_api, "Must have a registered kernel definition on the target EP"); + + Ort::ConstKernelDef kernel_def(ort_kernel_def); + node_to_kernel.insert({node_id, kernel_def}); + + // Find all the direct consumers of CPU tensors. + std::vector outputs = node.GetOutputs(); + for (size_t out_index = 0; out_index < outputs.size(); out_index++) { + Ort::ConstValueInfo output = outputs[out_index]; + if (output == nullptr) continue; // Skip missing optional output + + bool is_output_on_cpu = MemTypeOnCpuExplicitly(kernel_def.GetOutputMemType(out_index)); + if (is_output_on_cpu) { + cpu_output_args.insert(output); + + auto consumer_infos = output.GetConsumers(); + for (const auto& consumer_info : consumer_infos) { + candidates.push(consumer_info.node.GetId()); + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_INFO, "Candidate for fallback CPU execution: %s\n", + consumer_info.node.GetName().c_str()); + } + } + } + } + + std::unordered_set visited; + visited.reserve(candidates.size()); + + std::unordered_set cpu_nodes; + cpu_nodes.reserve(candidates.size()); + + // The algo below is trying to identity a subgraph that only depends on cpu tensors. + // Usually it is a subgraph that doing shape calculation based on a GPU tensor, then reshape it back. + // The detail: + // for each candidate, if one of its input is a cpu tensor and the Non-CPU kernel doesn't mark it as cpu input, + // force the node to CPU to avoid memory cpu and add its output to the small cpu tensors. + while (!candidates.empty()) { + NodeId cur = candidates.top(); + candidates.pop(); + + auto p = visited.insert(cur); + if (!p.second) { + continue; + } + + auto node_iter = node_id_to_node.find(cur); + RETURN_IF(node_iter == node_id_to_node.end(), ort_api, "Unable to get OrtNode for a given node ID"); + Ort::ConstNode node = node_iter->second; + + if (provider_nodes.find(cur) == provider_nodes.end()) { + // Nodes not in provider_nodes are either have EP assigned or no kernel found on target EP. + // we assume these nodes will fallback to CPU, so add all direct consumers of all outputs to candidates. + std::string ep_name = node.GetEpName(); + if (ep_name.empty() || ep_name == "CPUExecutionProvider") { + std::vector outputs = node.GetOutputs(); + + for (Ort::ConstValueInfo output : outputs) { + if (output == nullptr) continue; // Skip missing optional output + cpu_output_args.insert(output); + } + + std::vector output_nodes; + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); + + for (Ort::ConstNode downstream_node : output_nodes) { + candidates.push(downstream_node.GetId()); + } + } + continue; + } + + std::vector inputs = node.GetInputs(); + bool place_in_cpu = true; + + for (size_t i = 0; i < inputs.size(); i++) { + Ort::ConstValueInfo input = inputs[i]; + if (input == nullptr) continue; // Skip missing optional input + + // skip placing on CPU if the data typs is float16 or bfloat16 or + // float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz or float4e2m1 + Ort::ConstTypeInfo type_info = input.TypeInfo(); + auto type_shape_info = type_info.GetTensorTypeAndShapeInfo(); + auto elem_type = type_shape_info.GetElementType(); + if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ || + elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1) { + place_in_cpu = false; + break; + } + + bool is_small_initializer = input.IsConstantInitializer() && + type_shape_info.GetElementCount() <= kSmallInitializerThreshold; + + // Allow placing on CPU if it's a small initializer or graph input + if (is_small_initializer || input.IsRequiredGraphInput() || input.IsOptionalGraphInput()) { + continue; + } + + // the input is not a CPU tensor + if (cpu_output_args.find(input) == cpu_output_args.end()) { + place_in_cpu = false; + break; + } + + // input is a CPU tensor, but it's intended to be consumed as CPU input by the target EP + bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetOutputMemType(i)); + if (is_input_on_cpu) { + place_in_cpu = false; + break; + } + } + + if (place_in_cpu) { + cpu_nodes.insert(node); + ORT_CXX_LOGF(Ort::Logger(&logger), ORT_LOGGING_LEVEL_WARNING, + "EP optimization: Force fallback to CPU execution for node %s because the CPU execution path " + "is deemed faster than overhead involved with execution on other EPs capable of executing " + "this node.\n", + node.GetName().c_str()); + + std::vector outputs = node.GetOutputs(); + for (Ort::ConstValueInfo output : outputs) { + if (output == nullptr) continue; // Skip missing optional output + cpu_output_args.insert(output); + } + + std::vector output_nodes; + RETURN_IF_ERROR(GetOutputNodes(outputs, output_nodes)); + + for (Ort::ConstNode downstream_node : output_nodes) { + candidates.push(downstream_node.GetId()); + } + } + } + + cpu_preferred_nodes = std::move(cpu_nodes); + + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.h new file mode 100644 index 0000000000000..59df34c1d8779 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "../plugin_ep_utils.h" + +// Returns nodes that should be assigned to CPU EP instead of this example EP to avoid costly I/O copies. +// Based on GetCpuPreferredNodes from onnxruntime/core/framework/fallback_cpu_capability.cc +OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo& graph_support_info, + const OrtLogger& logger, gsl::span tentative_nodes, + /*out*/ std::unordered_set& cpu_preferred_nodes); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc index 89f52c4b53dc3..782393595f0ae 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc @@ -15,7 +15,8 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, /*version*/ 14, // Equivalent to start_version: 14, end_version: 14 (inclusive) (Ort::KernelDefBuilder() - .AddTypeConstraint("T", GetTensorType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)) + .AddTypeConstraint("T", GetTensorTypes({ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64})) .AddInputOutputMutableAlias(0, 0)), Relu) @@ -36,6 +37,23 @@ OrtStatus* Relu::Create(const OrtKernelInfo* info, void* state, /*out*/ std::uni EXCEPTION_TO_RETURNED_STATUS_END } +template +static OrtStatus* ApplyRelu(Ort::KernelContext kernel_context) noexcept { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + gsl::span input0; + std::vector shape0; + RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); + + Ort::UnownedValue output = kernel_context.GetOutput(0, shape0); + T* output_data = output.GetTensorMutableData(); + + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = std::max(static_cast(0), input0[i]); + } + return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END +} + /*static*/ OrtStatus* ORT_API_CALL Relu::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelContext* kernel_ctx) noexcept { EXCEPTION_TO_RETURNED_STATUS_BEGIN @@ -43,15 +61,15 @@ OrtStatus* ORT_API_CALL Relu::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelCont Ort::KernelContext kernel_context(kernel_ctx); static_cast(relu_kernel->info_); // NOTE: Unused in this example. - gsl::span input0; - std::vector shape0; - RETURN_IF_ERROR(GetKernelInputDataAndShape(kernel_context, 0, input0, shape0)); + Ort::ConstValue input = kernel_context.GetInput(0); + auto type_shape = input.GetTensorTypeAndShapeInfo(); + auto elem_type = type_shape.GetElementType(); - Ort::UnownedValue output = kernel_context.GetOutput(0, shape0); - float* output_data = output.GetTensorMutableData(); - - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = std::max(0.0f, input0[i]); + if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { + ApplyRelu(kernel_context); + } else { + assert(elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); + ApplyRelu(kernel_context); } return nullptr; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc index 4bdead1a2ca04..2487ca2ecda79 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/reshape.cc @@ -128,13 +128,15 @@ OrtStatus* ORT_API_CALL Reshape::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelC // Input[1] has the requested shape for the reshape operation. Ort::ConstValue shape_input = kernel_context.GetInput(1); gsl::span shape_input_data; - std::vector final_shape; + std::vector shape_input_shape; - RETURN_IF_ERROR(GetValueDataAndShape(shape_input, shape_input_data, final_shape)); - RETURN_IF(final_shape.size() != 1, Ort::GetApi(), "A shape tensor must have one dimension"); - RETURN_IF_ERROR(GetRequestedShape(input_shape, reshape_kernel->allow_zero_, final_shape)); + RETURN_IF_ERROR(GetValueDataAndShape(shape_input, shape_input_data, shape_input_shape)); + RETURN_IF(shape_input_shape.size() != 1, Ort::GetApi(), "A shape tensor must have one dimension"); - Ort::UnownedValue output = kernel_context.GetOutput(0, final_shape); + std::vector output_shape(shape_input_data.begin(), shape_input_data.end()); + RETURN_IF_ERROR(GetRequestedShape(input_shape, reshape_kernel->allow_zero_, output_shape)); + + Ort::UnownedValue output = kernel_context.GetOutput(0, output_shape); // This kernel aliases the input and output, so a copy is not really necessary. // CopyTensor() will not do a copy if the source and destination buffers are the same. diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h index 506392abb6149..3c6a6c9b334a6 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h @@ -3,6 +3,7 @@ #pragma once +#include #include "../../plugin_ep_utils.h" /// @@ -18,6 +19,25 @@ inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) { return result; } +/// +/// Gets OrtDataTypes for the given tensor types. Throws on error. +/// +/// +/// +inline std::vector GetTensorTypes(const std::vector& elem_types) { + const OrtEpApi& ep_api = Ort::GetEpApi(); + std::vector result; + result.reserve(elem_types.size()); + + for (auto elem_type : elem_types) { + const OrtDataType* tensor_data_type = nullptr; + Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &tensor_data_type)); + result.push_back(tensor_data_type); + } + + return result; +} + /// /// Copy a tensor using a OrtDataTransferImpl instance. Used by kernel implementations to copy /// tensors that my reside on different devices. diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index bb391bb0bca23..7154aee549380 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -282,5 +282,64 @@ TEST(OrtEpLibrary, KernelPluginEp_Inference) { gsl::span output_span(output_data, 6); EXPECT_THAT(output_span, ::testing::ElementsAre(4, 0, 24, 0, 0, 84)); } + +TEST(OrtEpLibrary, KernelPluginEp_OffloadPreferredCpuNodes) { + RegisteredEpDeviceUniquePtr example_kernel_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_kernel_registry_info, + example_kernel_ep)); + Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get()); + + { + Ort::SessionOptions session_options; + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_DISABLE_ALL); + session_options.AddConfigEntry(kOrtSessionOptionsDisableCPUEPFallback, "1"); // Fail if any node assigned to CPU EP. + + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + // The example kernel EP supports all operator types in this model. + // However, this model has a subgraph with nodes that should be offloaded to CPU: + // Shape -> [Nodes to offload] -> Reshape + // Expect failure because we also disabled CPU EP fallback via session options. + try { + Ort::Session session(*ort_env, ORT_TSTR("testdata/plugin_kernel_ep_cpu_preferred_nodes.onnx"), session_options); + FAIL(); // Should not get here! + } catch (const Ort::Exception& excpt) { + ASSERT_EQ(excpt.GetOrtErrorCode(), ORT_FAIL); + ASSERT_THAT(excpt.what(), testing::HasSubstr("fallback to CPU EP has been explicitly disabled")); + } + } + + // Allow nodes to fallback to CPU and run inference. + // The example kernel EP will offload a subgraph that processes a shape to CPU. + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/plugin_kernel_ep_cpu_preferred_nodes.onnx"), session_options); + // Create inputs + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::array a_shape = {3, 2}; + std::array a_data = {1.f, -2.f, 3.f, 4.f, -5.f, 6.f}; + + std::vector ort_inputs{}; + ort_inputs.emplace_back( + Ort::Value::CreateTensor(memory_info, a_data.data(), a_data.size(), a_shape.data(), a_shape.size())); + + std::array ort_input_names{"A"}; + + // Run session and get outputs + std::array output_names{"B"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(2.f, -4.f, 6.f, 8.f, -10.f, 12.f)); + } +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/testdata/plugin_kernel_ep_cpu_preferred_nodes.onnx b/onnxruntime/test/testdata/plugin_kernel_ep_cpu_preferred_nodes.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a277021bdd9efd658e2192b2b795e839afb9f14a GIT binary patch literal 368 zcmZus!3u&v6x3Do^^%C~V$wx`s9A2H8ED}Ffeas7~X)f zdlB5Ad_XuBdEK0!1;q!;!h$s=Yb>qhuC4eiLk9(i8g&wtai7z>%eA#rE8`XiKAlx@ z#`^H&bGWt|a1mXx7R4ZyJ!@LUeYGBRl%gNq^YB(hW#TxX09nYkDM`})AH46d%aZS; ob2LDvfCLl{a^<2o7Spx((~oP7#vuaeRBh|h)o_7ed1PU}0eAjbnE(I) literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/plugin_kernel_ep_cpu_preferred_nodes.py b/onnxruntime/test/testdata/plugin_kernel_ep_cpu_preferred_nodes.py new file mode 100644 index 0000000000000..09506b727860e --- /dev/null +++ b/onnxruntime/test/testdata/plugin_kernel_ep_cpu_preferred_nodes.py @@ -0,0 +1,46 @@ +from onnx import TensorProto, checker, helper, save, shape_inference + +# A --> Mul --> Shape --> Relu --> Reshape(mul_output) --> B +graph_proto = helper.make_graph( + nodes=[ + helper.make_node( + "Mul", + inputs=["A", "ConstTwo"], + outputs=["mul_output"], + name="mul_0", + ), + helper.make_node( + "Shape", + inputs=["mul_output"], + outputs=["shape_output"], + name="shape_0", + ), + helper.make_node( + "Relu", + inputs=["shape_output"], + outputs=["relu_output"], + name="relu_0", + ), + helper.make_node( + "Reshape", + inputs=["mul_output", "relu_output"], + outputs=["B"], + name="reshape_0", + ), + ], + name="Main_graph", + inputs=[ + helper.make_tensor_value_info("A", TensorProto.FLOAT, [3, 2]), + ], + outputs=[ + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3, 2]), + ], + initializer=[ + helper.make_tensor("ConstTwo", TensorProto.FLOAT, [3, 2], [2.0] * 6), + ], +) + +model = helper.make_model(graph_proto) +model = shape_inference.infer_shapes(model) +checker.check_model(model, True) +save(model, "plugin_kernel_ep_cpu_preferred_nodes.onnx") From 4ac370c351ad5a87cc33f4247ca5af1a80ca56b8 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Mon, 5 Jan 2026 10:55:16 -0800 Subject: [PATCH 3/3] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../get_capability_utils.cc | 6 +++--- .../example_plugin_ep_kernel_registry/kernels/relu.cc | 4 ++-- .../example_plugin_ep_kernel_registry/kernels/utils.h | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc index 36bfe59e4c2d6..20db41f3cf44d 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/get_capability_utils.cc @@ -158,8 +158,8 @@ OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo Ort::ConstValueInfo input = inputs[i]; if (input == nullptr) continue; // Skip missing optional input - // skip placing on CPU if the data typs is float16 or bfloat16 or - // float8e4m3fn, float8e4m3fnuz, floate5m2, floate5m2fnuz or float4e2m1 + // skip placing on CPU if the data types is float16 or bfloat16 or + // float8e4m3fn, float8e4m3fnuz, float8e5m2, float8e5m2fnuz or float4e2m1 Ort::ConstTypeInfo type_info = input.TypeInfo(); auto type_shape_info = type_info.GetTensorTypeAndShapeInfo(); auto elem_type = type_shape_info.GetElementType(); @@ -189,7 +189,7 @@ OrtStatus* GetCpuPreferredNodes(const OrtGraph& ort_graph, OrtEpGraphSupportInfo } // input is a CPU tensor, but it's intended to be consumed as CPU input by the target EP - bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetOutputMemType(i)); + bool is_input_on_cpu = MemTypeOnCpuExplicitly(node_to_kernel[cur].GetInputMemType(i)); if (is_input_on_cpu) { place_in_cpu = false; break; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc index 782393595f0ae..92b0e508ecdd6 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/relu.cc @@ -66,10 +66,10 @@ OrtStatus* ORT_API_CALL Relu::ComputeImpl(OrtKernelImpl* this_ptr, OrtKernelCont auto elem_type = type_shape.GetElementType(); if (elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) { - ApplyRelu(kernel_context); + RETURN_IF_ERROR(ApplyRelu(kernel_context)); } else { assert(elem_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64); - ApplyRelu(kernel_context); + RETURN_IF_ERROR(ApplyRelu(kernel_context)); } return nullptr; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h index 3c6a6c9b334a6..28a4e2ceca9f2 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep_kernel_registry/kernels/utils.h @@ -40,7 +40,7 @@ inline std::vector GetTensorTypes(const std::vector /// Copy a tensor using a OrtDataTransferImpl instance. Used by kernel implementations to copy -/// tensors that my reside on different devices. +/// tensors that may reside on different devices. /// /// ///