Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include <vector>
#include "core/framework/compute_capability.h"
#include "core/framework/error_code_helper.h"
#include "core/framework/model_metadef_id_generator.h"
#include "core/framework/plugin_data_transfer.h"
#include "core/framework/plugin_ep_stream.h"
#include "core/graph/ep_api_types.h"
Expand Down Expand Up @@ -227,8 +226,6 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
return {};
}

ModelMetadefIdGenerator generator;

// Create ComputeCapability instances from OrtEpGraphSupportInfo::NodeGrouping instances.
for (const OrtEpGraphSupportInfo::NodeGrouping& node_grouping : api_graph_support_info.node_groupings) {
// Skip this node grouping if any node has already been assigned to another EP.
Expand Down Expand Up @@ -278,8 +275,9 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
// TODO(adrianlizarraga): Do not use the heavy-weight CreateSupportedPartitions just to check if the user
// provided a single partition. Use utils::MakeCapability() and create a new helper to check that there are no
// unsupported nodes in any path between supported nodes.
auto metadef_gen_functor = PluginEpMetaDefNameFunctor(metadef_id_generator_, graph_viewer, this->Type());
std::vector<std::unique_ptr<ComputeCapability>> capabilities = utils::CreateSupportedPartitions(
graph_viewer, node_set, /*stop_ops*/ {}, PluginEpMetaDefNameFunctor(generator, graph_viewer, this->Type()),
graph_viewer, node_set, /*stop_ops*/ {}, std::move(metadef_gen_functor),
this->Type(), this->Type(), /*node_unit_map*/ nullptr,
node_grouping.fusion_options.drop_constant_initializers);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/common/common.h"
#include "core/common/inlined_containers.h"
#include "core/framework/execution_provider.h"
#include "core/framework/model_metadef_id_generator.h"
#include "core/providers/providers.h"
#include "core/session/onnxruntime_c_api.h"

Expand Down Expand Up @@ -160,6 +161,13 @@ class PluginExecutionProvider : public IExecutionProvider {
// so that it is not destroyed until the EP itself is destroyed.
std::vector<FusedNodeState> fused_node_states_;

// Generates a model's hash and a monotonically increasing ID that is unique per model hash. The
// ID is used in the MetaDef name for a fused node containing a compiling EP's supported subgraph.
//
// The same generator instance must be used across calls to GetCapability() to ensure that fused nodes that live in
// different GraphViews (e.g., different branches of an If node) obtain a unique ID.
ModelMetadefIdGenerator metadef_id_generator_;

// Stores the EPContext Nodes created from the OrtNode instances returned by the underlying plugin EP.
// Need to store both the Node and NodeArg instances so that they are available when the GraphPartitioner
// calls IExecutionProvider::GetEpContextNodes().
Expand Down
69 changes: 31 additions & 38 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,44 +196,33 @@ const char* ORT_API_CALL ExampleEp ::GetNameImpl(const OrtEp* this_ptr) noexcept
return ep->name_.c_str();
}

OrtStatus* ExampleEp::SaveConstantInitializers(const OrtGraph* ort_graph) {
Ort::ConstGraph graph{ort_graph};

try {
std::vector<Ort::ConstValueInfo> initializers = graph.GetInitializers();
bool ExampleEp::CopiesConstantInitializers() const {
return !(config_.enable_ep_context && config_.enable_weightless_ep_context_nodes);
}

for (const auto& initializer : initializers) {
const bool is_constant = initializer.IsConstantInitializer();
OrtStatus* ExampleEp::TrySaveConstantInitializer(Ort::ConstValueInfo maybe_initializer) {
EXCEPTION_TO_RETURNED_STATUS_BEGIN
const bool is_constant = maybe_initializer.IsConstantInitializer();

if (is_constant) {
auto name = initializer.GetName();
Ort::ConstValue value;
auto status = initializer.GetInitializer(value);
if (!status.IsOK())
return status.release();
if (is_constant) {
auto name = maybe_initializer.GetName();
Ort::ConstValue value;
RETURN_IF_ERROR(maybe_initializer.GetInitializer(value));

auto type_shape = value.GetTensorTypeAndShapeInfo();
const size_t num_elems = type_shape.GetElementCount();
const ONNXTensorElementDataType elem_type = type_shape.GetElementType();
if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
return Ort::Status("Expected float32 initializers", ORT_INVALID_ARGUMENT).release();
auto type_shape = value.GetTensorTypeAndShapeInfo();
const size_t num_elems = type_shape.GetElementCount();
const ONNXTensorElementDataType elem_type = type_shape.GetElementType();
if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
return Ort::Status("Expected float32 initializers", ORT_INVALID_ARGUMENT).release();

std::vector<int64_t> dims = type_shape.GetShape();
const float* data = value.GetTensorData<float>();
std::vector<int64_t> dims = type_shape.GetShape();
const float* data = value.GetTensorData<float>();

FloatInitializer ep_initializer = {std::move(dims), std::vector<float>(data, data + num_elems)};
float_initializers_.emplace(std::move(name), std::move(ep_initializer));
}
}
} catch (const Ort::Exception& ex) {
Ort::Status status(ex);
return status.release();
} catch (const std::exception& ex) {
Ort::Status status(ex.what(), ORT_EP_FAIL);
return status.release();
FloatInitializer ep_initializer = {std::move(dims), std::vector<float>(data, data + num_elems)};
float_initializers_.emplace(std::move(name), std::move(ep_initializer));
}

return nullptr;
EXCEPTION_TO_RETURNED_STATUS_END
}

/*static*/
Expand Down Expand Up @@ -342,8 +331,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG
// Refer to the "ep.enable_weightless_ep_context_nodes"
// session configuration entry in onnxruntime_session_options_config_keys.h for more information about generating
// weightless EPContext models.
node_fusion_options.drop_constant_initializers = !(ep->config_.enable_ep_context &&
ep->config_.enable_weightless_ep_context_nodes);
node_fusion_options.drop_constant_initializers = ep->CopiesConstantInitializers();
RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(
graph_support_info,
reinterpret_cast<const OrtNode* const*>(supported_nodes.data()),
Expand Down Expand Up @@ -377,11 +365,6 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const

Ort::ConstGraph graph{ort_graphs[0]};

// In GetCapability(), this EP specified that it doesn't need ORT to provide constant initializers during inference.
// So, this EP saves constant initializers so that they're available during inference, but an actual EP
// implementation could transfer the weights to device memory.
ep->SaveConstantInitializers(graph);

std::vector<Ort::ConstNode> nodes = graph.GetNodes();
if (nodes.size() != 1) {
Ort::Status status("Expected to compile a single node", ORT_EP_FAIL);
Expand Down Expand Up @@ -437,6 +420,16 @@ OrtStatus* ORT_API_CALL ExampleEp::CompileImpl(_In_ OrtEp* this_ptr, _In_ const
return status.release();
}

// In GetCapability(), this EP may have specified that it doesn't need ORT to provide constant initializers
// during inference. If so, this EP saves copies of constant initializers so they're available during inference.
//
// We try to save each node input individually because graph.GetInitializers() does not return
// initializers defined in parent or sibling subgraphs.
if (ep->CopiesConstantInitializers()) {
RETURN_IF_ERROR(ep->TrySaveConstantInitializer(node_inputs[0]));
RETURN_IF_ERROR(ep->TrySaveConstantInitializer(node_inputs[1]));
}

// Create MulKernel for Mul nodes
ep->mul_kernels_.emplace(fused_node_name,
std::make_unique<MulKernel>(ep->ort_api, ep->logger_,
Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/test/autoep/library/example_plugin_ep/ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,12 @@ class ExampleEp : public OrtEp, public ApiPtrs {
OrtStatus* CreateEpContextNodes(gsl::span<const OrtNode*> fused_nodes,
/*out*/ gsl::span<OrtNode*> ep_context_nodes);

OrtStatus* SaveConstantInitializers(const OrtGraph* graph);
// Returns true if the EP should save constant initializers so that they are available during inference.
bool CopiesConstantInitializers() const;

// If the given `OrtValueInfo` represents a constant initializer, this function saves a copy of the initializer data
// within this EP instance so that it is available during inference.
OrtStatus* TrySaveConstantInitializer(Ort::ConstValueInfo maybe_initializer);

ExampleEpFactory& factory_;
std::string name_;
Expand Down
31 changes: 31 additions & 0 deletions onnxruntime/test/autoep/test_execution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1060,5 +1060,36 @@ TEST(OrtEpLibrary, PluginEp_GpuDevice_ReturnsInCompatible) {

api->ReleaseDeviceEpIncompatibilityDetails(details);
}

TEST(OrtEpLibrary, CompilingPluginEp_MultiSubgraphs_DuplicateMetaDefIdBug) {
// Test a fix to a bug that incorrectly assigned duplicate MetaDef IDs to fused nodes
// that live in different GraphViews (e.g., in different branches of an If node).
//
// The test model graph does the following computation:
// if (A) { C = Mul(B, 2.0) }
// else { C = Mul(B, 3) }
// return C
//
// The example plugin EP should support and execute both Mul nodes (as compiled fused nodes).
// However, the bug (in PluginExecutionProvider::GetCapability) assigned the same MetaDef ID
// to both compiled Mul nodes, which caused session creation to fail with error:
//
// > Failed to add kernel for example_ep_9433721956998717990_0 example_ep example_ep:
// Conflicting with a registered kernel with op versions. the since version is: 1
//
// The fix was to use the same instance of `ModelMetadefIdGenerator` across all calls to
// PluginExecutionProvider::GetCapability(). This ensures that the MetaDef IDs are unique.
RegisteredEpDeviceUniquePtr example_kernel_ep;
ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info,
example_kernel_ep));
Ort::ConstEpDevice plugin_ep_device(example_kernel_ep.get());

std::unordered_map<std::string, std::string> ep_options;
Ort::SessionOptions session_options;

session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options);
ASSERT_NO_FATAL_FAILURE(RunIfMulModel(session_options, /*if_condition*/ true));
}

} // namespace test
} // namespace onnxruntime
Loading