From 8b775c620ffe19a948b99423cfa882e604549c92 Mon Sep 17 00:00:00 2001 From: Aditya Rastogi Date: Tue, 20 Jan 2026 20:31:32 -0800 Subject: [PATCH 1/4] Initial draft --- .../autoep/library/example_plugin_ep/ep.cc | 152 ++++++++++++++---- .../autoep/library/example_plugin_ep/ep.h | 4 + .../library/example_plugin_ep/ep_factory.cc | 73 +++++++++ .../library/example_plugin_ep/ep_factory.h | 17 ++ onnxruntime/test/autoep/test_execution.cc | 108 +++++++++++++ 5 files changed, 327 insertions(+), 27 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 76b2502da5c3c..0223aca410eac 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -139,6 +139,7 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr + GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -207,11 +208,29 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG } std::vector supported_nodes; + std::vector ep_context_nodes; for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); auto domain = node.GetDomain(); + // Check for EPContext nodes that belong to this EP (from compiled models). + // This is needed to handle loading pre-compiled models with EPContext nodes. + if (op_type == "EPContext" && domain == "com.microsoft") { + // Check if this EPContext node belongs to this EP via the "source" attribute + Ort::ConstOpAttr source_attr; + Ort::Status status = node.GetAttributeByName("source", source_attr); + if (status.IsOK()) { + std::string source_value; + status = source_attr.GetValue(source_value); + if (status.IsOK() && source_value == ep->name_) { + // This EPContext node was created by this EP - collect it for fusion + ep_context_nodes.push_back(node); + } + } + continue; // Don't process further, EPContext is a special case + } + if (op_type == "Mul") { // Check that Mul has inputs/output of type float std::vector inputs = node.GetInputs(); @@ -248,28 +267,45 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG } } - if (supported_nodes.empty()) { - return nullptr; - } - - if (supported_nodes[0].GetOperatorType() == "Mul") { - // Create (optional) fusion options for the supported nodes to fuse. + // Handle EPContext nodes first - these are from loading compiled models + // Each EPContext node is fused individually so it gets its own compiled node + for (const auto& ep_ctx_node : ep_context_nodes) { + std::vector single_node = {ep_ctx_node}; OrtNodeFusionOptions node_fusion_options = {}; node_fusion_options.ort_version_supported = ORT_API_VERSION; - - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. node_fusion_options.drop_constant_initializers = true; RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, - reinterpret_cast(supported_nodes.data()), - supported_nodes.size(), + reinterpret_cast(single_node.data()), + single_node.size(), &node_fusion_options)); - } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { - // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, - // as CustomMul has the concrete kernel implementation. - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); + } + + // Return early if no supported nodes (but not if we have EPContext nodes) + if (supported_nodes.empty() && ep_context_nodes.empty()) { + return nullptr; + } + + // Handle regular nodes + if (!supported_nodes.empty()) { + if (supported_nodes[0].GetOperatorType() == "Mul") { + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { + // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, + // as CustomMul has the concrete kernel implementation. + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); + } } } catch (const Ort::Exception& ex) { @@ -305,22 +341,21 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const std::vector nodes = graph.GetNodes(); if (nodes.size() != 1) { - Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + Ort::Status status("Expected to compile a single node", ORT_EP_FAIL); return status.release(); } auto node_op_type = nodes[0].GetOperatorType(); - if (node_op_type != "Mul") { - Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL); + auto node_domain = nodes[0].GetDomain(); + + // Check if this is an EPContext node (from loading a pre-compiled model) + bool is_ep_context_node = (node_op_type == "EPContext" && node_domain == "com.microsoft"); + + if (node_op_type != "Mul" && !is_ep_context_node) { + Ort::Status status("Expected to compile a Mul node or EPContext node", ORT_EP_FAIL); return status.release(); } - // Now we know we're compiling a single Mul node. Create a computation kernel. - std::vector node_inputs = nodes[0].GetInputs(); - std::array node_input_names; - node_input_names[0] = node_inputs[0].GetName(); - node_input_names[1] = node_inputs[1].GetName(); - Ort::ConstNode fused_node{fused_nodes[0]}; auto ep_name = fused_node.GetEpName(); if (ep_name != ep->name_) { @@ -328,6 +363,38 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const return status.release(); } + // Get input names for the kernel + // For both EPContext and Mul nodes, we use the inner node's inputs from the graph + // Note: EPContext nodes from compiled models may have fewer inputs if constant initializers were dropped + std::array node_input_names = {"", ""}; + std::vector node_inputs = nodes[0].GetInputs(); + + if (is_ep_context_node) { + // This example EP does *not* fully support executing EPContext nodes. + // + // When a model is compiled with this EP, constant initializers may be dropped from the EPContext + // node's inputs. A production EP would serialize initializer data and compiled state into the + // `ep_cache_context` attribute and deserialize it here. This example EP does not do that. + // + // As a result: + // - Session creation with a compiled model will succeed (for metadata access, compatibility testing) + // - Inference may fail at runtime if MulKernel::Compute cannot find expected inputs/initializers + // + // To fully support EPContext execution, deserialize `ep_cache_context` and restore initializers. + for (size_t i = 0; i < node_inputs.size() && i < 2; ++i) { + node_input_names[i] = node_inputs[i].GetName(); + } + } else { + // For Mul nodes during initial compilation, we need exactly 2 inputs + if (node_inputs.size() != 2) { + std::string err_msg = "Mul node should have 2 inputs, got " + std::to_string(node_inputs.size()); + Ort::Status status(err_msg.c_str(), ORT_EP_FAIL); + return status.release(); + } + node_input_names[0] = node_inputs[0].GetName(); + node_input_names[1] = node_inputs[1].GetName(); + } + // Associate the name of the fused node with our MulKernel. auto fused_node_name = fused_node.GetName(); ep->kernels_.emplace(std::move(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, @@ -340,7 +407,8 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const node_compute_infos[0] = node_compute_info.release(); // Create EpContext nodes for the fused nodes we compiled. - if (ep->config_.enable_ep_context) { + // Don't create new EPContext nodes if we're already processing an EPContext node! + if (ep->config_.enable_ep_context && !is_ep_context_node) { assert(ep_context_nodes != nullptr); RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), gsl::span(ep_context_nodes, count))); @@ -521,3 +589,33 @@ void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void (void)kernel; // Do nothing for this example. } + +// +// Implementation of GetCompiledModelCompatibilityInfo +// +/*static*/ +const char* ORT_API_CALL ExampleEp::GetCompiledModelCompatibilityInfoImpl(OrtEp* this_ptr, + const OrtGraph* graph) noexcept { + // Suppress unused parameter warning. The ORT_UNUSED_PARAMETER macro is in internal headers + // (core/common/common.h) which are not available to plugin EPs using only public APIs. + // A real EP would inspect the graph for model-specific compatibility info. + (void)graph; + auto* ep = static_cast(this_ptr); + + // Generate a compatibility string that includes: + // - EP name + // - EP version (from factory) + // - ORT API version + // + // In a real EP, this might include driver versions, hardware IDs, etc. + // The string format is EP-defined and should be parseable by ValidateCompiledModelCompatibilityInfo. + ep->compatibility_info_ = ep->name_ + ";version=" + ep->factory_.GetEpVersionString() + ";ort_api_version=" + + std::to_string(ORT_API_VERSION); + + IGNORE_ORTSTATUS(ep->ort_api.Logger_LogMessage(&ep->logger_, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + ("GetCompiledModelCompatibilityInfo returning: " + ep->compatibility_info_).c_str(), + ORT_FILE, __LINE__, __FUNCTION__)); + + return ep->compatibility_info_.c_str(); +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 5d4788ed76bf2..b70a58e2783ab 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -78,6 +78,9 @@ class ExampleEp : public OrtEp, public ApiPtrs { OrtNodeComputeInfo** node_compute_infos, size_t num_node_compute_infos) noexcept; + static const char* ORT_API_CALL GetCompiledModelCompatibilityInfoImpl(OrtEp* this_ptr, + const OrtGraph* graph) noexcept; + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); @@ -89,4 +92,5 @@ class ExampleEp : public OrtEp, public ApiPtrs { const OrtLogger& logger_; std::unordered_map> kernels_; std::unordered_map float_initializers_; + std::string compatibility_info_; // Cached compatibility string returned by GetCompiledModelCompatibilityInfo }; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index c56f0f74ab74a..10bb5e054e2cf 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -42,6 +42,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL GetNumCustomOpDomains = GetNumCustomOpDomainsImpl; GetCustomOpDomains = GetCustomOpDomainsImpl; + ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl; // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. @@ -417,3 +418,75 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateExternalResourceImporterForDevic return nullptr; } + +OrtStatus* ORT_API_CALL ExampleEpFactory::ValidateCompiledModelCompatibilityInfoImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* /*devices*/, + size_t /*num_devices*/, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + auto& factory = *static_cast(this_ptr); + + if (model_compatibility == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "model_compatibility cannot be nullptr"); + } + + // Parse the compatibility info to check if it matches our current configuration. + // The expected format is "ExampleEP;version=0.1.0;ort_api_version=24". + // For this example implementation, we simply check if the string starts with our EP name. + + if (compatibility_info == nullptr || compatibility_info[0] == '\0') { + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + std::string info(compatibility_info); + std::string expected_prefix = factory.ep_name_ + ";"; + + if (info.find(expected_prefix) != 0) { + // The compatibility info doesn't match our EP + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Parse version parts: "ExampleEP;version=X;ort_api_version=Y" + // Look for "version=" and extract the value + size_t version_pos = info.find("version="); + size_t ort_version_pos = info.find("ort_api_version="); + + if (version_pos == std::string::npos) { + // Invalid format + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Extract EP version (between "version=" and the next ";") + size_t version_start = version_pos + 8; // length of "version=" + size_t version_end = info.find(';', version_start); + std::string ep_version = (version_end != std::string::npos) + ? info.substr(version_start, version_end - version_start) + : info.substr(version_start); + + // Check if the EP version matches our version + if (ep_version != factory.ep_version_) { + // Different EP version - might work but prefer recompilation + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + return nullptr; + } + + // Check ORT API version if present + if (ort_version_pos != std::string::npos) { + size_t ort_version_start = ort_version_pos + 16; // length of "ort_api_version=" + std::string ort_version = info.substr(ort_version_start); + std::string current_ort_version = std::to_string(ORT_API_VERSION); + if (ort_version != current_ort_version) { + // Different ORT version - might still work but prefer recompilation + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + return nullptr; + } + } + + // Everything matches - the compiled model is fully compatible + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 244051dd5e4d0..91478047afb0a 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -28,6 +28,16 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return arena_allocator_.get(); } + // Get the EP version string. + const std::string& GetEpVersionString() const { + return ep_version_; + } + + // Get the vendor ID. + uint32_t GetVendorIdValue() const { + return vendor_id_; + } + const OrtLogger& default_logger_; // default logger for the EP factory private: @@ -89,6 +99,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { _Outptr_result_maybenull_ OrtCustomOpDomain** domains, _Out_ size_t num_domains) noexcept; + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept; + const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index a3cca42d81c6e..d3668b198fe56 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -10,6 +10,7 @@ #include "core/graph/constants.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "test/autoep/test_autoep_utils.h" #include "test/shared_lib/utils.h" @@ -429,6 +430,113 @@ TEST(OrtEpLibrary, PluginEp_VirtGpu_GenEpContextModel) { } } +// Test that compatibility info is written to compiled model metadata +TEST(OrtEpLibrary, PluginEp_CompatibilityInfo_WrittenToMetadata) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_compat_test.onnx"); + std::filesystem::remove(output_model_file); + + // Compile the model + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + } + + // Load the compiled model and check metadata for compatibility info + { + Ort::SessionOptions session_options; + // Need to add the EP to handle EPContext nodes + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, output_model_file, session_options); + Ort::AllocatorWithDefaultOptions allocator; + + // Check that the model has EP compatibility metadata + Ort::ModelMetadata metadata = session.GetModelMetadata(); + auto custom_metadata_keys = metadata.GetCustomMetadataMapKeysAllocated(allocator); + + // Check for the exact metadata key for this EP: "ep_compatibility_info.example_ep" + const std::string expected_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + "example_ep"; + + bool found_compatibility_key = false; + for (const auto& key : custom_metadata_keys) { + std::string key_str(key.get()); + if (key_str == expected_key) { + found_compatibility_key = true; + break; + } + } + ASSERT_TRUE(found_compatibility_key) << "Expected metadata key '" << expected_key << "' in compiled model"; + + // Verify the compatibility value contains expected EP information + auto value = metadata.LookupCustomMetadataMapAllocated(expected_key.c_str(), allocator); + ASSERT_NE(value.get(), nullptr); + std::string compatibility_value = value.get(); + ASSERT_GT(compatibility_value.length(), 0) << "Compatibility info should not be empty"; + + // Validate the exact compatibility string format and values + // Format: "example_ep;version=0.1.0;ort_api_version=" + std::string expected_compatibility_info = "example_ep;version=0.1.0;ort_api_version=" + + std::to_string(ORT_API_VERSION); + EXPECT_EQ(compatibility_value, expected_compatibility_info); + } + + std::filesystem::remove(output_model_file); +} + +// Test loading a compiled model validates compatibility successfully +TEST(OrtEpLibrary, PluginEp_CompatibilityInfo_ValidatedOnLoad) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* compiled_model_file = ORT_TSTR("plugin_ep_compat_validate.onnx"); + std::filesystem::remove(compiled_model_file); + + // Step 1: Compile the model + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(compiled_model_file); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(compiled_model_file)); + } + + // Step 2: Load the compiled model with the same EP - should succeed + // The EP should validate compatibility and return OPTIMAL status + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + // This should not throw - EP should validate compatibility as OPTIMAL + ASSERT_NO_THROW(Ort::Session session(*ort_env, compiled_model_file, session_options)); + } + + std::filesystem::remove(compiled_model_file); +} + // Uses the original compiling approach with session option configs (instead of explicit compile API). // Test that ORT does not overwrite an output model if it already exists. // Notably, this tests the case in which ORT automatically generates the output model name. From 31714581da9da6d2b4c8466a1c7de5bdfbbb9efa Mon Sep 17 00:00:00 2001 From: adrastogi Date: Wed, 21 Jan 2026 13:47:39 -0800 Subject: [PATCH 2/4] Update onnxruntime/test/autoep/library/example_plugin_ep/ep.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/autoep/library/example_plugin_ep/ep.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 0223aca410eac..a0a286902b3fb 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -137,8 +137,8 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C GetCapability = GetCapabilityImpl; Compile = CompileImpl; ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; - CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr - CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr + CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_, From 113690ecbf101a41b5dbe429181c669efedc74ba Mon Sep 17 00:00:00 2001 From: Aditya Rastogi Date: Mon, 2 Feb 2026 22:26:51 -0800 Subject: [PATCH 3/4] PR feedback --- .../autoep/library/example_plugin_ep/ep.cc | 245 ++++++++++++------ .../autoep/library/example_plugin_ep/ep.h | 28 +- onnxruntime/test/autoep/test_execution.cc | 118 +++++++++ 3 files changed, 303 insertions(+), 88 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 0223aca410eac..cd8555ca0af6e 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -107,10 +107,35 @@ OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) { return nullptr; } +OrtStatus* EpContextKernel::Compute(OrtKernelContext* /*kernel_ctx*/) { + // This example EP does not fully support EPContext inference. + // A production EP would: + // 1. Deserialize state from ep_cache_context attribute during Compile + // 2. Use that state here to perform actual computation + // + // Session creation succeeds for metadata access and compatibility testing, + // but inference requires deserializing ep_cache_context (not implemented). + return ort_api.CreateStatus( + ORT_NOT_IMPLEMENTED, + "EPContext inference is not fully implemented in this example EP. " + "Session creation succeeds for metadata access and compatibility testing, " + "but inference requires deserializing ep_cache_context (not implemented). " + "A production EP would restore compiled state from the EPContext node's attributes."); +} + +/// +/// Intermediate base class with virtual destructor for proper polymorphic deletion. +/// This allows ReleaseNodeComputeInfosImpl to delete any derived type correctly +/// without manual type dispatch. +/// +struct NodeComputeInfoBase : OrtNodeComputeInfo { + virtual ~NodeComputeInfoBase() = default; +}; + /// /// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. /// -struct ExampleNodeComputeInfo : OrtNodeComputeInfo { +struct ExampleNodeComputeInfo : NodeComputeInfoBase { explicit ExampleNodeComputeInfo(ExampleEp& ep); static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, @@ -123,6 +148,22 @@ struct ExampleNodeComputeInfo : OrtNodeComputeInfo { ExampleEp& ep; }; +/// +/// OrtNodeComputeInfo for EPContext nodes - delegates to EpContextKernel. +/// +struct EpContextNodeComputeInfo : NodeComputeInfoBase { + explicit EpContextNodeComputeInfo(ExampleEp& ep); + + static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state); + static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context); + static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state); + + ExampleEp& ep; +}; + ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger) : OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized ApiPtrs{static_cast(factory)}, @@ -207,8 +248,9 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; // No nodes to process } + // Single array for all supported node types. + // This EP only supports compiling one node at a time (a documented limitation). std::vector supported_nodes; - std::vector ep_context_nodes; for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); @@ -224,8 +266,9 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG std::string source_value; status = source_attr.GetValue(source_value); if (status.IsOK() && source_value == ep->name_) { - // This EPContext node was created by this EP - collect it for fusion - ep_context_nodes.push_back(node); + // This EPContext node was created by this EP - add to supported nodes + supported_nodes.push_back(node); + break; // Only support one node at a time } } continue; // Don't process further, EPContext is a special case @@ -260,52 +303,42 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG } } - supported_nodes.push_back(node); // Only support a single Mul for now. - break; + supported_nodes.push_back(node); + break; // Only support a single Mul for now. } else if (op_type == "Custom_Mul" && domain == "test") { supported_nodes.push_back(node); + break; // Only support one node at a time (consistent with Mul/EPContext handling). } } - // Handle EPContext nodes first - these are from loading compiled models - // Each EPContext node is fused individually so it gets its own compiled node - for (const auto& ep_ctx_node : ep_context_nodes) { - std::vector single_node = {ep_ctx_node}; - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, - reinterpret_cast(single_node.data()), - single_node.size(), - &node_fusion_options)); - } - - // Return early if no supported nodes (but not if we have EPContext nodes) - if (supported_nodes.empty() && ep_context_nodes.empty()) { + // Return early if no supported nodes + if (supported_nodes.empty()) { return nullptr; } - // Handle regular nodes - if (!supported_nodes.empty()) { - if (supported_nodes[0].GetOperatorType() == "Mul") { - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; - - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, - reinterpret_cast(supported_nodes.data()), - supported_nodes.size(), - &node_fusion_options)); - } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { - // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, - // as CustomMul has the concrete kernel implementation. - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); - } + // Unified dispatch based on node type + const auto& node = supported_nodes[0]; + auto op_type = node.GetOperatorType(); + + if (op_type == "Custom_Mul") { + // Custom_Mul has concrete kernel implementation - no fusion needed. + // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled. + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, node)); + } else { + // Both EPContext and Mul use AddNodesToFuse + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse( + graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); } } catch (const Ort::Exception& ex) { @@ -351,6 +384,18 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const // Check if this is an EPContext node (from loading a pre-compiled model) bool is_ep_context_node = (node_op_type == "EPContext" && node_domain == "com.microsoft"); + // Validate configuration: cannot enable EPContext generation when loading a compiled model. + // This is a configuration error - you cannot re-compile an already compiled model. + if (ep->config_.enable_ep_context && is_ep_context_node) { + Ort::Status status( + "Invalid configuration: 'enable_ep_context' is true but model already contains " + "EPContext nodes. Cannot re-compile an already compiled model. Either:\n" + " 1. Use the original (uncompiled) model as input, or\n" + " 2. Disable ep_context generation when loading a compiled model.", + ORT_INVALID_ARGUMENT); + return status.release(); + } + if (node_op_type != "Mul" && !is_ep_context_node) { Ort::Status status("Expected to compile a Mul node or EPContext node", ORT_EP_FAIL); return status.release(); @@ -363,55 +408,42 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const return status.release(); } - // Get input names for the kernel - // For both EPContext and Mul nodes, we use the inner node's inputs from the graph - // Note: EPContext nodes from compiled models may have fewer inputs if constant initializers were dropped - std::array node_input_names = {"", ""}; - std::vector node_inputs = nodes[0].GetInputs(); + auto fused_node_name = fused_node.GetName(); if (is_ep_context_node) { - // This example EP does *not* fully support executing EPContext nodes. - // - // When a model is compiled with this EP, constant initializers may be dropped from the EPContext - // node's inputs. A production EP would serialize initializer data and compiled state into the - // `ep_cache_context` attribute and deserialize it here. This example EP does not do that. - // - // As a result: - // - Session creation with a compiled model will succeed (for metadata access, compatibility testing) - // - Inference may fail at runtime if MulKernel::Compute cannot find expected inputs/initializers - // - // To fully support EPContext execution, deserialize `ep_cache_context` and restore initializers. - for (size_t i = 0; i < node_inputs.size() && i < 2; ++i) { - node_input_names[i] = node_inputs[i].GetName(); - } + // Create EpContextKernel for EPContext nodes - clearly separates from MulKernel + ep->ep_context_kernels_.emplace(fused_node_name, + std::make_unique(ep->ort_api, ep->logger_)); + + // Use EpContextNodeComputeInfo for EPContext nodes + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); } else { // For Mul nodes during initial compilation, we need exactly 2 inputs + std::vector node_inputs = nodes[0].GetInputs(); if (node_inputs.size() != 2) { std::string err_msg = "Mul node should have 2 inputs, got " + std::to_string(node_inputs.size()); Ort::Status status(err_msg.c_str(), ORT_EP_FAIL); return status.release(); } - node_input_names[0] = node_inputs[0].GetName(); - node_input_names[1] = node_inputs[1].GetName(); - } - // Associate the name of the fused node with our MulKernel. - auto fused_node_name = fused_node.GetName(); - ep->kernels_.emplace(std::move(fused_node_name), std::make_unique(ep->ort_api, ep->logger_, - ep->float_initializers_, - node_input_names[0], - node_input_names[1])); - - // Update the OrtNodeComputeInfo associated with the graph. - auto node_compute_info = std::make_unique(*ep); - node_compute_infos[0] = node_compute_info.release(); - - // Create EpContext nodes for the fused nodes we compiled. - // Don't create new EPContext nodes if we're already processing an EPContext node! - if (ep->config_.enable_ep_context && !is_ep_context_node) { - assert(ep_context_nodes != nullptr); - RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), - gsl::span(ep_context_nodes, count))); + // Create MulKernel for Mul nodes + ep->mul_kernels_.emplace(fused_node_name, + std::make_unique(ep->ort_api, ep->logger_, + ep->float_initializers_, + node_inputs[0].GetName(), + node_inputs[1].GetName())); + + // Use ExampleNodeComputeInfo for Mul nodes + auto node_compute_info = std::make_unique(*ep); + node_compute_infos[0] = node_compute_info.release(); + + // Create EpContext nodes for the fused nodes we compiled (only for Mul, not EPContext). + if (ep->config_.enable_ep_context) { + assert(ep_context_nodes != nullptr); + RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span(fused_nodes, count), + gsl::span(ep_context_nodes, count))); + } } } catch (const Ort::Exception& ex) { Ort::Status status(ex); @@ -430,7 +462,9 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, size_t num_node_compute_infos) noexcept { (void)this_ptr; for (size_t i = 0; i < num_node_compute_infos; i++) { - delete static_cast(node_compute_infos[i]); + // All node compute info types derive from NodeComputeInfoBase which has a virtual destructor. + // This ensures correct polymorphic deletion without manual type dispatch. + delete static_cast(node_compute_infos[i]); } } @@ -565,9 +599,9 @@ OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, ExampleEp& ep = node_compute_info->ep; std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); - auto kernel_it = ep.Kernels().find(fused_node_name); - if (kernel_it == ep.Kernels().end()) { - std::string message = "Unable to get kernel for fused node with name " + fused_node_name; + auto kernel_it = ep.MulKernels().find(fused_node_name); + if (kernel_it == ep.MulKernels().end()) { + std::string message = "Unable to get MulKernel for fused node with name " + fused_node_name; return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); } @@ -590,6 +624,47 @@ void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void // Do nothing for this example. } +// +// Implementation of EpContextNodeComputeInfo +// +EpContextNodeComputeInfo::EpContextNodeComputeInfo(ExampleEp& ep) : ep(ep) { + ort_version_supported = ORT_API_VERSION; + CreateState = CreateStateImpl; + Compute = ComputeImpl; + ReleaseState = ReleaseStateImpl; +} + +OrtStatus* EpContextNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr, + OrtNodeComputeContext* compute_context, + void** compute_state) { + auto* node_compute_info = static_cast(this_ptr); + ExampleEp& ep = node_compute_info->ep; + + std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context); + auto kernel_it = ep.EpContextKernels().find(fused_node_name); + if (kernel_it == ep.EpContextKernels().end()) { + std::string message = "Unable to get EpContextKernel for fused node with name " + fused_node_name; + return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str()); + } + + EpContextKernel& kernel = *kernel_it->second; + *compute_state = &kernel; + return nullptr; +} + +OrtStatus* EpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state, + OrtKernelContext* kernel_context) { + (void)this_ptr; + EpContextKernel& kernel = *reinterpret_cast(compute_state); + return kernel.Compute(kernel_context); +} + +void EpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) { + (void)this_ptr; + (void)compute_state; + // Do nothing for this example. +} + // // Implementation of GetCompiledModelCompatibilityInfo // diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index b70a58e2783ab..3a12860921949 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -37,6 +37,23 @@ struct MulKernel { std::string input1_name; }; +/// +/// Kernel for EPContext nodes loaded from compiled models. +/// +/// This example EP does not support EPContext inference - Compute() returns NOT_IMPLEMENTED. +/// A production EP would deserialize the ep_cache_context attribute and restore compiled state. +/// This kernel exists to clearly separate EPContext handling from MulKernel. +/// +struct EpContextKernel { + EpContextKernel(const OrtApi& ort_api, const OrtLogger& logger) + : ort_api(ort_api), logger(logger) {} + + OrtStatus* Compute(OrtKernelContext* kernel_ctx); + + const OrtApi& ort_api; + const OrtLogger& logger; +}; + /// /// Example EP that can compile a single Mul operator. /// @@ -51,8 +68,12 @@ class ExampleEp : public OrtEp, public ApiPtrs { ~ExampleEp(); - std::unordered_map>& Kernels() { - return kernels_; + std::unordered_map>& MulKernels() { + return mul_kernels_; + } + + std::unordered_map>& EpContextKernels() { + return ep_context_kernels_; } private: @@ -90,7 +111,8 @@ class ExampleEp : public OrtEp, public ApiPtrs { std::string name_; Config config_{}; const OrtLogger& logger_; - std::unordered_map> kernels_; + std::unordered_map> mul_kernels_; + std::unordered_map> ep_context_kernels_; std::unordered_map float_initializers_; std::string compatibility_info_; // Cached compatibility string returned by GetCompiledModelCompatibilityInfo }; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index d3668b198fe56..00bd5fdea9fce 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -537,6 +537,124 @@ TEST(OrtEpLibrary, PluginEp_CompatibilityInfo_ValidatedOnLoad) { std::filesystem::remove(compiled_model_file); } +// Test that loading a compiled model with ep_context_enable=1 returns an error. +// This is an invalid configuration: the user is asking to generate EP context from a model +// that already contains EPContext nodes. +TEST(OrtEpLibrary, PluginEp_Error_LoadCompiledModelWithEpContextEnabled) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* compiled_model_file = ORT_TSTR("plugin_ep_recompile_test.onnx"); + std::filesystem::remove(compiled_model_file); + + // Step 1: Compile the original model (CompileModel API implicitly generates EPContext nodes) + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(compiled_model_file); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(compiled_model_file)); + } + + // Step 2: Attempt to load the compiled model with ep.context_enable=1 - should fail + { + Ort::SessionOptions session_options; + session_options.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); // Request EP context generation + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + // Loading a compiled model with ep_context_enable=1 should fail + try { + Ort::Session session(*ort_env, compiled_model_file, session_options); + FAIL() << "Expected error when loading compiled model with ep_context_enable=1"; + } catch (const Ort::Exception& e) { + std::string error_msg = e.what(); + // Verify error message mentions the issue + EXPECT_TRUE(error_msg.find("EPContext") != std::string::npos || + error_msg.find("already") != std::string::npos || + error_msg.find("re-compile") != std::string::npos) + << "Error should mention EPContext or re-compilation: " << error_msg; + } + } + + std::filesystem::remove(compiled_model_file); +} + +// Test that EPContext inference returns expected "not implemented" error. +// This documents that the example EP does not fully support EPContext execution. +TEST(OrtEpLibrary, PluginEp_EpContextInference_NotImplemented) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* compiled_model_file = ORT_TSTR("plugin_ep_inference_test.onnx"); + std::filesystem::remove(compiled_model_file); + + // Step 1: Compile the model with EP context enabled + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(compiled_model_file); + + ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options)); + ASSERT_TRUE(std::filesystem::exists(compiled_model_file)); + } + + // Step 2: Load compiled model and attempt inference - should fail with clear error + { + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + Ort::Session session(*ort_env, compiled_model_file, session_options); + + // Prepare inputs - mul_1.onnx has input X of shape [3,2] + std::vector input_x = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector input_shape = {3, 2}; + + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + Ort::Value input_x_tensor = Ort::Value::CreateTensor( + memory_info, input_x.data(), input_x.size(), + input_shape.data(), input_shape.size()); + + const char* input_names[] = {"X"}; + const char* output_names[] = {"Y"}; + std::vector input_tensors; + input_tensors.push_back(std::move(input_x_tensor)); + + // Inference should fail with NOT_IMPLEMENTED - verify exception content + try { + auto outputs = session.Run(Ort::RunOptions{nullptr}, + input_names, input_tensors.data(), input_tensors.size(), + output_names, 1); + FAIL() << "Expected exception for EPContext inference, but Run() succeeded"; + } catch (const Ort::Exception& e) { + std::string msg = e.what(); + // Verify error mentions the limitation + EXPECT_TRUE(msg.find("not implemented") != std::string::npos || + msg.find("NOT_IMPLEMENTED") != std::string::npos || + msg.find("EPContext") != std::string::npos) + << "Expected NOT_IMPLEMENTED or EPContext in error, got: " << msg; + } + } + + std::filesystem::remove(compiled_model_file); +} + // Uses the original compiling approach with session option configs (instead of explicit compile API). // Test that ORT does not overwrite an output model if it already exists. // Notably, this tests the case in which ORT automatically generates the output model name. From 1b34b1f9856b91cb611aa073c4baf6634eb7a25f Mon Sep 17 00:00:00 2001 From: adrastogi Date: Wed, 4 Feb 2026 14:24:19 -0800 Subject: [PATCH 4/4] Update onnxruntime/test/autoep/library/example_plugin_ep/ep.cc Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- onnxruntime/test/autoep/library/example_plugin_ep/ep.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index 277f33d284662..3537a916310f8 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -430,9 +430,9 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const // Create MulKernel for Mul nodes ep->mul_kernels_.emplace(fused_node_name, std::make_unique(ep->ort_api, ep->logger_, - ep->float_initializers_, - node_inputs[0].GetName(), - node_inputs[1].GetName())); + ep->float_initializers_, + node_inputs[0].GetName(), + node_inputs[1].GetName())); // Use ExampleNodeComputeInfo for Mul nodes auto node_compute_info = std::make_unique(*ep);