Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 215 additions & 42 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,35 @@ OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) {
return nullptr;
}

OrtStatus* EpContextKernel::Compute(OrtKernelContext* /*kernel_ctx*/) {
// This example EP does not fully support EPContext inference.
// A production EP would:
// 1. Deserialize state from ep_cache_context attribute during Compile
// 2. Use that state here to perform actual computation
//
// Session creation succeeds for metadata access and compatibility testing,
// but inference requires deserializing ep_cache_context (not implemented).
return ort_api.CreateStatus(
ORT_NOT_IMPLEMENTED,
"EPContext inference is not fully implemented in this example EP. "
"Session creation succeeds for metadata access and compatibility testing, "
"but inference requires deserializing ep_cache_context (not implemented). "
"A production EP would restore compiled state from the EPContext node's attributes.");
}

/// <summary>
/// Intermediate base class with virtual destructor for proper polymorphic deletion.
/// This allows ReleaseNodeComputeInfosImpl to delete any derived type correctly
/// without manual type dispatch.
/// </summary>
struct NodeComputeInfoBase : OrtNodeComputeInfo {
virtual ~NodeComputeInfoBase() = default;
};

/// <summary>
/// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph.
/// </summary>
struct ExampleNodeComputeInfo : OrtNodeComputeInfo {
struct ExampleNodeComputeInfo : NodeComputeInfoBase {
explicit ExampleNodeComputeInfo(ExampleEp& ep);

static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr,
Expand All @@ -123,6 +148,22 @@ struct ExampleNodeComputeInfo : OrtNodeComputeInfo {
ExampleEp& ep;
};

/// <summary>
/// OrtNodeComputeInfo for EPContext nodes - delegates to EpContextKernel.
/// </summary>
struct EpContextNodeComputeInfo : NodeComputeInfoBase {
explicit EpContextNodeComputeInfo(ExampleEp& ep);

static OrtStatus* ORT_API_CALL CreateStateImpl(OrtNodeComputeInfo* this_ptr,
OrtNodeComputeContext* compute_context,
void** compute_state);
static OrtStatus* ORT_API_CALL ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state,
OrtKernelContext* kernel_context);
static void ORT_API_CALL ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state);

ExampleEp& ep;
};

ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const Config& config, const OrtLogger& logger)
: OrtEp{}, // explicitly call the struct ctor to ensure all optional values are default initialized
ApiPtrs{static_cast<const ApiPtrs&>(factory)},
Expand All @@ -137,8 +178,9 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C
GetCapability = GetCapabilityImpl;
Compile = CompileImpl;
ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl;
CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models

IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
Expand Down Expand Up @@ -206,12 +248,32 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
return nullptr; // No nodes to process
}

// Single array for all supported node types.
// This EP only supports compiling one node at a time (a documented limitation).
std::vector<Ort::ConstNode> supported_nodes;

for (const auto& node : nodes) {
auto op_type = node.GetOperatorType();
auto domain = node.GetDomain();

// Check for EPContext nodes that belong to this EP (from compiled models).
// This is needed to handle loading pre-compiled models with EPContext nodes.
if (op_type == "EPContext" && domain == "com.microsoft") {
// Check if this EPContext node belongs to this EP via the "source" attribute
Ort::ConstOpAttr source_attr;
Ort::Status status = node.GetAttributeByName("source", source_attr);
if (status.IsOK()) {
std::string source_value;
status = source_attr.GetValue(source_value);
if (status.IsOK() && source_value == ep->name_) {
// This EPContext node was created by this EP - add to supported nodes
supported_nodes.push_back(node);
break; // Only support one node at a time
}
}
continue; // Don't process further, EPContext is a special case
}

if (op_type == "Mul") {
// Check that Mul has inputs/output of type float
std::vector<Ort::ConstValueInfo> inputs = node.GetInputs();
Expand Down Expand Up @@ -241,19 +303,29 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
}
}

supported_nodes.push_back(node); // Only support a single Mul for now.
break;
supported_nodes.push_back(node);
break; // Only support a single Mul for now.
} else if (op_type == "Custom_Mul" && domain == "test") {
supported_nodes.push_back(node);
break; // Only support one node at a time (consistent with Mul/EPContext handling).
}
}

// Return early if no supported nodes
if (supported_nodes.empty()) {
return nullptr;
}

