Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
2578d1f
Zero-copy I/O for plugin EPs with HOST_ACCESSIBLE memory
ericcraw Apr 9, 2026
204dcd7
First pass addressing review comments
ericcraw Apr 17, 2026
40b78f5
Simplify CanSourceSatisfyTarget and add test cases
ericcraw Apr 23, 2026
5a83360
Add OrtEp::GetMemoryInfoByMemType for plugin EP memory placement
ericcraw Apr 30, 2026
34421c2
Merge remote-tracking branch 'upstream/main' into host-accessible-all…
ericcraw Apr 30, 2026
e19213b
Fix existing tests
ericcraw Apr 30, 2026
24256a4
Add additional tests for GetMemoryInfoByMemType
ericcraw Apr 30, 2026
627377e
Merge remote-tracking branch 'upstream/main' into host-accessible-all…
ericcraw May 6, 2026
dc9fa91
Bump since api version to 1.27 for GetMemoryInfoByMemType
ericcraw May 6, 2026
c60c3ee
Merge remote-tracking branch 'upstream/main' into host-accessible-all…
ericcraw May 11, 2026
363d98c
Remove OrtEp::GetMemoryInfoByMemType
ericcraw May 11, 2026
8316547
Add OrtEp::GetDefaultMemoryDevice
ericcraw May 11, 2026
154969b
Apply suggestions from code review
ericcraw May 14, 2026
096b76a
simplify GetDefaultMemoryDevice_StatusErrorThrows
ericcraw May 14, 2026
6aa5b94
Add note to GetDefaultMemoryDevice comment header.
ericcraw May 26, 2026
b65cf33
Change default device fallback behavior for plugin eps.
ericcraw May 27, 2026
ee5c5aa
Merge remote-tracking branch 'upstream/main' into host-accessible-all…
ericcraw May 28, 2026
4a4d361
Update InferOrtDeviceFromDeviceMemoryInfo to reflect new fallback beh…
ericcraw May 29, 2026
7219dbe
Address review comments
ericcraw May 29, 2026
5313639
Slightly relax assert to account for control flow cases
ericcraw Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_ep_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2568,6 +2568,47 @@ struct OrtEp {
*/
ORT_API2_STATUS(OnSessionInitializationEnd, _In_ OrtEp* this_ptr);

/** \brief Get the EP's default memory device.
*
* The EP's default memory device identifies the hardware the EP operates on. ORT uses it to:
* - Determine if data copies are needed between EPs (inserting memcpy nodes at EP boundaries)
* - Determine if the EP is CPU-based (which affects synchronization and data transfer decisions)
* - Bind execution streams to the correct device
Comment thread
ericcraw marked this conversation as resolved.
*
* If the implementation allows an EP to be created with multiple EpDevices this should return the OrtMemoryDevice
* that ORT should consider as default for this EP instance.
*
* An OrtMemoryDevice is obtained from an OrtMemoryInfo via `OrtEpApi::MemoryInfo_GetMemoryDevice()`.
* Typically, an EP creates OrtMemoryInfo instances and registers them with its OrtEpDevice(s) via
* `OrtEpApi::EpDevice_AddAllocatorInfo()`. The OrtMemoryDevice returned here must correspond to an
* OrtMemoryInfo registered as an `OrtDeviceAllocator` entry (either `OrtDeviceMemoryType_DEFAULT` or
* `OrtDeviceMemoryType_HOST_ACCESSIBLE`). An OrtMemoryDevice from an `OrtReadOnlyAllocator` entry is
* not accepted as the EP's default/identity device.
*
* The returned pointer must remain valid for the lifetime of the OrtEp instance
* (typically by storing the parent OrtMemoryInfo as a member of the EP).
*
* If this function is not implemented (NULL), or if it sets `device` to NULL, ORT infers
* the default memory device from the first OrtEpDevice's `OrtDeviceAllocator` entry with
* `OrtDeviceMemoryType_DEFAULT` registered via `EpDevice_AddAllocatorInfo`. EPs created against
* multiple OrtEpDevices whose default memory devices differ should implement this function to
* disambiguate; otherwise the first OrtEpDevice's default memory device is used and the others
* are ignored for identity purposes. If no such allocator entry is registered, the EP defaults
* to a CPU memory device.
*
* \param[in] this_ptr The OrtEp instance.
* \param[out] device Set to the EP's default OrtMemoryDevice, or NULL to use the default behavior (described above).
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \note Implementation of this function is optional. If set to NULL (not implemented), ORT
* infers the default memory device using the default behavior described above.
*
* \since Version 1.27.
*/
ORT_API2_STATUS(GetDefaultMemoryDevice, _In_ const OrtEp* this_ptr,
_Outptr_result_maybenull_ const OrtMemoryDevice** device);

/** \brief Release a previously captured graph and its associated resources.
*
* Called when the caller no longer needs the captured graph for the given annotation ID.
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/core/framework/allocation_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -913,12 +913,7 @@ class PlannerImpl {
ProcessDef(index, node_output);
OrtDevice output_device = exec_provider->GetOrtDeviceByMemType(p_kernel_def->OutputMemoryType(i));
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
// Downstream nodes of certain providers may require a CPU accessible location override
// to make sure the EP does not incur an unnecessary copy.
// We only do it for CPU based EPs. We are not likely to encounter
// non CPU devices here since they are already taken care of by using MemCpy nodes earlier.
// However, we still ignore them.
if (output_device.Type() == OrtDevice::CPU) {
if (output_device.UsesCpuMemory()) {
Comment thread
ericcraw marked this conversation as resolved.
const auto& output_name = node_output->Name();
const auto consumers = graph_viewer_.GetConsumerNodes(output_name);
for (const auto* consumer : consumers) {
Expand Down
112 changes: 81 additions & 31 deletions onnxruntime/core/framework/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,53 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider) {
return provider.GetDevice().Type() == OrtDevice::CPU;
}

// Returns true if src memory can satisfy tgt's requirements without a data copy.
//
// HOST_ACCESSIBLE → DEFAULT is valid: the device can access HOST_ACCESSIBLE memory directly.
// DEFAULT → HOST_ACCESSIBLE is NOT valid: HOST_ACCESSIBLE implies CPU consumers, and DEFAULT
// memory is device-only — the CPU cannot read it.
//
// For the mixed case, src alignment must meet tgt's minimum requirement.
// Alignment 0 on tgt means "no alignment requirement". Alignment 0 on src means "unknown" and
// does not satisfy a non-zero tgt alignment requirement.
bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt) {
const bool src_is_cpu_mem = src.UsesCpuMemory();
const bool tgt_is_cpu_mem = tgt.UsesCpuMemory();

// Identical devices are always compatible.
if (src == tgt) {
return true;
}

// Alignment 0 means "unspecified" — treat tgt as compatible with any alignment requirement.
const bool is_alignment_satisfied = tgt.GetAlignment() == 0 ||
src.GetAlignment() >= tgt.GetAlignment();

const bool is_same_physical_device = src.Type() == tgt.Type() &&
src.Vendor() == tgt.Vendor() &&
src.Id() == tgt.Id();

// Both are CPU-accessible (CPU type or HOST_ACCESSIBLE memory).
if (src_is_cpu_mem && tgt_is_cpu_mem) {
// CPU target can read from any CPU or HOST_ACCESSIBLE source, regardless of the source device
if (tgt.Type() == OrtDevice::CPU) {
return is_alignment_satisfied;
}
Comment thread
ericcraw marked this conversation as resolved.
// Both are HOST_ACCESSIBLE on some device: require the same physical device.
return is_same_physical_device && is_alignment_satisfied;
}

// HOST_ACCESSIBLE source can serve a DEFAULT target on the same physical device —
// the device can DMA from HOST_ACCESSIBLE memory directly.
// The reverse (DEFAULT → HOST_ACCESSIBLE) is unsafe: HOST_ACCESSIBLE implies CPU consumers,
// and DEFAULT memory is device-only so the CPU cannot read it.
if (src_is_cpu_mem && !tgt_is_cpu_mem) {
return is_same_physical_device && is_alignment_satisfied;
}

return false;
Comment thread
yuslepukhin marked this conversation as resolved.
}

bool IsMemcpyNode(const Node& node) {
return node.Domain() == kOnnxDomain &&
(node.OpType() == "MemcpyFromHost" || node.OpType() == "MemcpyToHost");
Expand Down Expand Up @@ -117,6 +164,33 @@ const std::string& GetNodeInputProviderType(const SessionState::NodeInfo& info)
return required_provider_type;
}

// Populate device_fetches for the output-copy path.
// When the user pre-allocates a fetch buffer, reuse it directly as the EP's output buffer if
// the user's buffer (tgt) can satisfy the EP's output device (src) requirements — i.e.,
// CanSourceSatisfyTarget(tgt, src). This avoids a post-execution copy.
// Otherwise inserts an empty placeholder for the EP to allocate into.
static void PopulateDeviceFetches(gsl::span<const MLValueCopyInfo> fetch_copy_info,
Comment thread
ericcraw marked this conversation as resolved.
const std::vector<OrtValue>& fetches,
std::vector<OrtValue>& device_fetches) {
ORT_ENFORCE(fetch_copy_info.size() == fetches.size());
device_fetches.clear();
device_fetches.reserve(fetches.size());
for (size_t i = 0; i < fetches.size(); ++i) {
const auto& src = fetch_copy_info[i].source_device;
const auto& tgt = fetch_copy_info[i].target_device;

// The swapped order is intentional. We're checking if a user's fetch buffer (tgt)
// can be reused for EP's output (src) buffer — i.e. CanSourceSatisfyTarget(tgt, src).
// Example: A user provided CPU buffer cannot satisfy a non-CPU host accessible EP output device so a copy should be
// inserted.
if (CanSourceSatisfyTarget(tgt, src) && fetches[i].IsAllocated()) {
Comment thread
ericcraw marked this conversation as resolved.
device_fetches.push_back(fetches[i]);
} else {
device_fetches.push_back({});
}
}
}

// Copy MLValue. Uses DataTransferManager for device copy if necessary. If copy_tensor_pairs/copy_sparse_pairs is provided,
// src/dst pairs that need a device copy are added to copy_pairs so copying can be batches by the DataTransferManager
// implementation for performance reasons.
Expand All @@ -132,8 +206,9 @@ static Status BatchOrCopyMLValue(const SessionState& session_state,
std::vector<IDataTransfer::SrcDstPair>* copy_tensor_pairs = nullptr)
#endif
{
// same device so direct copy
if (copy_info.source_device == copy_info.target_device) {
// No data transfer needed if devices are identical, or the source can satisfy the target
// (HOST_ACCESSIBLE source serving a DEFAULT target on the same physical device).
if (CanSourceSatisfyTarget(copy_info.source_device, copy_info.target_device)) {
target_mlvalue = source_mlvalue;
return Status::OK();
}
Expand Down Expand Up @@ -324,7 +399,7 @@ static bool FinalizeCopyInfoForFeeds(gsl::span<const OrtDevice> feed_locations,
for (size_t i = 0, end = feed_locations.size(); i < end; ++i) {
copy_info[i].source_device = feed_locations[i];

if (copy_info[i].source_device != copy_info[i].target_device) {
if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) {
copy_needed = true;
}
}
Expand All @@ -345,7 +420,7 @@ static bool FinalizeCopyInfoForFetches(gsl::span<const OrtDevice* const>& fetch_
copy_info[i].target_device = *alloc_info;
}

if (copy_info[i].source_device != copy_info[i].target_device) {
if (!CanSourceSatisfyTarget(copy_info[i].source_device, copy_info[i].target_device)) {
copy_needed = true;
}
}
Expand Down Expand Up @@ -652,22 +727,9 @@ ExecuteGraphImpl(const SessionState& session_state,
feeds_to_use = device_feeds;
}

auto num_outputs = fetches.size();
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();

if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
// need intermediate fetches. use pre-allocated fetches where possible.
device_fetches.reserve(num_outputs);

for (size_t i = 0; i < num_outputs; ++i) {
if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) {
device_fetches.push_back(fetches[i]);
} else {
// use temporary value
device_fetches.push_back({});
}
}

PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches);
p_fetches = &device_fetches;
}

Expand Down Expand Up @@ -808,22 +870,10 @@ common::Status ExecutePartialGraphImpl(const SessionState& session_state, FeedsF
p_feeds = device_feeds;
}

auto num_outputs = fetches.size();
const auto& fetch_copy_info = feeds_fetches_manager.GetFetchesDeviceCopyInfo();

if (device_copy_checks.output_copy_needed == DeviceCopyCheck::Copy) {
// need intermediate fetches. use pre-allocated fetches where possible.
device_fetches.reserve(num_outputs);

for (size_t i = 0; i < num_outputs; ++i) {
if (fetch_copy_info[i].source_device == fetch_copy_info[i].target_device && fetches[i].IsAllocated()) {
device_fetches.push_back(fetches[i]);
} else {
// use temporary value
device_fetches.push_back({});
}
}

PopulateDeviceFetches(fetch_copy_info, fetches, device_fetches);
p_fetches = &device_fetches;
}

Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/core/framework/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ bool ProviderIsCpuBased(const IExecutionProvider& provider);

bool IsMemcpyNode(const Node& node);

// Returns true if src memory can satisfy tgt's requirements without a data copy.
// HOST_ACCESSIBLE -> DEFAULT is valid (device can access HOST_ACCESSIBLE memory directly).
// DEFAULT -> HOST_ACCESSIBLE is NOT valid (CPU cannot read device-only memory).
bool CanSourceSatisfyTarget(const OrtDevice& src, const OrtDevice& tgt);

common::Status CopyOneInputAcrossDevices(const SessionState& session_state, const std::string& input_name,
const OrtValue& orig_mlvalue, OrtValue& new_mlvalue);

Expand Down
38 changes: 13 additions & 25 deletions onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,35 +141,23 @@ struct PluginEpMetaDefNameFunctor {
// PluginExecutionProvider
//

static OrtDevice GetOrtDeviceForPluginEp(gsl::span<const OrtEpDevice* const> ep_devices) {
// Get the OrtDevice from OrtEpDevice.device_memory_info if it is set. Otherwise, we set it to CPU.
// If there are multiple OrtEpDevice instances, the device_memory_info must be consistent for all.
static OrtDevice GetOrtDeviceForPluginEp(const OrtEp& ep, gsl::span<const OrtEpDevice* const> ep_devices) {
// Resolve the EP's default device. If the EP implements GetDefaultMemoryDevice, use its
// answer directly. Otherwise fall back to the first OrtEpDevice's default memory info.

ORT_ENFORCE(!ep_devices.empty()); // Should not be possible to create an EP without OrtEpDevices.

const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info;

// Check assertion that all OrtEpDevice instances must have equivalent device_memory_infos
bool all_match = std::all_of(ep_devices.begin() + 1, ep_devices.end(),
[mem_a = device_memory_info](const OrtEpDevice* ep_device) {
const OrtMemoryInfo* mem_b = ep_device->device_memory_info;

if (mem_a == mem_b) {
return true; // Point to the same OrtMemoryInfo instance.
}

if (mem_a == nullptr || mem_b == nullptr) {
return false; // One is nullptr and the other is not.
}

// Both non-null but point to different instances. Use operator==.
return *mem_a == *mem_b;
});
if (!all_match) {
ORT_THROW("Error creating execution provider '", ep_devices[0]->ep_name,
"': expected all OrtEpDevice instances to use the same device_memory_info.");
if (ep.ort_version_supported >= 27 && ep.GetDefaultMemoryDevice != nullptr) {
Comment thread
ericcraw marked this conversation as resolved.
const OrtMemoryDevice* memory_device = nullptr;
Ort::ThrowOnError(ep.GetDefaultMemoryDevice(&ep, &memory_device));
if (memory_device != nullptr) {
return *static_cast<const OrtDevice*>(memory_device);
}
}
Comment thread
ericcraw marked this conversation as resolved.

// If there's no explicit default memory device, choose the first default memory info.
const OrtMemoryInfo* device_memory_info = ep_devices[0]->device_memory_info;

return device_memory_info != nullptr ? device_memory_info->device : OrtDevice();
}

Expand All @@ -189,7 +177,7 @@ PluginExecutionProvider::PluginExecutionProvider(UniqueOrtEp ep, const OrtSessio
gsl::span<const OrtEpDevice* const> ep_devices,
std::shared_ptr<KernelRegistry> kernel_registry,
const logging::Logger& logger)
: IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(ep_devices),
: IExecutionProvider(ep->GetName(ep.get()), GetOrtDeviceForPluginEp(*ep, ep_devices),
std::vector<const OrtEpDevice*>(ep_devices.begin(), ep_devices.end()), logger),
ort_ep_(std::move(ep)),
Comment thread
ericcraw marked this conversation as resolved.
ep_factory_(ep_factory),
Expand Down
9 changes: 9 additions & 0 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C
CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr
GetCompiledModelCompatibilityInfo = GetCompiledModelCompatibilityInfoImpl; // compatibility info for compiled models
Sync = SyncImpl; // optional. can be nullptr
GetDefaultMemoryDevice = GetDefaultMemoryDeviceImpl; // optional. can be nullptr

IGNORE_ORTSTATUS(ort_api.Logger_LogMessage(&logger_,
OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO,
Expand Down Expand Up @@ -602,6 +603,14 @@ OrtStatus* ORT_API_CALL ExampleEp::SyncImpl(_In_ OrtEp* this_ptr) noexcept {
return nullptr;
}

/*static*/
OrtStatus* ORT_API_CALL ExampleEp::GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr,
_Outptr_ const OrtMemoryDevice** device) noexcept {
const auto* ep = static_cast<const ExampleEp*>(this_ptr);
*device = ep->ep_api.MemoryInfo_GetMemoryDevice(ep->factory_.GetDefaultMemoryInfo());
return nullptr;
}

//
// Implementation of ExampleNodeComputeInfo
//
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/test/autoep/library/example_plugin_ep/ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class ExampleEp : public OrtEp, public ApiPtrs {

static OrtStatus* ORT_API_CALL SyncImpl(_In_ OrtEp* this_ptr) noexcept;

static OrtStatus* ORT_API_CALL GetDefaultMemoryDeviceImpl(_In_ const OrtEp* this_ptr,
_Outptr_ const OrtMemoryDevice** device) noexcept;

OrtStatus* CreateEpContextNodes(gsl::span<const OrtNode*> fused_nodes,
/*out*/ gsl::span<OrtNode*> ep_context_nodes);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs {
return vendor_id_;
}

const OrtMemoryInfo* GetDefaultMemoryInfo() const {
return default_memory_info_;
}

const OrtLogger& default_logger_; // default logger for the EP factory

private:
Expand Down
Loading
Loading