From 25c1567b530cff0dcb062df513c5323ab94ade5c Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Sun, 15 Mar 2026 14:20:20 -0700 Subject: [PATCH] Plugin EP: Fix bug that incorrectly assigned duplicate MetDef IDs to fused nodes that live in different GraphViews (e.g., different branch of an If node) --- .../ep_plugin_provider_interfaces.cc | 6 +- .../plugin_ep/ep_plugin_provider_interfaces.h | 8 +++ .../autoep/library/example_plugin_ep/ep.cc | 69 +++++++++---------- .../autoep/library/example_plugin_ep/ep.h | 7 +- onnxruntime/test/autoep/test_execution.cc | 31 +++++++++ 5 files changed, 78 insertions(+), 43 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index f8cba9435a6bc..25195afa8cfd4 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -11,7 +11,6 @@ #include #include "core/framework/compute_capability.h" #include "core/framework/error_code_helper.h" -#include "core/framework/model_metadef_id_generator.h" #include "core/framework/plugin_data_transfer.h" #include "core/framework/plugin_ep_stream.h" #include "core/graph/ep_api_types.h" @@ -227,8 +226,6 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie return {}; } - ModelMetadefIdGenerator generator; - // Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances. for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) { // Skip this node grouping if any node has already been assigned to another EP. @@ -278,8 +275,9 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie // TODO(adrianlizarraga): Do not use the heavy-weight CreateSupportedPartitions just to check if the user // provided a single partition. Use utils::MakeCapability() and create a new helper to check that there are no // unsupported nodes in any path between supported nodes. + auto metadef_gen_functor = PluginEpMetaDefNameFunctor(metadef_id_generator_, graph_viewer, this->Type()); std::vector> capabilities = utils::CreateSupportedPartitions( - graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()), + graph_viewer, node_set, /*stop_ops*/ {}, std::move(metadef_gen_functor), this->Type(), this->Type(), /*node_unit_map*/ nullptr, node_grouping.fusion_options.drop_constant_initializers); diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 8d94607cdace8..815de95862951 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -12,6 +12,7 @@ #include "core/common/common.h" #include "core/common/inlined_containers.h" #include "core/framework/execution_provider.h" +#include "core/framework/model_metadef_id_generator.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" @@ -160,6 +161,13 @@ class PluginExecutionProvider : public IExecutionProvider { // so that it is not destroyed until the EP itself is destroyed. std::vector fused_node_states_; + // Generates a model's hash and a monotonically increasing ID that is unique per model hash. The + // ID is used in the MetaDef name for a fused node containing a compiling EP's supported subgraph. + // + // The same generator instance must be used across calls to GetCapability() to ensure that fused nodes that live in + // different GraphViews (e.g., different branches of an If node) obtain a unique ID. + ModelMetadefIdGenerator metadef_id_generator_; + // Stores the EPContext Nodes created from the OrtNode instances returned by the underlying plugin EP. // Need to store both the Node and NodeArg instances so that they are available when the GraphPartitioner // calls IExecutionProvider::GetEpContextNodes(). diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 518cd9b9c5ae3..09d0623b86ef1 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -196,44 +196,33 @@ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept return ep->name_.c_str(); } -OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* ort_graph) { - Ort::ConstGraph graph{ort_graph}; - - try { - std::vector initializers = graph.GetInitializers(); +bool ExampleEp::CopiesConstantInitializers() const { + return !(config_.enable_ep_context && config_.enable_weightless_ep_context_nodes); +} - for (const auto& initializer : initializers) { - const bool is_constant = initializer.IsConstantInitializer(); +OrtStatus* ExampleEp::TrySaveConstantInitializer(Ort::ConstValueInfo maybe_initializer) { + EXCEPTION_TO_RETURNED_STATUS_BEGIN + const bool is_constant = maybe_initializer.IsConstantInitializer(); - if (is_constant) { - auto name = initializer.GetName(); - Ort::ConstValue value; - auto status = initializer.GetInitializer(value); - if (!status.IsOK()) - return status.release(); + if (is_constant) { + auto name = maybe_initializer.GetName(); + Ort::ConstValue value; + RETURN_IF_ERROR(maybe_initializer.GetInitializer(value)); - auto type_shape = value.GetTensorTypeAndShapeInfo(); - const size_t num_elems = type_shape.GetElementCount(); - const ONNXTensorElementDataType elem_type = type_shape.GetElementType(); - if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) - return Ort::Status("Expected float32 initializers", ORT_INVALID_ARGUMENT).release(); + auto type_shape = value.GetTensorTypeAndShapeInfo(); + const size_t num_elems = type_shape.GetElementCount(); + const ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + return Ort::Status("Expected float32 initializers", ORT_INVALID_ARGUMENT).release(); - std::vector dims = type_shape.GetShape(); - const float* data = value.GetTensorData(); + std::vector dims = type_shape.GetShape(); + const float* data = value.GetTensorData(); - FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; - float_initializers_.emplace(std::move(name), std::move(ep_initializer)); - } - } - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } catch (const std::exception& ex) { - Ort::Status status(ex.what(), ORT_EP_FAIL); - return status.release(); + FloatInitializer ep_initializer = {std::move(dims), std::vector(data, data + num_elems)}; + float_initializers_.emplace(std::move(name), std::move(ep_initializer)); } - return nullptr; + EXCEPTION_TO_RETURNED_STATUS_END } /*static*/ @@ -342,8 +331,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG // Refer to the "ep.enable_weightless_ep_context_nodes" // session configuration entry in onnxruntime_session_options_config_keys.h for more information about generating // weightless EPContext models. - node_fusion_options.drop_constant_initializers = !(ep->config_.enable_ep_context && - ep->config_.enable_weightless_ep_context_nodes); + node_fusion_options.drop_constant_initializers = ep->CopiesConstantInitializers(); RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse( graph_support_info, reinterpret_cast(supported_nodes.data()), @@ -377,11 +365,6 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const Ort::ConstGraph graph{ort_graphs[0]}; - // In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference. - // So, this EP saves constant initializers so that they're available during inference, but an actual EP - // implementation could transfer the weights to device memory. - ep->SaveConstantInitializers(graph); - std::vector nodes = graph.GetNodes(); if (nodes.size() != 1) { Ort::Status status("Expected to compile a single node", ORT_EP_FAIL); @@ -437,6 +420,16 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const return status.release(); } + // In GetCapability(), this EP may have specified that it doesn't need ORT to provide constant initializers + // during inference. If so, this EP saves copies of constant initializers so they're available during inference. + // + // We try to save each node input individually because graph.GetInitializers() does not return + // initializers defined in parent or sibling subgraphs. + if (ep->CopiesConstantInitializers()) { + RETURN_IF_ERROR(ep->TrySaveConstantInitializer(node_inputs[0])); + RETURN_IF_ERROR(ep->TrySaveConstantInitializer(node_inputs[1])); + } + // Create MulKernel for Mul nodes ep->mul_kernels_.emplace(fused_node_name, std::make_unique(ep->ort_api, ep->logger_, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 01a5c72cf2b44..c83e99b2ca604 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -106,7 +106,12 @@ class ExampleEp : public OrtEp, public ApiPtrs { OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); - OrtStatus* SaveConstantInitializers(const OrtGraph* graph); + // Returns true if the EP should save constant initializers so that they are available during inference. + bool CopiesConstantInitializers() const; + + // If the given `OrtValueInfo` represents a constant initializer, this function saves a copy of the initializer data + // within this EP instance so that it is available during inference. + OrtStatus* TrySaveConstantInitializer(Ort::ConstValueInfo maybe_initializer); ExampleEpFactory& factory_; std::string name_; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 5f5836bad15d0..f3ae95961f505 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -1060,5 +1060,36 @@ TEST(OrtEpLibrary, PluginEp_GpuDevice_ReturnsInCompatible) { api->ReleaseDeviceEpIncompatibilityDetails(details); } + +TEST(OrtEpLibrary, CompilingPluginEp_MultiSubgraphs_DuplicateMetaDefIdBug) { + // Test a fix to a bug that incorrectly assigned duplicate MetaDef IDs to fused nodes + // that live in different GraphViews (e.g., in different branches of an If node). + // + // The test model graph does the following computation: + // if (A) { C = Mul(B, 2.0) } + // else { C = Mul(B, 3) } + // return C + // + // The example plugin EP should support and execute both Mul nodes (as compiled fused nodes). + // However, the bug (in PluginExecutionProvider::GetCapability) assigned the same MetaDef ID + // to both compiled Mul nodes, which caused session creation to fail with error: + // + // > Failed to add kernel for example_ep_9433721956998717990_0 example_ep example_ep: + // Conflicting with a registered kernel with op versions. the since version is: 1 + // + // The fix was to use the same instance of `ModelMetadefIdGenerator` across all calls to + // PluginExecutionProvider::GetCapability(). This ensures that the MetaDef IDs are unique. + RegisteredEpDeviceUniquePtr example_kernel_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, + example_kernel_ep)); + Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get()); + + std::unordered_map ep_options; + Ort::SessionOptions session_options; + + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + ASSERT_NO_FATAL_FAILURE(RunIfMulModel(session_options, /*if_condition*/ true)); +} + } // namespace test } // namespace onnxruntime