if (supported_nodes[0].GetOperatorType() == "Mul") {
// Create (optional) fusion options for the supported nodes to fuse.
// Unified dispatch based on node type
const auto& node = supported_nodes[0];
auto op_type = node.GetOperatorType();

if (op_type == "Custom_Mul") {
// Custom_Mul has concrete kernel implementation - no fusion needed.
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled.
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, node));
} else {
// Both EPContext and Mul use AddNodesToFuse
OrtNodeFusionOptions node_fusion_options = {};
node_fusion_options.ort_version_supported = ORT_API_VERSION;

Expand All @@ -262,14 +334,11 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
// This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use
// during inference.
node_fusion_options.drop_constant_initializers = true;
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
&node_fusion_options));
} else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") {
// Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled,
// as CustomMul has the concrete kernel implementation.
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0]));
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(
graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
supported_nodes.size(),
&node_fusion_options));
}

} catch (const Ort::Exception& ex) {
Expand Down Expand Up @@ -305,21 +374,32 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const

std::vector<Ort::ConstNode> nodes = graph.GetNodes();
if (nodes.size() != 1) {
Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL);
Ort::Status status("Expected to compile a single node", ORT_EP_FAIL);
return status.release();
}

auto node_op_type = nodes[0].GetOperatorType();
if (node_op_type != "Mul") {
Ort::Status status("Expected to compile a single Mul node", ORT_EP_FAIL);
auto node_domain = nodes[0].GetDomain();

// Check if this is an EPContext node (from loading a pre-compiled model)
bool is_ep_context_node = (node_op_type == "EPContext" && node_domain == "com.microsoft");

// Validate configuration: cannot enable EPContext generation when loading a compiled model.
// This is a configuration error - you cannot re-compile an already compiled model.
if (ep->config_.enable_ep_context && is_ep_context_node) {
Ort::Status status(
"Invalid configuration: 'enable_ep_context' is true but model already contains "
"EPContext nodes. Cannot re-compile an already compiled model. Either:\n"
" 1. Use the original (uncompiled) model as input, or\n"
" 2. Disable ep_context generation when loading a compiled model.",
ORT_INVALID_ARGUMENT);
return status.release();
}

// Now we know we're compiling a single Mul node. Create a computation kernel.
std::vector<Ort::ConstValueInfo> node_inputs = nodes[0].GetInputs();
std::array<std::string, 2> node_input_names;
node_input_names[0] = node_inputs[0].GetName();
node_input_names[1] = node_inputs[1].GetName();
if (node_op_type != "Mul" && !is_ep_context_node) {
Ort::Status status("Expected to compile a Mul node or EPContext node", ORT_EP_FAIL);
return status.release();
}

Ort::ConstNode fused_node{fused_nodes[0]};
auto ep_name = fused_node.GetEpName();
Expand All @@ -328,22 +408,42 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
return status.release();
}

// Associate the name of the fused node with our MulKernel.
auto fused_node_name = fused_node.GetName();
ep->kernels_.emplace(std::move(fused_node_name), std::make_unique<MulKernel>(ep->ort_api, ep->logger_,
ep->float_initializers_,
node_input_names[0],
node_input_names[1]));

// Update the OrtNodeComputeInfo associated with the graph.
auto node_compute_info = std::make_unique<ExampleNodeComputeInfo>(*ep);
node_compute_infos[0] = node_compute_info.release();

// Create EpContext nodes for the fused nodes we compiled.
if (ep->config_.enable_ep_context) {
assert(ep_context_nodes != nullptr);
RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span<const OrtNode*>(fused_nodes, count),
gsl::span<OrtNode*>(ep_context_nodes, count)));

if (is_ep_context_node) {
// Create EpContextKernel for EPContext nodes - clearly separates from MulKernel
ep->ep_context_kernels_.emplace(fused_node_name,
std::make_unique<EpContextKernel>(ep->ort_api, ep->logger_));

// Use EpContextNodeComputeInfo for EPContext nodes
auto node_compute_info = std::make_unique<EpContextNodeComputeInfo>(*ep);
node_compute_infos[0] = node_compute_info.release();
} else {
// For Mul nodes during initial compilation, we need exactly 2 inputs
std::vector<Ort::ConstValueInfo> node_inputs = nodes[0].GetInputs();
if (node_inputs.size() != 2) {
std::string err_msg = "Mul node should have 2 inputs, got " + std::to_string(node_inputs.size());
Ort::Status status(err_msg.c_str(), ORT_EP_FAIL);
return status.release();
}

// Create MulKernel for Mul nodes
ep->mul_kernels_.emplace(fused_node_name,
std::make_unique<MulKernel>(ep->ort_api, ep->logger_,
ep->float_initializers_,
node_inputs[0].GetName(),
node_inputs[1].GetName()));

// Use ExampleNodeComputeInfo for Mul nodes
auto node_compute_info = std::make_unique<ExampleNodeComputeInfo>(*ep);
node_compute_infos[0] = node_compute_info.release();

// Create EpContext nodes for the fused nodes we compiled (only for Mul, not EPContext).
if (ep->config_.enable_ep_context) {
assert(ep_context_nodes != nullptr);
RETURN_IF_ERROR(ep->CreateEpContextNodes(gsl::span<const OrtNode*>(fused_nodes, count),
gsl::span<OrtNode*>(ep_context_nodes, count)));
}
}
} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
Expand All @@ -362,7 +462,9 @@ void ORT_API_CALL ExampleEp::ReleaseNodeComputeInfosImpl(OrtEp* this_ptr,
size_t num_node_compute_infos) noexcept {
(void)this_ptr;
for (size_t i = 0; i < num_node_compute_infos; i++) {
delete static_cast<ExampleNodeComputeInfo*>(node_compute_infos[i]);
// All node compute info types derive from NodeComputeInfoBase which has a virtual destructor.
// This ensures correct polymorphic deletion without manual type dispatch.
delete static_cast<NodeComputeInfoBase*>(node_compute_infos[i]);
}
}

