Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions onnxruntime/core/framework/session_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1257,10 +1257,41 @@ using NodePlacementSet = std::unordered_set<std::string>;

static Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_verbose,
NodePlacementMap& node_placements,
NodePlacementSet& node_placement_provider_set) {
NodePlacementSet& node_placement_provider_set,
const ExecutionProviders& providers) {
for (const auto& node : graph.Nodes()) {
const auto& node_provider = node.GetExecutionProviderType();
if (node_provider.empty()) {
// Provide a more descriptive error for EPContext nodes that were not assigned to an EP.
if (node.OpType() == "EPContext") {
// Get information about who generated the EPContext node from the 'source' attribute.
// Commonly, 'source' will be the name of the EP that generated the node, but that is not required.
// An EP may choose to use a different source identifier.
std::string source = "(unknown)";
const auto& attrs = node.GetAttributes();
auto it = attrs.find("source");

if (it != attrs.end() && it->second.has_s()) {
source = it->second.s();
}

const auto& ep_ids = providers.GetIds();
std::ostringstream session_ep_names;

for (size_t i = 0; i < ep_ids.size(); ++i) {
if (i > 0) {
session_ep_names << ", ";
}
session_ep_names << ep_ids[i];
}

return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"EPContext node generated by '", source, "' is not ",
"compatible with any execution provider added to the session. ",
"EPContext node name: '", node.Name(), "'. Available session execution providers: [",
session_ep_names.str(), "].");
}

return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"Could not find an implementation for ",
node.OpType(), "(", node.SinceVersion(), ") node with name '", node.Name(), "'");
Expand All @@ -1280,7 +1311,7 @@ static Status VerifyEachNodeIsAssignedToAnEpImpl(const Graph& graph, bool is_ver
const auto subgraphs = node.GetSubgraphs();
for (const auto& subgraph : subgraphs) {
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(*subgraph, is_verbose, node_placements,
node_placement_provider_set));
node_placement_provider_set, providers));
}
}
}
Expand All @@ -1299,7 +1330,8 @@ static Status VerifyEachNodeIsAssignedToAnEp(const Graph& graph, const logging::
const bool is_verbose_mode = false;
#endif // !defined(ORT_MINIMAL_BUILD)

ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(graph, is_verbose_mode, node_placements, node_placement_provider_set));
ORT_RETURN_IF_ERROR(VerifyEachNodeIsAssignedToAnEpImpl(graph, is_verbose_mode, node_placements,
node_placement_provider_set, providers));

#if !defined(ORT_MINIMAL_BUILD)
// print placement info
Expand Down
53 changes: 53 additions & 0 deletions onnxruntime/test/autoep/test_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include <filesystem>
#include <string_view>
#include <vector>
// #include <absl/base/config.h>
#include <gsl/gsl>
Expand Down Expand Up @@ -400,6 +401,58 @@ TEST(OrtEpLibrary, PluginEp_GenEpContextModel) {
}
}

// Test loading a compiled model without registering the required EP with the session.
// We expect to get an explicit error that says that an EPContext node generated by "example_ep"
// was not assigned to the appropriate EP.
TEST(OrtEpLibrary, PluginEp_ErrorWhenLoadEPContextModel_WithoutRequiredEp) {
RegisteredEpDeviceUniquePtr example_ep;
ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep));
Ort::ConstEpDevice plugin_ep_device(example_ep.get());

// Create a compiled model for the example EP.
const ORTCHAR_T* compiled_model_file = ORT_TSTR("plugin_ep_compiled_test_errorwhenloadwithoutep.onnx");
{
const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx");
std::filesystem::remove(compiled_model_file);

Comment thread
adrianlizarraga marked this conversation as resolved.
Ort::SessionOptions session_options;
std::unordered_map<std::string, std::string> ep_options;

session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);

Ort::ModelCompilationOptions compile_options(*ort_env, session_options);
compile_options.SetFlags(OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED);
compile_options.SetInputModelPath(input_model_file);
compile_options.SetOutputModelPath(compiled_model_file);

ASSERT_CXX_ORTSTATUS_OK(Ort::CompileModel(*ort_env, compile_options));
ASSERT_TRUE(std::filesystem::exists(compiled_model_file));
}

// Create a session without the plugin EP and expect an error.
{
Ort::SessionOptions session_options;

try {
Ort::Session session(*ort_env, compiled_model_file, session_options);
FAIL() << "Expected error when loading compiled model without the necessary EP";
} catch (const Ort::Exception& e) {
std::string error_msg = e.what();
std::string_view expected_msg_prefix =
"EPContext node generated by 'example_ep' is not compatible with any execution provider "
"added to the session.";
std::string_view expected_session_eps = "[CPUExecutionProvider]";
Comment thread
adrianlizarraga marked this conversation as resolved.

EXPECT_TRUE(error_msg.find(expected_msg_prefix) != std::string::npos &&
error_msg.find(expected_session_eps) != std::string::npos)
<< "Error should mention EPContext node's required EP and the available EPs:\n"
<< error_msg;
}
}

std::filesystem::remove(compiled_model_file);
}

// Generate an EPContext model with a plugin EP that uses a virtual GPU.
TEST(OrtEpLibrary, PluginEp_VirtGpu_GenEpContextModel) {
RegisteredEpDeviceUniquePtr example_ep;
Expand Down
Loading