From 2578d1f213edd4313b2900f178f722c6eaa6b383 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 9 Apr 2026 16:40:45 -0700 Subject: [PATCH 01/16] Zero-copy I/O for plugin EPs with HOST_ACCESSIBLE memory Adds DevicesAreMemoryCompatible() to skip data copies between devices that share memory (CPU <-> HOST_ACCESSIBLE, or HOST_ACCESSIBLE <-> DEFAULT on the same physical device). Applied in feed/fetch copy planning and in BatchOrCopyMLValue. Overrides GetOrtDeviceByMemType() in PluginExecutionProvider so the allocation planner routes CPU-type I/O through the HOST_ACCESSIBLE allocator when the plugin EP has registered one. This enables the planner to place intermediate tensors (CPU EP -> plugin EP boundary) in HOST_ACCESSIBLE memory, avoiding copies at the partition boundary. Updates the in-place optimization check in the allocation planner to use UsesCpuMemory() so it recognises HOST_ACCESSIBLE outputs as CPU-memory-compatible. --- .../core/framework/allocation_planner.cc | 2 +- onnxruntime/core/framework/utils.cc | 84 ++++++++++++------- .../ep_plugin_provider_interfaces.cc | 12 +++ .../plugin_ep/ep_plugin_provider_interfaces.h | 2 + 4 files changed, 71 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 1c80d83f99feb..96077415656b4 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -918,7 +918,7 @@ class PlannerImpl { // We only do it for CPU based EPs. We are not likely to encounter // non CPU devices here since they are already taken care of by using MemCpy nodes earlier. // However, we still ignore them. - if (output_device.Type() == OrtDevice::CPU) { + if (output_device.UsesCpuMemory()) { const auto& output_name = node_output->Name(); const auto consumers = graph_viewer_.GetConsumerNodes(output_name); for (const auto* consumer : consumers) { diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index f1945ded10b07..60af176932465 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -50,6 +50,34 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) { return provider.GetDevice().Type() == OrtDevice::CPU; } +// Returns true if no data transfer is needed between the two devices. +// HOST_ACCESSIBLE memory is a superset — accessible by both host and device — so it can satisfy +// DEFAULT memory requirements on the same physical device without a copy. +static bool DevicesAreMemoryCompatible(const OrtDevice& a, const OrtDevice& b) { + const bool a_is_cpu_mem = a.UsesCpuMemory(); + const bool b_is_cpu_mem = b.UsesCpuMemory(); + + // Both CPU-accessible: compatible unless both are HOST_ACCESSIBLE on different physical devices. + if (a_is_cpu_mem && b_is_cpu_mem) { + if (a.Type() == OrtDevice::CPU || b.Type() == OrtDevice::CPU) { + return true; + } + return a.Type() == b.Type() && + a.Vendor() == b.Vendor() && + a.Id() == b.Id(); + } + + // HOST_ACCESSIBLE <-> DEFAULT: compatible only on the same physical device. + if ((a_is_cpu_mem != b_is_cpu_mem) && + a.Type() == b.Type() && + a.Vendor() == b.Vendor() && + a.Id() == b.Id()) { + return true; + } + + return false; +} + bool IsMemcpyNode(const Node& node) { return node.Domain() == kOnnxDomain && (node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost"); @@ -117,6 +145,24 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) return required_provider_type; } +// Populate device_fetches for the output-copy path. +// Reuses a pre-allocated user buffer when the memory is compatible (same device or HOST_ACCESSIBLE +// <-> DEFAULT on the same physical device); otherwise inserts an empty placeholder. +static void PopulateDeviceFetches(gsl::span fetch_copy_info, + const std::vector& fetches, + std::vector& device_fetches) { + device_fetches.reserve(fetches.size()); + for (size_t i = 0; i < fetches.size(); ++i) { + const auto& src = fetch_copy_info[i].source_device; + const auto& tgt = fetch_copy_info[i].target_device; + if ((src == tgt || DevicesAreMemoryCompatible(src, tgt)) && fetches[i].IsAllocated()) { + device_fetches.push_back(fetches[i]); + } else { + device_fetches.push_back({}); + } + } +} + // Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_tensor_pairs/copy_sparse_pairs is provided, // src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager // implementation for performance reasons. @@ -132,8 +178,10 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, std::vector* copy_tensor_pairs = nullptr) #endif { - // same device so direct copy - if (copy_info.source_device == copy_info.target_device) { + // No data transfer needed if devices are the same or memory-compatible + // (e.g. HOST_ACCESSIBLE <-> DEFAULT on the same physical device). + if (copy_info.source_device == copy_info.target_device || + DevicesAreMemoryCompatible(copy_info.source_device, copy_info.target_device)) { target_mlvalue = source_mlvalue; return Status::OK(); } @@ -324,7 +372,8 @@ static bool FinalizeCopyInfoForFeeds(gsl::span feed_locations, for (size_t i = 0, end = feed_locations.size(); i < end; ++i) { copy_info[i].source_device = feed_locations[i]; - if (copy_info[i].source_device != copy_info[i].target_device) { + if (copy_info[i].source_device != copy_info[i].target_device && + !DevicesAreMemoryCompatible(copy_info[i].source_device, copy_info[i].target_device)) { copy_needed = true; } } @@ -345,7 +394,8 @@ static bool FinalizeCopyInfoForFetches(gsl::span& fetch_ copy_info[i].target_device = *alloc_info; } - if (copy_info[i].source_device != copy_info[i].target_device) { + if (copy_info[i].source_device != copy_info[i].target_device && + !DevicesAreMemoryCompatible(copy_info[i].source_device, copy_info[i].target_device)) { copy_needed = true; } } @@ -656,18 +706,7 @@ ExecuteGraphImpl(const SessionState& session_state, const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo(); if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) { - // need intermediate fetches. use pre-allocated fetches where possible. - device_fetches.reserve(num_outputs); - - for (size_t i = 0; i < num_outputs; ++i) { - if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) { - device_fetches.push_back(fetches[i]); - } else { - // use temporary value - device_fetches.push_back({}); - } - } - + PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches); p_fetches = &device_fetches; } @@ -812,18 +851,7 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo(); if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) { - // need intermediate fetches. use pre-allocated fetches where possible. - device_fetches.reserve(num_outputs); - - for (size_t i = 0; i < num_outputs; ++i) { - if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) { - device_fetches.push_back(fetches[i]); - } else { - // use temporary value - device_fetches.push_back({}); - } - } - + PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches); p_fetches = &device_fetches; } diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index d32967f8b37e3..18f7e3131f22b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -215,6 +215,18 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio } } +OrtDevice PluginExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { + if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { + // Use the host-accessible allocator device if one was registered by the plugin. + // This avoids unnecessary copies between CPU and HOST_ACCESSIBLE memory. + if (!ep_devices_.empty() && ep_devices_[0]->host_accessible_memory_info != nullptr) { + return ep_devices_[0]->host_accessible_memory_info->device; + } + return OrtDevice(); + } + return GetDevice(); +} + PluginExecutionProvider::~PluginExecutionProvider() { if (ort_ep_ && !api_node_compute_infos_.empty() && ort_ep_->ReleaseNodeComputeInfos != nullptr) { ort_ep_->ReleaseNodeComputeInfos(ort_ep_.get(), api_node_compute_infos_.data(), diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 8218571a8b1fe..d0e213ba1ca3a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -149,6 +149,8 @@ class PluginExecutionProvider : public IExecutionProvider { common::Status ReplayGraph(int graph_annotation_id) override; OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override; + OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; + private: const logging::Logger& GetEpLoggerOrDefault() const; From 204dcd75191a28095ef11cf7a925290b7a54b8f1 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 16 Apr 2026 17:43:29 -0700 Subject: [PATCH 02/16] First pass addressing review comments --- .../core/framework/allocation_planner.cc | 5 - onnxruntime/core/framework/utils.cc | 84 ++++++++------ onnxruntime/core/framework/utils.h | 5 + .../ep_plugin_provider_interfaces.cc | 2 +- onnxruntime/test/framework/utils_test.cc | 107 ++++++++++++++++++ 5 files changed, 163 insertions(+), 40 deletions(-) create mode 100644 onnxruntime/test/framework/utils_test.cc diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 96077415656b4..006e6e0da8b56 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -913,11 +913,6 @@ class PlannerImpl { ProcessDef(index, node_output); OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Downstream nodes of certain providers may require a CPU accessible location override - // to make sure the EP does not incur an unnecessary copy. - // We only do it for CPU based EPs. We are not likely to encounter - // non CPU devices here since they are already taken care of by using MemCpy nodes earlier. - // However, we still ignore them. if (output_device.UsesCpuMemory()) { const auto& output_name = node_output->Name(); const auto consumers = graph_viewer_.GetConsumerNodes(output_name); diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 60af176932465..5e2c033fc727c 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -50,29 +50,48 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) { return provider.GetDevice().Type() == OrtDevice::CPU; } -// Returns true if no data transfer is needed between the two devices. -// HOST_ACCESSIBLE memory is a superset — accessible by both host and device — so it can satisfy -// DEFAULT memory requirements on the same physical device without a copy. -static bool DevicesAreMemoryCompatible(const OrtDevice& a, const OrtDevice& b) { - const bool a_is_cpu_mem = a.UsesCpuMemory(); - const bool b_is_cpu_mem = b.UsesCpuMemory(); - - // Both CPU-accessible: compatible unless both are HOST_ACCESSIBLE on different physical devices. - if (a_is_cpu_mem && b_is_cpu_mem) { - if (a.Type() == OrtDevice::CPU || b.Type() == OrtDevice::CPU) { - return true; +// Returns true if src memory can satisfy tgt's requirements without a data copy. +// +// HOST_ACCESSIBLE → DEFAULT is valid: the device can access HOST_ACCESSIBLE memory directly. +// DEFAULT → HOST_ACCESSIBLE is NOT valid: HOST_ACCESSIBLE implies CPU consumers, and DEFAULT +// memory is device-only — the CPU cannot read it. +// +// For the mixed case, src alignment must meet tgt's minimum requirement. +// Alignment 0 means "unspecified" and is treated as compatible with any requirement. +bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) { + const bool src_is_cpu_mem = src.UsesCpuMemory(); + const bool tgt_is_cpu_mem = tgt.UsesCpuMemory(); + + // Identical devices are always compatible. + if (src == tgt) { + return true; + } + + // Alignment 0 means "unspecified" — treat as compatible with any alignment requirement. + const bool is_alignment_satisfied = src.GetAlignment() == 0 || tgt.GetAlignment() == 0 || + src.GetAlignment() >= tgt.GetAlignment(); + + // Both are CPU-accessible (CPU type or HOST_ACCESSIBLE memory). + if (src_is_cpu_mem && tgt_is_cpu_mem) { + // CPU target can read from any CPU or HOST_ACCESSIBLE source, regardless of the source device + if (tgt.Type() == OrtDevice::CPU) { + return is_alignment_satisfied; } - return a.Type() == b.Type() && - a.Vendor() == b.Vendor() && - a.Id() == b.Id(); + // Both are HOST_ACCESSIBLE on some device: require the same physical device. + return src.Type() == tgt.Type() && + src.Vendor() == tgt.Vendor() && + src.Id() == tgt.Id() && is_alignment_satisfied; } - // HOST_ACCESSIBLE <-> DEFAULT: compatible only on the same physical device. - if ((a_is_cpu_mem != b_is_cpu_mem) && - a.Type() == b.Type() && - a.Vendor() == b.Vendor() && - a.Id() == b.Id()) { - return true; + // HOST_ACCESSIBLE source can serve a DEFAULT target on the same physical device — + // the device can DMA from HOST_ACCESSIBLE memory directly. + // The reverse (DEFAULT → HOST_ACCESSIBLE) is unsafe: HOST_ACCESSIBLE implies CPU consumers, + // and DEFAULT memory is device-only so the CPU cannot read it. + if (src_is_cpu_mem && !tgt_is_cpu_mem && + src.Type() == tgt.Type() && + src.Vendor() == tgt.Vendor() && + src.Id() == tgt.Id()) { + return is_alignment_satisfied; } return false; @@ -146,16 +165,19 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) } // Populate device_fetches for the output-copy path. -// Reuses a pre-allocated user buffer when the memory is compatible (same device or HOST_ACCESSIBLE -// <-> DEFAULT on the same physical device); otherwise inserts an empty placeholder. +// When the user pre-allocates a fetch buffer, reuse it directly as the EP's output buffer if +// the user's buffer (tgt) can satisfy the EP's output device (src) requirements — i.e., +// CanSourceSatisfyTarget(tgt, src). This avoids a post-execution copy. +// Otherwise inserts an empty placeholder for the EP to allocate into. static void PopulateDeviceFetches(gsl::span fetch_copy_info, const std::vector& fetches, std::vector& device_fetches) { + ORT_ENFORCE(fetch_copy_info.size() >= fetches.size()); device_fetches.reserve(fetches.size()); for (size_t i = 0; i < fetches.size(); ++i) { const auto& src = fetch_copy_info[i].source_device; const auto& tgt = fetch_copy_info[i].target_device; - if ((src == tgt || DevicesAreMemoryCompatible(src, tgt)) && fetches[i].IsAllocated()) { + if (CanSourceSatisfyTarget(tgt, src) && fetches[i].IsAllocated()) { device_fetches.push_back(fetches[i]); } else { device_fetches.push_back({}); @@ -178,10 +200,9 @@ static Status BatchOrCopyMLValue(const SessionState& session_state, std::vector* copy_tensor_pairs = nullptr) #endif { - // No data transfer needed if devices are the same or memory-compatible - // (e.g. HOST_ACCESSIBLE <-> DEFAULT on the same physical device). - if (copy_info.source_device == copy_info.target_device || - DevicesAreMemoryCompatible(copy_info.source_device, copy_info.target_device)) { + // No data transfer needed if devices are identical, or the source can satisfy the target + // (HOST_ACCESSIBLE source serving a DEFAULT target on the same physical device). + if (CanSourceSatisfyTarget(copy_info.source_device, copy_info.target_device)) { target_mlvalue = source_mlvalue; return Status::OK(); } @@ -372,8 +393,7 @@ static bool FinalizeCopyInfoForFeeds(gsl::span feed_locations, for (size_t i = 0, end = feed_locations.size(); i < end; ++i) { copy_info[i].source_device = feed_locations[i]; - if (copy_info[i].source_device != copy_info[i].target_device && - !DevicesAreMemoryCompatible(copy_info[i].source_device, copy_info[i].target_device)) { + if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) { copy_needed = true; } } @@ -394,8 +414,7 @@ static bool FinalizeCopyInfoForFetches(gsl::span& fetch_ copy_info[i].target_device = *alloc_info; } - if (copy_info[i].source_device != copy_info[i].target_device && - !DevicesAreMemoryCompatible(copy_info[i].source_device, copy_info[i].target_device)) { + if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) { copy_needed = true; } } @@ -702,9 +721,7 @@ ExecuteGraphImpl(const SessionState& session_state, feeds_to_use = device_feeds; } - auto num_outputs = fetches.size(); const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo(); - if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) { PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches); p_fetches = &device_fetches; @@ -847,7 +864,6 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF p_feeds = device_feeds; } - auto num_outputs = fetches.size(); const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo(); if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) { diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index aa6c9746e6d5b..ca3cca860ed8c 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -57,6 +57,11 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider); bool IsMemcpyNode(const Node& node); +// Returns true if src memory can satisfy tgt's requirements without a data copy. +// HOST_ACCESSIBLE -> DEFAULT is valid (device can access HOST_ACCESSIBLE memory directly). +// DEFAULT -> HOST_ACCESSIBLE is NOT valid (CPU cannot read device-only memory). +bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt); + common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name, const OrtValue& orig_mlvalue, OrtValue& new_mlvalue); diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 18f7e3131f22b..8c6a303251322 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -217,7 +217,7 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio OrtDevice PluginExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - // Use the host-accessible allocator device if one was registered by the plugin. + // Use the first host-accessble allocator device if one was registered by the plugin. // This avoids unnecessary copies between CPU and HOST_ACCESSIBLE memory. if (!ep_devices_.empty() && ep_devices_[0]->host_accessible_memory_info != nullptr) { return ep_devices_[0]->host_accessible_memory_info->device; diff --git a/onnxruntime/test/framework/utils_test.cc b/onnxruntime/test/framework/utils_test.cc new file mode 100644 index 0000000000000..56625de544561 --- /dev/null +++ b/onnxruntime/test/framework/utils_test.cc @@ -0,0 +1,107 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "core/framework/utils.h" + +namespace onnxruntime { +namespace test { + +constexpr OrtDevice::VendorId kTestVendor1 = 0x1234; +constexpr OrtDevice::VendorId kTestVendor2 = 0x5678; + +static OrtDevice Cpu() { + return OrtDevice{OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0}; +} + +static OrtDevice HostAccessible(OrtDevice::VendorId vendor, OrtDevice::DeviceId id, + OrtDevice::Alignment align = 0) { + return OrtDevice{OrtDevice::NPU, OrtDevice::MemType::HOST_ACCESSIBLE, vendor, id, align}; +} + +static OrtDevice Default(OrtDevice::VendorId vendor, OrtDevice::DeviceId id, + OrtDevice::Alignment align = 0) { + return OrtDevice{OrtDevice::NPU, OrtDevice::MemType::DEFAULT, vendor, id, align}; +} + +TEST(CanSourceSatisfyTargetTest, CpuSourceHostAccessibleTarget) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget(Cpu(), HostAccessible(kTestVendor1, 0))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleSourceCpuTarget) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget(HostAccessible(kTestVendor1, 0), Cpu())); +} + +// src == tgt early return: identical devices are always compatible +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDevice) { + auto dev = HostAccessible(kTestVendor1, 0, 16); + EXPECT_TRUE(utils::CanSourceSatisfyTarget(dev, dev)); +} + +// Branch 3: both HOST_ACCESSIBLE, different physical device +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentId) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor1, 1))); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentVendor) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor2, 0))); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentAlignment) { + // Different alignment => OrtDevice::operator== returns false + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 16), HostAccessible(kTestVendor1, 0, 32))); +} + +// Branch 4: HOST_ACCESSIBLE (src) -> DEFAULT (tgt), same physical device +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSameDevice) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 0))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultAlignmentSatisfied) { + // src alignment >= tgt alignment: compatible + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 64), Default(kTestVendor1, 0, 32))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultAlignmentInsufficient) { + // src alignment < tgt alignment: incompatible + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 16), Default(kTestVendor1, 0, 64))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSrcAlignmentZero) { + // 0 = unspecified, treated as wildcard + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 0), Default(kTestVendor1, 0, 64))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultTgtAlignmentZero) { + // 0 = unspecified, treated as wildcard + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 16), Default(kTestVendor1, 0, 0))); +} + +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultDifferentDeviceId) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 1))); +} + +// Branch 5: incompatible cases + +TEST(CanSourceSatisfyTargetTest, DefaultToHostAccessibleRejected) { + // Reversed direction: CPU cannot read DEFAULT (device-only) memory + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + Default(kTestVendor1, 0), HostAccessible(kTestVendor1, 0))); +} + +TEST(CanSourceSatisfyTargetTest, DefaultToDefaultRejected) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + Default(kTestVendor1, 0), Default(kTestVendor2, 0))); +} + +} // namespace test +} // namespace onnxruntime From 40b78f53f15e8cb563eacf9f0e76ad5b6c2e04ea Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 23 Apr 2026 13:51:21 -0700 Subject: [PATCH 03/16] Simplify CanSourceSatisfyTarget and add test cases --- onnxruntime/core/framework/utils.cc | 19 +++++----- onnxruntime/test/framework/utils_test.cc | 46 +++++++++++++++++------- 2 files changed, 43 insertions(+), 22 deletions(-) diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 5e2c033fc727c..91c6cb1e3d34b 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -67,10 +67,14 @@ bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) { return true; } - // Alignment 0 means "unspecified" — treat as compatible with any alignment requirement. - const bool is_alignment_satisfied = src.GetAlignment() == 0 || tgt.GetAlignment() == 0 || + // Alignment 0 means "unspecified" — treat tgt as compatible with any alignment requirement. + const bool is_alignment_satisfied = tgt.GetAlignment() == 0 || src.GetAlignment() >= tgt.GetAlignment(); + const bool is_same_physical_device = src.Type() == tgt.Type() && + src.Vendor() == tgt.Vendor() && + src.Id() == tgt.Id(); + // Both are CPU-accessible (CPU type or HOST_ACCESSIBLE memory). if (src_is_cpu_mem && tgt_is_cpu_mem) { // CPU target can read from any CPU or HOST_ACCESSIBLE source, regardless of the source device @@ -78,20 +82,15 @@ bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) { return is_alignment_satisfied; } // Both are HOST_ACCESSIBLE on some device: require the same physical device. - return src.Type() == tgt.Type() && - src.Vendor() == tgt.Vendor() && - src.Id() == tgt.Id() && is_alignment_satisfied; + return is_same_physical_device && is_alignment_satisfied; } // HOST_ACCESSIBLE source can serve a DEFAULT target on the same physical device — // the device can DMA from HOST_ACCESSIBLE memory directly. // The reverse (DEFAULT → HOST_ACCESSIBLE) is unsafe: HOST_ACCESSIBLE implies CPU consumers, // and DEFAULT memory is device-only so the CPU cannot read it. - if (src_is_cpu_mem && !tgt_is_cpu_mem && - src.Type() == tgt.Type() && - src.Vendor() == tgt.Vendor() && - src.Id() == tgt.Id()) { - return is_alignment_satisfied; + if (src_is_cpu_mem && !tgt_is_cpu_mem) { + return is_same_physical_device && is_alignment_satisfied; } return false; diff --git a/onnxruntime/test/framework/utils_test.cc b/onnxruntime/test/framework/utils_test.cc index 56625de544561..2b2f2ec9376ba 100644 --- a/onnxruntime/test/framework/utils_test.cc +++ b/onnxruntime/test/framework/utils_test.cc @@ -38,7 +38,6 @@ TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDevice) { EXPECT_TRUE(utils::CanSourceSatisfyTarget(dev, dev)); } -// Branch 3: both HOST_ACCESSIBLE, different physical device TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentId) { EXPECT_FALSE(utils::CanSourceSatisfyTarget( HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor1, 1))); @@ -49,13 +48,6 @@ TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentVendor) { HostAccessible(kTestVendor1, 0), HostAccessible(kTestVendor2, 0))); } -TEST(CanSourceSatisfyTargetTest, BothHostAccessibleDifferentAlignment) { - // Different alignment => OrtDevice::operator== returns false - EXPECT_FALSE(utils::CanSourceSatisfyTarget( - HostAccessible(kTestVendor1, 0, 16), HostAccessible(kTestVendor1, 0, 32))); -} - -// Branch 4: HOST_ACCESSIBLE (src) -> DEFAULT (tgt), same physical device TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSameDevice) { EXPECT_TRUE(utils::CanSourceSatisfyTarget( HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 0))); @@ -74,8 +66,7 @@ TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultAlignmentInsufficient) { } TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultSrcAlignmentZero) { - // 0 = unspecified, treated as wildcard - EXPECT_TRUE(utils::CanSourceSatisfyTarget( + EXPECT_FALSE(utils::CanSourceSatisfyTarget( HostAccessible(kTestVendor1, 0, 0), Default(kTestVendor1, 0, 64))); } @@ -90,8 +81,6 @@ TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultDifferentDeviceId) { HostAccessible(kTestVendor1, 0), Default(kTestVendor1, 1))); } -// Branch 5: incompatible cases - TEST(CanSourceSatisfyTargetTest, DefaultToHostAccessibleRejected) { // Reversed direction: CPU cannot read DEFAULT (device-only) memory EXPECT_FALSE(utils::CanSourceSatisfyTarget( @@ -103,5 +92,38 @@ TEST(CanSourceSatisfyTargetTest, DefaultToDefaultRejected) { Default(kTestVendor1, 0), Default(kTestVendor2, 0))); } +// Early return: identical CPU devices are always compatible. +TEST(CanSourceSatisfyTargetTest, CpuToCpuIdentical) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget(Cpu(), Cpu())); +} + +// Early return: identical DEFAULT devices on the same physical device are compatible. +TEST(CanSourceSatisfyTargetTest, DefaultToDefaultSameDevice) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget(Default(kTestVendor1, 0), Default(kTestVendor1, 0))); +} + +// Both HOST_ACCESSIBLE, same physical device — alignment variations (not the early-return path +// because src and tgt differ in alignment, so src != tgt). +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDeviceAlignmentSatisfied) { + EXPECT_TRUE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 64), HostAccessible(kTestVendor1, 0, 32))); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDeviceAlignmentInsufficient) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 16), HostAccessible(kTestVendor1, 0, 32))); +} + +TEST(CanSourceSatisfyTargetTest, BothHostAccessibleSameDeviceSrcAlignmentZero) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0, 0), HostAccessible(kTestVendor1, 0, 32))); +} + +// HOST_ACCESSIBLE → DEFAULT: same device id but different vendor fails is_same_physical_device. +TEST(CanSourceSatisfyTargetTest, HostAccessibleToDefaultDifferentVendor) { + EXPECT_FALSE(utils::CanSourceSatisfyTarget( + HostAccessible(kTestVendor1, 0), Default(kTestVendor2, 0))); +} + } // namespace test } // namespace onnxruntime From 5a833606b22e21cadcc83d691ce9c296c03bd4e7 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 30 Apr 2026 11:50:47 -0700 Subject: [PATCH 04/16] Add OrtEp::GetMemoryInfoByMemType for plugin EP memory placement Co-authored-by: Copilot --- .../core/session/onnxruntime_ep_c_api.h | 26 +++++++++++++ .../ep_plugin_provider_interfaces.cc | 37 +++++++++++++------ .../plugin_ep/ep_plugin_provider_interfaces.h | 3 +- .../test/framework/ep_plugin_provider_test.cc | 3 +- 4 files changed, 56 insertions(+), 13 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 8ff15e5c35ed5..8dffe70f245de 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2462,6 +2462,32 @@ struct OrtEp { */ ORT_API_T(OrtGraphCaptureNodeAssignmentPolicy, GetGraphCaptureNodeAssignmentPolicy, _In_ const OrtEp* this_ptr); + + /** \brief Returns the OrtMemoryInfo the EP wants used for the given OrtMemType. + * + * Lets an EP declare, per OrtMemType, the memory the runtime should associate with that + * role (default device memory, CPU-side inputs, CPU-side outputs). ORT may consult this + * any time it needs to resolve placement for the EP. + * + * Implementations should be deterministic: a given OrtMemType should always map to the + * same OrtMemoryInfo for the lifetime of the OrtEp. Caching the answer up front is the + * recommended pattern; returned pointers must remain valid while the OrtEp is alive. + * + * Return nullptr for any OrtMemType to defer to ORT's built-in resolution for that type. + * Plugins may opt in selectively. + * + * \param[in] this_ptr The OrtEp instance. + * \param[in] mem_type The memory type to query. + * \return The OrtMemoryInfo the EP wants associated with the given mem_type, or nullptr + * to defer to ORT. + * + * \note Implementation of this function is optional. If set to NULL, ORT applies its + * built-in resolution for every OrtMemType. + * + * \since Version 1.26. + */ + ORT_API_T(const OrtMemoryInfo*, GetMemoryInfoByMemType, _In_ const OrtEp* this_ptr, + _In_ OrtMemType mem_type); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 8c6a303251322..d37a76d9a1098 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -28,6 +28,18 @@ namespace onnxruntime { +static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices); + +// Single source of truth for the OrtEp::GetMemoryInfoByMemType callback (added in EP API +// version 26): version-gated, null-checked. Returns nullptr if the EP did not opt in or if +// the EP returned nullptr to defer to ORT's built-in fallback. +static const OrtMemoryInfo* TryGetEpMemoryInfo(const OrtEp& ep, OrtMemType mem_type) { + if (ep.ort_version_supported >= 26 && ep.GetMemoryInfoByMemType != nullptr) { + return ep.GetMemoryInfoByMemType(&ep, mem_type); + } + return nullptr; +} + // // PluginExecutionProviderFactory // @@ -84,10 +96,17 @@ Status PluginExecutionProviderFactory::CreatePluginExecutionProvider( std::shared_ptr kernel_registry; ORT_RETURN_IF_ERROR(GetPluginEpKernelRegistry(*ort_ep, kernel_registry)); + // The EP is the single source of truth for its default device when it opts in; otherwise + // GetOrtDeviceForPluginEp throws on heterogeneous ep_devices. Resolved here once and + // forwarded to PluginExecutionProvider's IExecutionProvider base. + const OrtMemoryInfo* default_info = TryGetEpMemoryInfo(*ort_ep, OrtMemTypeDefault); + OrtDevice default_device = default_info ? default_info->device : GetOrtDeviceForPluginEp(devices_); + plugin_ep = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), session_options, ep_factory_, devices_, kernel_registry, - *logger.ToInternal()); + *logger.ToInternal(), + default_device); return Status::OK(); } @@ -164,8 +183,9 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio OrtEpFactory& ep_factory, gsl::span ep_devices, std::shared_ptr kernel_registry, - const logging::Logger& logger) - : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), + const logging::Logger& logger, + OrtDevice default_device) + : IExecutionProvider(ep->GetName(ep.get()), default_device, std::vector(ep_devices.begin(), ep_devices.end()), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), @@ -216,15 +236,10 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio } OrtDevice PluginExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { - // Use the first host-accessble allocator device if one was registered by the plugin. - // This avoids unnecessary copies between CPU and HOST_ACCESSIBLE memory. - if (!ep_devices_.empty() && ep_devices_[0]->host_accessible_memory_info != nullptr) { - return ep_devices_[0]->host_accessible_memory_info->device; - } - return OrtDevice(); + if (const OrtMemoryInfo* info = TryGetEpMemoryInfo(*ort_ep_, mem_type)) { + return info->device; } - return GetDevice(); + return IExecutionProvider::GetOrtDeviceByMemType(mem_type); } PluginExecutionProvider::~PluginExecutionProvider() { diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index d0e213ba1ca3a..229419b55fde2 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -91,7 +91,8 @@ class PluginExecutionProvider : public IExecutionProvider { explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, gsl::span ep_devices, std::shared_ptr kernel_registry, - const logging::Logger& logger); + const logging::Logger& logger, + OrtDevice default_device); ~PluginExecutionProvider(); std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index e046ee7067e4b..4683a2740a0dd 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -138,7 +138,8 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = { g_test_ort_ep_factory, ep_devices, /*kernel_registry*/ nullptr, - logging_manager.DefaultLogger()); + logging_manager.DefaultLogger(), + OrtDevice()); auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; return result; From e19213b5177de56fc292996e1b0e8ee455b6c96f Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 30 Apr 2026 15:17:16 -0700 Subject: [PATCH 05/16] Fix existing tests Co-authored-by: Copilot --- .../ep_plugin_provider_interfaces.cc | 44 ++++++++----------- .../plugin_ep/ep_plugin_provider_interfaces.h | 3 +- .../test/framework/ep_plugin_provider_test.cc | 3 +- 3 files changed, 21 insertions(+), 29 deletions(-) diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 225be3605a1ac..795ccab4d5999 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -32,18 +32,6 @@ namespace onnxruntime { -static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices); - -// Single source of truth for the OrtEp::GetMemoryInfoByMemType callback (added in EP API -// version 26): version-gated, null-checked. Returns nullptr if the EP did not opt in or if -// the EP returned nullptr to defer to ORT's built-in fallback. -static const OrtMemoryInfo* TryGetEpMemoryInfo(const OrtEp& ep, OrtMemType mem_type) { - if (ep.ort_version_supported >= 26 && ep.GetMemoryInfoByMemType != nullptr) { - return ep.GetMemoryInfoByMemType(&ep, mem_type); - } - return nullptr; -} - // // PluginExecutionProviderFactory // @@ -100,17 +88,10 @@ Status PluginExecutionProviderFactory::CreatePluginExecutionProvider( std::shared_ptr kernel_registry; ORT_RETURN_IF_ERROR(GetPluginEpKernelRegistry(*ort_ep, kernel_registry)); - // The EP is the single source of truth for its default device when it opts in; otherwise - // GetOrtDeviceForPluginEp throws on heterogeneous ep_devices. Resolved here once and - // forwarded to PluginExecutionProvider's IExecutionProvider base. - const OrtMemoryInfo* default_info = TryGetEpMemoryInfo(*ort_ep, OrtMemTypeDefault); - OrtDevice default_device = default_info ? default_info->device : GetOrtDeviceForPluginEp(devices_); - plugin_ep = std::make_unique(UniqueOrtEp(ort_ep, OrtEpDeleter(ep_factory_)), session_options, ep_factory_, devices_, kernel_registry, - *logger.ToInternal(), - default_device); + *logger.ToInternal()); return Status::OK(); } @@ -140,12 +121,26 @@ struct PluginEpMetaDefNameFunctor { // PluginExecutionProvider // -static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices) { - // Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU. +// Single source of truth for the OrtEp::GetMemoryInfoByMemType callback (added in EP API +// version 26): version-gated, null-checked. Returns nullptr if the EP did not opt in or if +// the EP returned nullptr to defer to ORT's built-in fallback. +static const OrtMemoryInfo* TryGetEpMemoryInfo(const OrtEp& ep, OrtMemType mem_type) { + if (ep.ort_version_supported >= 26 && ep.GetMemoryInfoByMemType != nullptr) { + return ep.GetMemoryInfoByMemType(&ep, mem_type); + } + return nullptr; +} + +static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::span ep_devices) { + // Get the OrtDevice from the Ep's default memory info. Otherwise, we set it to CPU. // If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all. ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices. + if (const OrtMemoryInfo* info = TryGetEpMemoryInfo(ep, OrtMemTypeDefault)) { + return info->device; + } + const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info; // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos @@ -187,9 +182,8 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio OrtEpFactory& ep_factory, gsl::span ep_devices, std::shared_ptr kernel_registry, - const logging::Logger& logger, - OrtDevice default_device) - : IExecutionProvider(ep->GetName(ep.get()), default_device, + const logging::Logger& logger) + : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(*ep, ep_devices), std::vector(ep_devices.begin(), ep_devices.end()), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 229419b55fde2..d0e213ba1ca3a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -91,8 +91,7 @@ class PluginExecutionProvider : public IExecutionProvider { explicit PluginExecutionProvider(UniqueOrtEp ep, const OrtSessionOptions& session_options, OrtEpFactory& ep_factory, gsl::span ep_devices, std::shared_ptr kernel_registry, - const logging::Logger& logger, - OrtDevice default_device); + const logging::Logger& logger); ~PluginExecutionProvider(); std::shared_ptr GetKernelRegistry() const override; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 3c2d4f62b8430..5fe77d8c62e09 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -142,8 +142,7 @@ MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = { g_test_ort_ep_factory, ep_devices, /*kernel_registry*/ nullptr, - logging_manager.DefaultLogger(), - OrtDevice()); + logging_manager.DefaultLogger()); auto result = MakeTestOrtEpResult{std::move(ep), ort_ep_raw}; return result; From 24256a4449222a7c0d59b165c1c902c7a127b609 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 30 Apr 2026 15:50:02 -0700 Subject: [PATCH 06/16] Add additional tests for GetMemoryInfoByMemType --- .../test/framework/ep_plugin_provider_test.cc | 153 +++++++++++++++++- 1 file changed, 152 insertions(+), 1 deletion(-) diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 5fe77d8c62e09..457a6a84cb57e 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -4,8 +4,10 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include +#include #include #include +#include #include #include "gsl/gsl" #include "gtest/gtest.h" @@ -73,6 +75,31 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } + + // Optional mem_infos returned by GetMemoryInfoByMemTypeImpl below. nullptr means "defer + // to ORT's built-in fallback" for that mem_type. Tests set these directly. + const OrtMemoryInfo* test_mem_info_default = nullptr; + const OrtMemoryInfo* test_mem_info_cpu_input = nullptr; + const OrtMemoryInfo* test_mem_info_cpu_output = nullptr; + // Counter incremented every time GetMemoryInfoByMemTypeImpl is invoked (used by tests + // that assert the version gate prevents the callback from firing). + mutable std::atomic get_memory_info_by_mem_type_call_count{0}; + + static const OrtMemoryInfo* ORT_API_CALL GetMemoryInfoByMemTypeImpl(const OrtEp* this_ptr, + OrtMemType mem_type) noexcept { + const auto* test_ep = static_cast(this_ptr); + test_ep->get_memory_info_by_mem_type_call_count.fetch_add(1, std::memory_order_relaxed); + switch (mem_type) { + case OrtMemTypeDefault: + return test_ep->test_mem_info_default; + case OrtMemTypeCPUInput: + return test_ep->test_mem_info_cpu_input; + case OrtMemTypeCPUOutput: + return test_ep->test_mem_info_cpu_output; + default: + return nullptr; + } + } }; // This factory doesn't do anything other than implement ReleaseEp(). @@ -116,6 +143,11 @@ OrtDevice MakeTestOrtDevice(OrtDevice::DeviceType device_type, OrtDevice::Memory return OrtDevice(device_type, memory_type, /*vendor_id*/ 0xBE57, /*device_id*/ 0, /*alignment*/ 16); } +OrtMemoryInfo MakeTestOrtMemoryInfo(const char* name, const OrtDevice& device, + OrtMemType mem_type = OrtMemTypeDefault) { + return OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, device, mem_type); +} + struct MakeTestOrtEpResult { std::unique_ptr ep; // the IExecutionProvider wrapping the TestOrtEp gsl::not_null ort_ep; // the wrapped TestOrtEp, owned by `ep` @@ -123,13 +155,20 @@ struct MakeTestOrtEpResult { // Creates an IExecutionProvider that wraps a TestOrtEp. // The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. -MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}) { +// `setup` runs on the raw TestOrtEp before the PluginExecutionProvider is constructed -- +// callbacks consulted at construction time (e.g., GetMemoryInfoByMemType seeding +// default_device_) must be configured here. +MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}, + std::function setup = nullptr) { // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices. static std::unique_ptr ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); static std::unique_ptr ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get()); auto ort_ep_raw = std::make_unique().release(); auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); + if (setup) { + setup(*ort_ep_raw); + } auto ort_session_options = Ort::SessionOptions{}; if (ep_devices.empty()) { @@ -365,6 +404,118 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { #endif // !defined(ORT_NO_EXCEPTIONS) } +// Callback wiring: when an EP implements GetMemoryInfoByMemType, the result seeds +// default_device_ at construction and is returned by GetOrtDeviceByMemType at runtime. +// This is the only test that proves the new API is actually consulted. +TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_SeedsDefaultDevice) { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); + auto mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp GPU HA", ort_device); + + // ep_device intentionally has no device_memory_info -- the legacy path would yield + // OrtDevice() (plain CPU). The callback must override that. + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; + test_ep.test_mem_info_default = &mem_info; + }); + + // Construction-time seeding (default_device_ via GetDevice()) and runtime query must + // both return the callback's answer, not OrtDevice(). + ASSERT_EQ(ep->GetDevice(), ort_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + // At minimum: one construction-time call + one runtime call. + ASSERT_GE(ort_ep->get_memory_info_by_mem_type_call_count.load(), 2); +} + +// Version gate: ort_version_supported < 26 must bypass the callback at both call sites. +// Without this guard ORT would call into a function pointer the EP didn't claim to support. +TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_VersionGateBypassesCallback) { + auto callback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto callback_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp GPU", callback_device); + + // Distinct device_memory_info on the ep_device -- the legacy fallback should win. + auto fallback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto fallback_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp NPU", fallback_device); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), &fallback_mem_info); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.ort_version_supported = 25; // older than the GetMemoryInfoByMemType API version + test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; + test_ep.test_mem_info_default = &callback_mem_info; + }); + + // The callback was set but must not be consulted -- fallback drives default_device_. + ASSERT_EQ(ep->GetDevice(), fallback_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), fallback_device); + ASSERT_EQ(ort_ep->get_memory_info_by_mem_type_call_count.load(), 0); +} + +// Heterogeneous ep_devices: today's GetOrtDeviceForPluginEp throws when ep_devices have +// inconsistent device_memory_info. The callback unblocks that case for plugins that +// natively compose physical devices (e.g., HETERO/AUTO). +TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_HeterogeneousEpDevicesUnblocked) { + auto gpu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto gpu_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp GPU", gpu_device); + auto npu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto npu_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp NPU", npu_device); + + auto representative_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, + OrtDevice::MemType::HOST_ACCESSIBLE); + auto representative_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp HETERO", + representative_device); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), &gpu_mem_info); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), &npu_mem_info); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + // Without the callback, this same ep_devices list throws (covered by the + // InferOrtDeviceFromDeviceMemoryInfo heterogeneity case). With the callback returning a + // representative mem_info, construction succeeds and default_device_ is the callback's + // answer. + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; + test_ep.test_mem_info_default = &representative_mem_info; + }); + + ASSERT_EQ(ep->GetDevice(), representative_device); +} + +// Per-OrtMemType routing: the runtime path threads `mem_type` through to the callback. +// Each OrtMemType must resolve independently, not always to the OrtMemTypeDefault answer. +TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_PerMemTypeRouting) { + auto default_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); + auto default_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp Default", default_device); + auto cpu_input_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE); + auto cpu_input_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp CPUInput", cpu_input_device, + OrtMemTypeCPUInput); + auto cpu_output_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::HOST_ACCESSIBLE); + auto cpu_output_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp CPUOutput", cpu_output_device, + OrtMemTypeCPUOutput); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; + test_ep.test_mem_info_default = &default_mem_info; + test_ep.test_mem_info_cpu_input = &cpu_input_mem_info; + test_ep.test_mem_info_cpu_output = &cpu_output_mem_info; + }); + + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), default_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeCPUInput), cpu_input_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeCPUOutput), cpu_output_device); +} + static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path, const char* ep_name, const std::unordered_set& ep_node_names, From dc9fa91b45a9f0537c5dfe0a0d61d6dc83a700a4 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Tue, 5 May 2026 18:37:21 -0700 Subject: [PATCH 07/16] Bump since api version to 1.27 for GetMemoryInfoByMemType --- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 2 +- .../core/session/plugin_ep/ep_plugin_provider_interfaces.cc | 4 ++-- onnxruntime/test/framework/ep_plugin_provider_test.cc | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 415762c93dbb0..48c342031025d 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2573,7 +2573,7 @@ struct OrtEp { * \note Implementation of this function is optional. If set to NULL, ORT applies its * built-in resolution for every OrtMemType. * - * \since Version 1.26. + * \since Version 1.27. */ ORT_API_T(const OrtMemoryInfo*, GetMemoryInfoByMemType, _In_ const OrtEp* this_ptr, _In_ OrtMemType mem_type); diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 795ccab4d5999..9fdc277b82a19 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -122,10 +122,10 @@ struct PluginEpMetaDefNameFunctor { // // Single source of truth for the OrtEp::GetMemoryInfoByMemType callback (added in EP API -// version 26): version-gated, null-checked. Returns nullptr if the EP did not opt in or if +// version 27): version-gated, null-checked. Returns nullptr if the EP did not opt in or if // the EP returned nullptr to defer to ORT's built-in fallback. static const OrtMemoryInfo* TryGetEpMemoryInfo(const OrtEp& ep, OrtMemType mem_type) { - if (ep.ort_version_supported >= 26 && ep.GetMemoryInfoByMemType != nullptr) { + if (ep.ort_version_supported >= 27 && ep.GetMemoryInfoByMemType != nullptr) { return ep.GetMemoryInfoByMemType(&ep, mem_type); } return nullptr; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 457a6a84cb57e..326695fb006ff 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -430,7 +430,7 @@ TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_SeedsDefaultDevice) { ASSERT_GE(ort_ep->get_memory_info_by_mem_type_call_count.load(), 2); } -// Version gate: ort_version_supported < 26 must bypass the callback at both call sites. +// Version gate: ort_version_supported < 27 must bypass the callback at both call sites. // Without this guard ORT would call into a function pointer the EP didn't claim to support. TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_VersionGateBypassesCallback) { auto callback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); @@ -445,7 +445,7 @@ TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_VersionGateBypassesCall std::vector ep_devices{ort_ep_device.get()}; auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { - test_ep.ort_version_supported = 25; // older than the GetMemoryInfoByMemType API version + test_ep.ort_version_supported = 26; // older than the GetMemoryInfoByMemType API version test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; test_ep.test_mem_info_default = &callback_mem_info; }); From 363d98c5d85af026264bee3941b7fb4496015b53 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Mon, 11 May 2026 15:39:06 -0700 Subject: [PATCH 08/16] Remove OrtEp::GetMemoryInfoByMemType --- .../core/session/onnxruntime_ep_c_api.h | 26 --- .../ep_plugin_provider_interfaces.cc | 27 +--- .../plugin_ep/ep_plugin_provider_interfaces.h | 2 - .../test/framework/ep_plugin_provider_test.cc | 153 +----------------- 4 files changed, 4 insertions(+), 204 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index ec1bdc9c39897..76fb7ce93b600 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2552,32 +2552,6 @@ struct OrtEp { */ ORT_API2_STATUS(GetAvailableResource, _In_ const OrtEp* this_ptr, _Out_ OrtResourceCount* available); - /** \brief Returns the OrtMemoryInfo the EP wants used for the given OrtMemType. - * - * Lets an EP declare, per OrtMemType, the memory the runtime should associate with that - * role (default device memory, CPU-side inputs, CPU-side outputs). ORT may consult this - * any time it needs to resolve placement for the EP. - * - * Implementations should be deterministic: a given OrtMemType should always map to the - * same OrtMemoryInfo for the lifetime of the OrtEp. Caching the answer up front is the - * recommended pattern; returned pointers must remain valid while the OrtEp is alive. - * - * Return nullptr for any OrtMemType to defer to ORT's built-in resolution for that type. - * Plugins may opt in selectively. - * - * \param[in] this_ptr The OrtEp instance. - * \param[in] mem_type The memory type to query. - * \return The OrtMemoryInfo the EP wants associated with the given mem_type, or nullptr - * to defer to ORT. - * - * \note Implementation of this function is optional. If set to NULL, ORT applies its - * built-in resolution for every OrtMemType. - * - * \since Version 1.27. - */ - ORT_API_T(const OrtMemoryInfo*, GetMemoryInfoByMemType, _In_ const OrtEp* this_ptr, - _In_ OrtMemType mem_type); - /** \brief Called by ORT when session initialization is complete. * * This provides an opportunity for execution providers to optionally synchronize and diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 17b4828712e68..d8094fe68ea53 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -121,26 +121,12 @@ struct PluginEpMetaDefNameFunctor { // PluginExecutionProvider // -// Single source of truth for the OrtEp::GetMemoryInfoByMemType callback (added in EP API -// version 27): version-gated, null-checked. Returns nullptr if the EP did not opt in or if -// the EP returned nullptr to defer to ORT's built-in fallback. -static const OrtMemoryInfo* TryGetEpMemoryInfo(const OrtEp& ep, OrtMemType mem_type) { - if (ep.ort_version_supported >= 27 && ep.GetMemoryInfoByMemType != nullptr) { - return ep.GetMemoryInfoByMemType(&ep, mem_type); - } - return nullptr; -} - -static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::span ep_devices) { - // Get the OrtDevice from the Ep's default memory info. Otherwise, we set it to CPU. +static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices) { + // Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU. // If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all. ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices. - if (const OrtMemoryInfo* info = TryGetEpMemoryInfo(ep, OrtMemTypeDefault)) { - return info->device; - } - const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info; // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos @@ -183,7 +169,7 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio gsl::span ep_devices, std::shared_ptr kernel_registry, const logging::Logger& logger) - : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(*ep, ep_devices), + : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), std::vector(ep_devices.begin(), ep_devices.end()), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), @@ -233,13 +219,6 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio } } -OrtDevice PluginExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { - if (const OrtMemoryInfo* info = TryGetEpMemoryInfo(*ort_ep_, mem_type)) { - return info->device; - } - return IExecutionProvider::GetOrtDeviceByMemType(mem_type); -} - PluginExecutionProvider::~PluginExecutionProvider() { if (ort_ep_ && !api_node_compute_infos_.empty() && ort_ep_->ReleaseNodeComputeInfos != nullptr) { ort_ep_->ReleaseNodeComputeInfos(ort_ep_.get(), api_node_compute_infos_.data(), diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h index 3d2e0baa641f5..ba84403dec8aa 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h @@ -151,8 +151,6 @@ class PluginExecutionProvider : public IExecutionProvider { common::Status ReplayGraph(int graph_annotation_id) override; OrtGraphCaptureNodeAssignmentPolicy GetGraphCaptureNodeAssignmentPolicy() const override; - OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; - private: const logging::Logger& GetEpLoggerOrDefault() const; diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 348a806c7141c..80b638314bad9 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -4,10 +4,8 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include -#include #include #include -#include #include #include "gsl/gsl" #include "gtest/gtest.h" @@ -75,31 +73,6 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } - - // Optional mem_infos returned by GetMemoryInfoByMemTypeImpl below. nullptr means "defer - // to ORT's built-in fallback" for that mem_type. Tests set these directly. - const OrtMemoryInfo* test_mem_info_default = nullptr; - const OrtMemoryInfo* test_mem_info_cpu_input = nullptr; - const OrtMemoryInfo* test_mem_info_cpu_output = nullptr; - // Counter incremented every time GetMemoryInfoByMemTypeImpl is invoked (used by tests - // that assert the version gate prevents the callback from firing). - mutable std::atomic get_memory_info_by_mem_type_call_count{0}; - - static const OrtMemoryInfo* ORT_API_CALL GetMemoryInfoByMemTypeImpl(const OrtEp* this_ptr, - OrtMemType mem_type) noexcept { - const auto* test_ep = static_cast(this_ptr); - test_ep->get_memory_info_by_mem_type_call_count.fetch_add(1, std::memory_order_relaxed); - switch (mem_type) { - case OrtMemTypeDefault: - return test_ep->test_mem_info_default; - case OrtMemTypeCPUInput: - return test_ep->test_mem_info_cpu_input; - case OrtMemTypeCPUOutput: - return test_ep->test_mem_info_cpu_output; - default: - return nullptr; - } - } }; // This factory doesn't do anything other than implement ReleaseEp(). @@ -143,11 +116,6 @@ OrtDevice MakeTestOrtDevice(OrtDevice::DeviceType device_type, OrtDevice::Memory return OrtDevice(device_type, memory_type, /*vendor_id*/ 0xBE57, /*device_id*/ 0, /*alignment*/ 16); } -OrtMemoryInfo MakeTestOrtMemoryInfo(const char* name, const OrtDevice& device, - OrtMemType mem_type = OrtMemTypeDefault) { - return OrtMemoryInfo(name, OrtAllocatorType::OrtDeviceAllocator, device, mem_type); -} - struct MakeTestOrtEpResult { std::unique_ptr ep; // the IExecutionProvider wrapping the TestOrtEp gsl::not_null ort_ep; // the wrapped TestOrtEp, owned by `ep` @@ -155,20 +123,13 @@ struct MakeTestOrtEpResult { // Creates an IExecutionProvider that wraps a TestOrtEp. // The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. -// `setup` runs on the raw TestOrtEp before the PluginExecutionProvider is constructed -- -// callbacks consulted at construction time (e.g., GetMemoryInfoByMemType seeding -// default_device_) must be configured here. -MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}, - std::function setup = nullptr) { +MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}) { // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices. static std::unique_ptr ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); static std::unique_ptr ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get()); auto ort_ep_raw = std::make_unique().release(); auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); - if (setup) { - setup(*ort_ep_raw); - } auto ort_session_options = Ort::SessionOptions{}; if (ep_devices.empty()) { @@ -404,118 +365,6 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { #endif // !defined(ORT_NO_EXCEPTIONS) } -// Callback wiring: when an EP implements GetMemoryInfoByMemType, the result seeds -// default_device_ at construction and is returned by GetOrtDeviceByMemType at runtime. -// This is the only test that proves the new API is actually consulted. -TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_SeedsDefaultDevice) { - auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); - auto mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp GPU HA", ort_device); - - // ep_device intentionally has no device_memory_info -- the legacy path would yield - // OrtDevice() (plain CPU). The callback must override that. - auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); - auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); - std::vector ep_devices{ort_ep_device.get()}; - - auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { - test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; - test_ep.test_mem_info_default = &mem_info; - }); - - // Construction-time seeding (default_device_ via GetDevice()) and runtime query must - // both return the callback's answer, not OrtDevice(). - ASSERT_EQ(ep->GetDevice(), ort_device); - ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); - // At minimum: one construction-time call + one runtime call. - ASSERT_GE(ort_ep->get_memory_info_by_mem_type_call_count.load(), 2); -} - -// Version gate: ort_version_supported < 27 must bypass the callback at both call sites. -// Without this guard ORT would call into a function pointer the EP didn't claim to support. -TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_VersionGateBypassesCallback) { - auto callback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); - auto callback_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp GPU", callback_device); - - // Distinct device_memory_info on the ep_device -- the legacy fallback should win. - auto fallback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); - auto fallback_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp NPU", fallback_device); - - auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); - auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), &fallback_mem_info); - std::vector ep_devices{ort_ep_device.get()}; - - auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { - test_ep.ort_version_supported = 26; // older than the GetMemoryInfoByMemType API version - test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; - test_ep.test_mem_info_default = &callback_mem_info; - }); - - // The callback was set but must not be consulted -- fallback drives default_device_. - ASSERT_EQ(ep->GetDevice(), fallback_device); - ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), fallback_device); - ASSERT_EQ(ort_ep->get_memory_info_by_mem_type_call_count.load(), 0); -} - -// Heterogeneous ep_devices: today's GetOrtDeviceForPluginEp throws when ep_devices have -// inconsistent device_memory_info. The callback unblocks that case for plugins that -// natively compose physical devices (e.g., HETERO/AUTO). -TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_HeterogeneousEpDevicesUnblocked) { - auto gpu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); - auto gpu_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp GPU", gpu_device); - auto npu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); - auto npu_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp NPU", npu_device); - - auto representative_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, - OrtDevice::MemType::HOST_ACCESSIBLE); - auto representative_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp HETERO", - representative_device); - - auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); - auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); - auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), &gpu_mem_info); - auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), &npu_mem_info); - std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; - - // Without the callback, this same ep_devices list throws (covered by the - // InferOrtDeviceFromDeviceMemoryInfo heterogeneity case). With the callback returning a - // representative mem_info, construction succeeds and default_device_ is the callback's - // answer. - auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { - test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; - test_ep.test_mem_info_default = &representative_mem_info; - }); - - ASSERT_EQ(ep->GetDevice(), representative_device); -} - -// Per-OrtMemType routing: the runtime path threads `mem_type` through to the callback. -// Each OrtMemType must resolve independently, not always to the OrtMemTypeDefault answer. -TEST(PluginExecutionProviderTest, GetMemoryInfoByMemType_PerMemTypeRouting) { - auto default_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); - auto default_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp Default", default_device); - auto cpu_input_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::CPU, OrtDevice::MemType::HOST_ACCESSIBLE); - auto cpu_input_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp CPUInput", cpu_input_device, - OrtMemTypeCPUInput); - auto cpu_output_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::HOST_ACCESSIBLE); - auto cpu_output_mem_info = test_plugin_ep::MakeTestOrtMemoryInfo("TestOrtEp CPUOutput", cpu_output_device, - OrtMemTypeCPUOutput); - - auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); - auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); - std::vector ep_devices{ort_ep_device.get()}; - - auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { - test_ep.GetMemoryInfoByMemType = test_plugin_ep::TestOrtEp::GetMemoryInfoByMemTypeImpl; - test_ep.test_mem_info_default = &default_mem_info; - test_ep.test_mem_info_cpu_input = &cpu_input_mem_info; - test_ep.test_mem_info_cpu_output = &cpu_output_mem_info; - }); - - ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), default_device); - ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeCPUInput), cpu_input_device); - ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeCPUOutput), cpu_output_device); -} - static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path, const char* ep_name, const std::unordered_set& ep_node_names, From 8316547e3cbd3722389a696943bf0a714e9336e4 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Mon, 11 May 2026 16:05:54 -0700 Subject: [PATCH 09/16] Add OrtEp::GetDefaultMemoryDevice --- .../core/session/onnxruntime_ep_c_api.h | 36 ++++++ .../ep_plugin_provider_interfaces.cc | 17 ++- .../autoep/library/example_plugin_ep/ep.cc | 9 ++ .../autoep/library/example_plugin_ep/ep.h | 3 + .../library/example_plugin_ep/ep_factory.h | 4 + .../test/framework/ep_plugin_provider_test.cc | 122 +++++++++++++++++- 6 files changed, 186 insertions(+), 5 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 76fb7ce93b600..811502b78c27b 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2567,6 +2567,42 @@ struct OrtEp { * \since Version 1.27. */ ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr); + + /** \brief Get the EP's default memory device. + * + * The EP's default memory device identifies the hardware the EP operates on. ORT uses it to: + * - Determine if data copies are needed between EPs (inserting memcpy nodes at EP boundaries) + * - Determine if the EP is CPU-based (which affects synchronization and data transfer decisions) + * - Bind execution streams to the correct device + * + * An OrtMemoryDevice is obtained from an OrtMemoryInfo via `OrtEpApi::MemoryInfo_GetMemoryDevice()`. + * Typically, an EP creates OrtMemoryInfo instances and registers them with its OrtEpDevice(s) via + * `OrtEpApi::EpDevice_AddAllocatorInfo()`. The OrtMemoryDevice returned here must correspond to an + * OrtMemoryInfo registered as an `OrtDeviceAllocator` entry (either `OrtDeviceMemoryType_DEFAULT` or + * `OrtDeviceMemoryType_HOST_ACCESSIBLE`). An OrtMemoryDevice from an `OrtReadOnlyAllocator` entry is + * not accepted as the EP's default/identity device. + * + * The returned pointer must remain valid for the lifetime of the OrtEp instance + * (typically by storing the parent OrtMemoryInfo as a member of the EP). + * + * If this function is not implemented (NULL), or if it sets `device` to NULL, ORT infers + * the default memory device from the `OrtDeviceAllocator` entry with `OrtDeviceMemoryType_DEFAULT` + * registered via `EpDevice_AddAllocatorInfo`. In this fallback case, all OrtEpDevice instances must + * use the same `OrtDeviceMemoryType_DEFAULT` OrtMemoryInfo (or ORT cannot determine which device to + * use). If no such entry is registered, the EP defaults to a CPU memory device. + * + * \param[in] this_ptr The OrtEp instance. + * \param[out] device Set to the EP's default OrtMemoryDevice, or NULL to use the default behavior (described above). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \note Implementation of this function is optional. If set to NULL (not implemented), ORT + * infers the default memory device using the default behavior described above. + * + * \since Version 1.27. + */ + ORT_API2_STATUS(GetDefaultMemoryDevice, _In_ const OrtEp* this_ptr, + _Outptr_result_maybenull_ const OrtMemoryDevice** device); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index d8094fe68ea53..f7691ed5dc02a 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -121,12 +121,21 @@ struct PluginEpMetaDefNameFunctor { // PluginExecutionProvider // -static OrtDevice GetOrtDeviceForPluginEp(gsl::span ep_devices) { - // Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU. - // If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all. +static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::span ep_devices) { + // Resolve the EP's default device. If the EP implements GetDefaultMemoryDevice, use its + // answer directly. Otherwise infer from OrtEpDevice.device_memory_info (and enforce that + // all OrtEpDevice instances agree). ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices. + if (ep.ort_version_supported >= 27 && ep.GetDefaultMemoryDevice != nullptr) { + const OrtMemoryDevice* memory_device = nullptr; + Ort::ThrowOnError(ep.GetDefaultMemoryDevice(&ep, &memory_device)); + if (memory_device != nullptr) { + return *static_cast(memory_device); + } + } + const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info; // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos @@ -169,7 +178,7 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio gsl::span ep_devices, std::shared_ptr kernel_registry, const logging::Logger& logger) - : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices), + : IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(*ep, ep_devices), std::vector(ep_devices.begin(), ep_devices.end()), logger), ort_ep_(std::move(ep)), ep_factory_(ep_factory), diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index ca9a296501b04..fecf7ac9a4038 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -185,6 +185,7 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models Sync = SyncImpl; // optional. can be nullptr + GetDefaultMemoryDevice = GetDefaultMemoryDeviceImpl; // optional. can be nullptr IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -602,6 +603,14 @@ OrtStatus* ORT_API_CALL ExampleEp::SyncImpl(_In_ OrtEp* this_ptr) noexcept { return nullptr; } +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr, + _Outptr_ const OrtMemoryDevice** device) noexcept { + const auto* ep = static_cast(this_ptr); + *device = ep->ep_api.MemoryInfo_GetMemoryDevice(ep->factory_.GetDefaultMemoryInfo()); + return nullptr; +} + // // Implementation of ExampleNodeComputeInfo // diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 2ba13658c3364..5dcd9f07bef1f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -105,6 +105,9 @@ class ExampleEp : public OrtEp, public ApiPtrs { static OrtStatus* ORT_API_CALL SyncImpl(_In_ OrtEp* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr, + _Outptr_ const OrtMemoryDevice** device) noexcept; + OrtStatus* CreateEpContextNodes(gsl::span fused_nodes, /*out*/ gsl::span ep_context_nodes); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 91478047afb0a..4bb23f1bddace 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -38,6 +38,10 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return vendor_id_; } + const OrtMemoryInfo* GetDefaultMemoryInfo() const { + return default_memory_info_; + } + const OrtLogger& default_logger_; // default logger for the EP factory private: diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 80b638314bad9..ebc3583a8556b 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -4,8 +4,10 @@ #include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include +#include #include #include +#include #include #include "gsl/gsl" #include "gtest/gtest.h" @@ -73,6 +75,24 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { constexpr const char* ep_name = "TestOrtEp"; return ep_name; } + + // OrtMemoryDevice returned by GetDefaultMemoryDeviceImpl. nullptr means "defer to ORT". + const OrtMemoryDevice* test_default_memory_device = nullptr; + // If set, the impl returns this status without writing to *device. Used to verify the + // ThrowOnError path in PluginExecutionProvider construction. + OrtStatus* test_default_memory_device_status = nullptr; + mutable std::atomic get_default_memory_device_call_count{0}; + + static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(const OrtEp* this_ptr, + const OrtMemoryDevice** device) noexcept { + const auto* test_ep = static_cast(this_ptr); + test_ep->get_default_memory_device_call_count.fetch_add(1, std::memory_order_relaxed); + if (test_ep->test_default_memory_device_status != nullptr) { + return test_ep->test_default_memory_device_status; + } + *device = test_ep->test_default_memory_device; + return nullptr; + } }; // This factory doesn't do anything other than implement ReleaseEp(). @@ -123,13 +143,20 @@ struct MakeTestOrtEpResult { // Creates an IExecutionProvider that wraps a TestOrtEp. // The TestOrtEp is also exposed so that tests can manipulate its function pointers directly. -MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}) { +// `setup` runs on the raw TestOrtEp before PluginExecutionProvider is constructed -- +// callbacks consulted at construction time (e.g., GetDefaultMemoryDevice seeding +// default_device_) must be configured here. +MakeTestOrtEpResult MakeTestOrtEp(std::vector ep_devices = {}, + std::function setup = nullptr) { // Default OrtHardwareDevice and OrtEpDevice used if the caller does not explicitly provide ep_devices. static std::unique_ptr ort_hw_device = MakeTestOrtHardwareDevice(OrtHardwareDeviceType_CPU); static std::unique_ptr ort_ep_device = MakeTestOrtEpDevice(ort_hw_device.get()); auto ort_ep_raw = std::make_unique().release(); auto ort_ep = UniqueOrtEp(ort_ep_raw, OrtEpDeleter{g_test_ort_ep_factory}); + if (setup) { + setup(*ort_ep_raw); + } auto ort_session_options = Ort::SessionOptions{}; if (ep_devices.empty()) { @@ -365,6 +392,99 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { #endif // !defined(ORT_NO_EXCEPTIONS) } +// When the EP implements GetDefaultMemoryDevice, the result seeds default_device_ at +// construction and is returned by GetOrtDeviceByMemType(OrtMemTypeDefault) at runtime. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_SeedsDefaultDevice) { + auto ort_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::HOST_ACCESSIBLE); + + // ep_device intentionally has no device_memory_info -- the legacy path would yield + // OrtDevice() (plain CPU). The callback must override that. + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; + test_ep.test_default_memory_device = static_cast(&ort_device); + }); + + ASSERT_EQ(ep->GetDevice(), ort_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device); + ASSERT_GE(ort_ep->get_default_memory_device_call_count.load(), 1); +} + +// Version gate: ort_version_supported < 27 must bypass the callback. Without this guard +// ORT would call into a function pointer the EP didn't claim to support. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_VersionGateBypassesCallback) { + auto callback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + + auto fallback_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto fallback_mem_info = std::make_unique("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator, + fallback_device, OrtMemTypeDefault); + + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get(), fallback_mem_info.get()); + std::vector ep_devices{ort_ep_device.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.ort_version_supported = 26; // older than the GetDefaultMemoryDevice API version + test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; + test_ep.test_default_memory_device = static_cast(&callback_device); + }); + + ASSERT_EQ(ep->GetDevice(), fallback_device); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), fallback_device); + ASSERT_EQ(ort_ep->get_default_memory_device_call_count.load(), 0); +} + +// Heterogeneous ep_devices: GetOrtDeviceForPluginEp throws when ep_devices have +// inconsistent device_memory_info. The callback unblocks that case by letting the EP +// name a representative device directly. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_HeterogeneousEpDevicesUnblocked) { + auto gpu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); + auto gpu_mem_info = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, + gpu_device, OrtMemTypeDefault); + auto npu_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT); + auto npu_mem_info = std::make_unique("TestOrtEp NPU", OrtAllocatorType::OrtDeviceAllocator, + npu_device, OrtMemTypeDefault); + + auto representative_device = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, + OrtDevice::MemType::HOST_ACCESSIBLE); + + auto ort_hw_device_gpu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_hw_device_npu = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_NPU); + auto ort_ep_device_gpu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_gpu.get(), gpu_mem_info.get()); + auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), npu_mem_info.get()); + std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; + + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; + test_ep.test_default_memory_device = static_cast(&representative_device); + }); + + ASSERT_EQ(ep->GetDevice(), representative_device); +} + +#if !defined(ORT_NO_EXCEPTIONS) +// A non-OK status from the callback must propagate out of PluginExecutionProvider +// construction via Ort::ThrowOnError. +TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_StatusErrorThrows) { + auto ort_hw_device = test_plugin_ep::MakeTestOrtHardwareDevice(OrtHardwareDeviceType_GPU); + auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); + std::vector ep_devices{ort_ep_device.get()}; + + const auto& ort_api = *::OrtGetApiBase()->GetApi(ORT_API_VERSION); + OrtStatus* injected_status = ort_api.CreateStatus(ORT_FAIL, "injected failure"); + + ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { + test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; + test_ep.test_default_memory_device_status = injected_status; + }), + Ort::Exception); + // Ort::ThrowOnError releases the status it threw on, so we don't release it here. +} +#endif // !defined(ORT_NO_EXCEPTIONS) + static void LoadModelAndAssignNodesToEp(const ORTCHAR_T* model_path, const char* ep_name, const std::unordered_set& ep_node_names, From 154969b8ccfc1552084890f441c457f9d875aa4f Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 14 May 2026 11:08:38 -0700 Subject: [PATCH 10/16] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- onnxruntime/core/framework/utils.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index 91c6cb1e3d34b..f375bc134ea66 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -57,7 +57,8 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) { // memory is device-only — the CPU cannot read it. // // For the mixed case, src alignment must meet tgt's minimum requirement. -// Alignment 0 means "unspecified" and is treated as compatible with any requirement. +// Alignment 0 on tgt means "no alignment requirement". Alignment 0 on src means "unknown" and +// does not satisfy a non-zero tgt alignment requirement. bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) { const bool src_is_cpu_mem = src.UsesCpuMemory(); const bool tgt_is_cpu_mem = tgt.UsesCpuMemory(); @@ -172,6 +173,7 @@ static void PopulateDeviceFetches(gsl::span fetch_copy_in const std::vector& fetches, std::vector& device_fetches) { ORT_ENFORCE(fetch_copy_info.size() >= fetches.size()); + device_fetches.clear(); device_fetches.reserve(fetches.size()); for (size_t i = 0; i < fetches.size(); ++i) { const auto& src = fetch_copy_info[i].source_device; From 096b76a56f621cec3098d360342ff18002546d2f Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Thu, 14 May 2026 14:44:04 -0700 Subject: [PATCH 11/16] simplify GetDefaultMemoryDevice_StatusErrorThrows --- .../test/framework/ep_plugin_provider_test.cc | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index ebc3583a8556b..0908527721824 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -78,18 +78,12 @@ struct TestOrtEp : ::OrtEp, ApiPtrs { // OrtMemoryDevice returned by GetDefaultMemoryDeviceImpl. nullptr means "defer to ORT". const OrtMemoryDevice* test_default_memory_device = nullptr; - // If set, the impl returns this status without writing to *device. Used to verify the - // ThrowOnError path in PluginExecutionProvider construction. - OrtStatus* test_default_memory_device_status = nullptr; mutable std::atomic get_default_memory_device_call_count{0}; static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(const OrtEp* this_ptr, const OrtMemoryDevice** device) noexcept { const auto* test_ep = static_cast(this_ptr); test_ep->get_default_memory_device_call_count.fetch_add(1, std::memory_order_relaxed); - if (test_ep->test_default_memory_device_status != nullptr) { - return test_ep->test_default_memory_device_status; - } *device = test_ep->test_default_memory_device; return nullptr; } @@ -473,15 +467,12 @@ TEST(PluginExecutionProviderTest, GetDefaultMemoryDevice_StatusErrorThrows) { auto ort_ep_device = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device.get()); std::vector ep_devices{ort_ep_device.get()}; - const auto& ort_api = *::OrtGetApiBase()->GetApi(ORT_API_VERSION); - OrtStatus* injected_status = ort_api.CreateStatus(ORT_FAIL, "injected failure"); - ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices, [&](test_plugin_ep::TestOrtEp& test_ep) { - test_ep.GetDefaultMemoryDevice = test_plugin_ep::TestOrtEp::GetDefaultMemoryDeviceImpl; - test_ep.test_default_memory_device_status = injected_status; - }), - Ort::Exception); - // Ort::ThrowOnError releases the status it threw on, so we don't release it here. + test_ep.GetDefaultMemoryDevice = [](const OrtEp* /*this_ptr*/, const OrtMemoryDevice** /*device*/) noexcept { + return Ort::Status("injected failure", ORT_FAIL).release(); + }; + }), + Ort::Exception); } #endif // !defined(ORT_NO_EXCEPTIONS) From 6aa5b94f5b6a591c2b723e30f124868c2d33f80e Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Tue, 26 May 2026 14:50:52 -0700 Subject: [PATCH 12/16] Add note to GetDefaultMemoryDevice comment header. --- include/onnxruntime/core/session/onnxruntime_ep_c_api.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 811502b78c27b..6c345ea5299e7 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2575,6 +2575,9 @@ struct OrtEp { * - Determine if the EP is CPU-based (which affects synchronization and data transfer decisions) * - Bind execution streams to the correct device * + * If the implementation allows an EP to be created with multiple EpDevices this should return the OrtMemoryDevice + * that ORT should consider as default for this EP instance. + * * An OrtMemoryDevice is obtained from an OrtMemoryInfo via `OrtEpApi::MemoryInfo_GetMemoryDevice()`. * Typically, an EP creates OrtMemoryInfo instances and registers them with its OrtEpDevice(s) via * `OrtEpApi::EpDevice_AddAllocatorInfo()`. The OrtMemoryDevice returned here must correspond to an From b65cf3313ec9e30b43afa25029ed3496f1f1cab5 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Wed, 27 May 2026 16:44:21 -0700 Subject: [PATCH 13/16] Change default device fallback behavior for plugin eps. --- .../core/session/onnxruntime_ep_c_api.h | 10 +++++--- .../ep_plugin_provider_interfaces.cc | 25 ++----------------- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 6c345ea5299e7..eaa232f684b35 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -2589,10 +2589,12 @@ struct OrtEp { * (typically by storing the parent OrtMemoryInfo as a member of the EP). * * If this function is not implemented (NULL), or if it sets `device` to NULL, ORT infers - * the default memory device from the `OrtDeviceAllocator` entry with `OrtDeviceMemoryType_DEFAULT` - * registered via `EpDevice_AddAllocatorInfo`. In this fallback case, all OrtEpDevice instances must - * use the same `OrtDeviceMemoryType_DEFAULT` OrtMemoryInfo (or ORT cannot determine which device to - * use). If no such entry is registered, the EP defaults to a CPU memory device. + * the default memory device from the first OrtEpDevice's `OrtDeviceAllocator` entry with + * `OrtDeviceMemoryType_DEFAULT` registered via `EpDevice_AddAllocatorInfo`. EPs created against + * multiple OrtEpDevices whose default memory devices differ should implement this function to + * disambiguate; otherwise the first OrtEpDevice's default memory device is used and the others + * are ignored for identity purposes. If no such allocator entry is registered, the EP defaults + * to a CPU memory device. * * \param[in] this_ptr The OrtEp instance. * \param[out] device Set to the EP's default OrtMemoryDevice, or NULL to use the default behavior (described above). diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index f7691ed5dc02a..7f7da56b9df7b 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -123,8 +123,7 @@ struct PluginEpMetaDefNameFunctor { static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::span ep_devices) { // Resolve the EP's default device. If the EP implements GetDefaultMemoryDevice, use its - // answer directly. Otherwise infer from OrtEpDevice.device_memory_info (and enforce that - // all OrtEpDevice instances agree). + // answer directly. Otherwise fall back to the first OrtEpDevice's default memory info. ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices. @@ -136,29 +135,9 @@ static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::spandevice_memory_info; - // Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos - bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(), - [mem_a = device_memory_info](const OrtEpDevice* ep_device) { - const OrtMemoryInfo* mem_b = ep_device->device_memory_info; - - if (mem_a == mem_b) { - return true; // Point to the same OrtMemoryInfo instance. - } - - if (mem_a == nullptr || mem_b == nullptr) { - return false; // One is nullptr and the other is not. - } - - // Both non-null but point to different instances. Use operator==. - return *mem_a == *mem_b; - }); - if (!all_match) { - ORT_THROW("Error creating execution provider '", ep_devices[0]->ep_name, - "': expected all OrtEpDevice instances to use the same device_memory_info."); - } - return device_memory_info != nullptr ? device_memory_info->device : OrtDevice(); } From 4a4d3610fc63c50ae30315f4cf3a3ad8804a972e Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Fri, 29 May 2026 11:49:43 -0700 Subject: [PATCH 14/16] Update InferOrtDeviceFromDeviceMemoryInfo to reflect new fallback behavior --- onnxruntime/test/framework/ep_plugin_provider_test.cc | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index e0c205542f425..fef185e20f341 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -509,9 +509,8 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), OrtDevice()); } -#if !defined(ORT_NO_EXCEPTIONS) - // 2 OrtEpDevice instances with DIFFERENT device_memory_info instances. - // Should throw an exception on construction of PluginExecutionProvider. + // 2 OrtEpDevice instances with DIFFERENT device_memory_info instances, no GetDefaultMemoryDevice. + // PluginExecutionProvider falls back to the first OrtEpDevice's device_memory_info. { auto ort_device_gpu = test_plugin_ep::MakeTestOrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT); auto ort_memory_info_gpu = std::make_unique("TestOrtEp GPU", OrtAllocatorType::OrtDeviceAllocator, @@ -527,9 +526,9 @@ TEST(PluginExecutionProviderTest, InferOrtDeviceFromDeviceMemoryInfo) { auto ort_ep_device_npu = test_plugin_ep::MakeTestOrtEpDevice(ort_hw_device_npu.get(), ort_memory_info_npu.get()); std::vector ep_devices{ort_ep_device_gpu.get(), ort_ep_device_npu.get()}; - ASSERT_THROW(test_plugin_ep::MakeTestOrtEp(ep_devices), OnnxRuntimeException); + auto [ep, ort_ep] = test_plugin_ep::MakeTestOrtEp(ep_devices); + ASSERT_EQ(ep->GetOrtDeviceByMemType(OrtMemTypeDefault), ort_device_gpu); } -#endif // !defined(ORT_NO_EXCEPTIONS) } // When the EP implements GetDefaultMemoryDevice, the result seeds default_device_ at From 7219dbea62bb8e0578f5405c84d97d216e2f2010 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Fri, 29 May 2026 15:17:08 -0700 Subject: [PATCH 15/16] Address review comments --- onnxruntime/core/framework/utils.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index f375bc134ea66..fe6b9b668b99e 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -172,12 +172,17 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) static void PopulateDeviceFetches(gsl::span fetch_copy_info, const std::vector& fetches, std::vector& device_fetches) { - ORT_ENFORCE(fetch_copy_info.size() >= fetches.size()); + ORT_ENFORCE(fetch_copy_info.size() == fetches.size()); device_fetches.clear(); device_fetches.reserve(fetches.size()); for (size_t i = 0; i < fetches.size(); ++i) { const auto& src = fetch_copy_info[i].source_device; const auto& tgt = fetch_copy_info[i].target_device; + + // The swapped order is intentional. We're checking if a user's fetch buffer (tgt) + // can be reused for EP's output (src) buffer — i.e. CanSourceSatisfyTarget(tgt, src). + // Example: A user provided CPU buffer cannot satisfy a non-CPU host accessible EP output device so a copy should be + // inserted. if (CanSourceSatisfyTarget(tgt, src) && fetches[i].IsAllocated()) { device_fetches.push_back(fetches[i]); } else { From 5313639892c34e1e86c8c535df3abb0e50bc2149 Mon Sep 17 00:00:00 2001 From: Eric Crawford Date: Mon, 1 Jun 2026 14:15:51 -0700 Subject: [PATCH 16/16] Slightly relax assert to account for control flow cases --- onnxruntime/core/framework/utils.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/framework/utils.cc b/onnxruntime/core/framework/utils.cc index fe6b9b668b99e..905cbc077af7c 100644 --- a/onnxruntime/core/framework/utils.cc +++ b/onnxruntime/core/framework/utils.cc @@ -172,7 +172,10 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info) static void PopulateDeviceFetches(gsl::span fetch_copy_info, const std::vector& fetches, std::vector& device_fetches) { - ORT_ENFORCE(fetch_copy_info.size() == fetches.size()); + // fetches is empty on the subgraph path, where control-flow nodes such as Loop pass an empty fetches vector + // and let ExecuteThePlan allocate the outputs; in that case device_fetches stays empty. A partially sized + // fetches vector indicates a bug. + ORT_ENFORCE(fetches.empty() || fetch_copy_info.size() == fetches.size()); device_fetches.clear(); device_fetches.reserve(fetches.size()); for (size_t i = 0; i < fetches.size(); ++i) {