Expand Down Expand Up @@ -497,9 +599,9 @@ OrtStatus* ExampleNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr,
ExampleEp& ep = node_compute_info->ep;

std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context);
auto kernel_it = ep.Kernels().find(fused_node_name);
if (kernel_it == ep.Kernels().end()) {
std::string message = "Unable to get kernel for fused node with name " + fused_node_name;
auto kernel_it = ep.MulKernels().find(fused_node_name);
if (kernel_it == ep.MulKernels().end()) {
std::string message = "Unable to get MulKernel for fused node with name " + fused_node_name;
return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str());
}

Expand All @@ -521,3 +623,74 @@ void ExampleNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void
(void)kernel;
// Do nothing for this example.
}

//
// Implementation of EpContextNodeComputeInfo
//
EpContextNodeComputeInfo::EpContextNodeComputeInfo(ExampleEp& ep) : ep(ep) {
ort_version_supported = ORT_API_VERSION;
CreateState = CreateStateImpl;
Compute = ComputeImpl;
ReleaseState = ReleaseStateImpl;
}

OrtStatus* EpContextNodeComputeInfo::CreateStateImpl(OrtNodeComputeInfo* this_ptr,
OrtNodeComputeContext* compute_context,
void** compute_state) {
auto* node_compute_info = static_cast<EpContextNodeComputeInfo*>(this_ptr);
ExampleEp& ep = node_compute_info->ep;

std::string fused_node_name = ep.ep_api.NodeComputeContext_NodeName(compute_context);
auto kernel_it = ep.EpContextKernels().find(fused_node_name);
if (kernel_it == ep.EpContextKernels().end()) {
std::string message = "Unable to get EpContextKernel for fused node with name " + fused_node_name;
return ep.ort_api.CreateStatus(ORT_EP_FAIL, message.c_str());
}

EpContextKernel& kernel = *kernel_it->second;
*compute_state = &kernel;
return nullptr;
}

OrtStatus* EpContextNodeComputeInfo::ComputeImpl(OrtNodeComputeInfo* this_ptr, void* compute_state,
OrtKernelContext* kernel_context) {
(void)this_ptr;
EpContextKernel& kernel = *reinterpret_cast<EpContextKernel*>(compute_state);
return kernel.Compute(kernel_context);
}

void EpContextNodeComputeInfo::ReleaseStateImpl(OrtNodeComputeInfo* this_ptr, void* compute_state) {
(void)this_ptr;
(void)compute_state;
// Do nothing for this example.
}

//
// Implementation of GetCompiledModelCompatibilityInfo
//
/*static*/
const char* ORT_API_CALL ExampleEp::GetCompiledModelCompatibilityInfoImpl(OrtEp* this_ptr,
const OrtGraph* graph) noexcept {
// Suppress unused parameter warning. The ORT_UNUSED_PARAMETER macro is in internal headers
// (core/common/common.h) which are not available to plugin EPs using only public APIs.
// A real EP would inspect the graph for model-specific compatibility info.
(void)graph;
auto* ep = static_cast<ExampleEp*>(this_ptr);

// Generate a compatibility string that includes:
// - EP name
// - EP version (from factory)
// - ORT API version
//
// In a real EP, this might include driver versions, hardware IDs, etc.
// The string format is EP-defined and should be parseable by ValidateCompiledModelCompatibilityInfo.
ep->compatibility_info_ = ep->name_ + ";version=" + ep->factory_.GetEpVersionString() + ";ort_api_version=" +
std::to_string(ORT_API_VERSION);

IGNORE_ORTSTATUS(ep->ort_api.Logger_LogMessage(&ep->logger_,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
("GetCompiledModelCompatibilityInfo returning: " + ep->compatibility_info_).c_str(),
ORT_FILE, __LINE__, __FUNCTION__));

return ep->compatibility_info_.c_str();
}
Loading
Loading