Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
84f0621
Initial pieces for Stream support.
skottmckay Jul 2, 2025
8725d76
Fix build warning
skottmckay Jul 2, 2025
d386a5f
Add ability to get OrtMemoryInfo for session inputs/outputs
skottmckay Jul 3, 2025
6df4f97
Define _In_reads_opt_ on linux
skottmckay Jul 3, 2025
566bdf3
Merge
skottmckay Jul 4, 2025
f68b472
Update CUDA EP with allocator, data transfer and stream implementations.
skottmckay Jul 4, 2025
a8ae38b
Add test with data copy of inputs. Uses a shared allocator and CopyTe…
skottmckay Jul 7, 2025
affc925
CUDA OrtSyncStreamImpl works.
skottmckay Jul 8, 2025
862d6c4
Merge
skottmckay Jul 8, 2025
1d60ee1
Merge remote-tracking branch 'origin/main' into skottmckay/OrtSyncStream
skottmckay Jul 8, 2025
860c692
Fix a couple of issues
skottmckay Jul 8, 2025
e35d8d7
Cleanups. Removed things not currently in use.
skottmckay Jul 9, 2025
9b3378f
Minor refinements
skottmckay Jul 9, 2025
686632d
Fix linux build errors
skottmckay Jul 9, 2025
c1ad6e6
Fix minimal build
skottmckay Jul 9, 2025
3b3217f
Fix x86 build
skottmckay Jul 9, 2025
c1adcff
More CI fixes.
skottmckay Jul 9, 2025
34c0717
Allow for CUDA library to not be found.
skottmckay Jul 10, 2025
6bbd4fd
Update example EP to be more testable.
skottmckay Jul 10, 2025
2b524d0
Fix some tests
skottmckay Jul 10, 2025
f75ac14
Auto-unregister EP in autoep unit test.
skottmckay Jul 10, 2025
91e7d17
Change tolerance for CIs. The values indicate the data was copied.
skottmckay Jul 11, 2025
410faef
Add optional arg for OrtEp to OrtEpFactory.CreateAllocator. Matches C…
skottmckay Jul 11, 2025
4924a7c
Merge remote-tracking branch 'origin/main' into skottmckay/OrtSyncStream
skottmckay Jul 11, 2025
43bc6f5
Merge remote-tracking branch 'origin/main' into skottmckay/OrtSyncStream
skottmckay Jul 14, 2025
02b6c96
Merge
skottmckay Jul 17, 2025
67f0452
Address PR comments.
skottmckay Jul 17, 2025
55c83ab
Fix Linux/Android build issue
skottmckay Jul 17, 2025
b8d965e
Fix minimal and qnn builds
skottmckay Jul 17, 2025
e40d961
Fix x86 build
skottmckay Jul 18, 2025
33e807b
Fix unused parameter warning.
skottmckay Jul 18, 2025
0914cdd
Merge remote-tracking branch 'origin/main' into skottmckay/OrtSyncStream
skottmckay Jul 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -498,17 +498,19 @@ set (ONNXRUNTIME_AUTOEP_TEST_SRC_DIR "${TEST_SRC_DIR}/autoep")
set (ONNXRUNTIME_EP_GRAPH_TEST_SRC_DIR "${TEST_SRC_DIR}/ep_graph")

set (onnxruntime_shared_lib_test_SRC
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_data_copy.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ort_format_models.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.h
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.cc
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h
${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc)
)

if (NOT onnxruntime_MINIMAL_BUILD)
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc)
Expand Down
29 changes: 28 additions & 1 deletion include/onnxruntime/core/session/environment.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/common/status.h"
#include "core/framework/allocator.h"
#include "core/framework/execution_provider.h"
#include "core/framework/data_transfer_manager.h"
#include "core/platform/device_discovery.h"
#include "core/platform/threadpool.h"

Expand Down Expand Up @@ -140,6 +141,10 @@
OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type,
const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator);
Status ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type);

const DataTransferManager& GetDataTransferManager() const {
return data_transfer_mgr_;
}
#endif // !defined(ORT_MINIMAL_BUILD)

// return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator
Expand Down Expand Up @@ -185,6 +190,23 @@

using OrtAllocatorUniquePtr = std::unique_ptr<OrtAllocator, std::function<void(OrtAllocator*)>>;

// if the user calls CreateSharedAllocator and wraps the plugin EP's allocator with an arena we end up with
// OrtAllocator from EP -> wrapped in IAllocatorImplWrappingOrtAllocator -> inside a BFCArena IAllocator.
// we can put that in shared_allocators_ for sessions to use, but to have an OrtAllocator available in
// shared_ort_allocators_ that can be used outside of a session we need to additionally wrap that in an
// OrtAllocatorImplWrappingIAllocator. way too many levels of indirection but that is what it is currently.
// we need something to own that final OrtAllocator, so we add it to arena_ort_allocators_.
//
// TODO: we could split out the BFCArena implementation so it can be plugged into either an IAllocator

Check warning on line 200 in include/onnxruntime/core/session/environment.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: include/onnxruntime/core/session/environment.h:200: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
// or an OrtAllocator instance to reduce the indirection a little.
// with that we get an OrtAllocator from the EP, wrap it with an OrtAllocator based BFCArena, and wrap that with the
// IAllocatorImplWrappingOrtAllocator which takes ownership of the OrtAllocator and is in shared_allocators_.
//
// Alternatively we can disable wrapping an EP's allocator with a BFCArena and say the EP should provide the arena
// implementation directly. They're free to copy BFCArena as it came from TF originally. Or we could provide a
// cut-and-paste BFCArena implementation that works using the EP API that can be included in the EP source.
std::unordered_map<const OrtMemoryInfo*, std::unique_ptr<OrtAllocatorImplWrappingIAllocator>> arena_ort_allocators_;

#if !defined(ORT_MINIMAL_BUILD)
// register EPs that are built into the ORT binary so they can take part in AutoEP selection
// added to ep_libraries
Expand All @@ -207,7 +229,9 @@

std::unique_ptr<EpLibrary> library;
std::vector<std::unique_ptr<OrtEpDevice>> execution_devices;
std::vector<EpFactoryInternal*> internal_factories; // factories that can create IExecutionProvider instances
std::vector<OrtEpFactory*> factories;
std::vector<EpFactoryInternal*> internal_factories; // factories that can create IExecutionProvider instances
std::vector<plugin_ep::DataTransfer*> data_transfers; // data transfer instances for this EP.

private:
EpInfo() = default;
Expand All @@ -223,6 +247,9 @@

// lookup set for internal EPs so we can create an IExecutionProvider directly
std::unordered_set<EpFactoryInternal*> internal_ep_factories_;

DataTransferManager data_transfer_mgr_; // plugin EP IDataTransfer instances

#endif // !defined(ORT_MINIMAL_BUILD)
};

Expand Down
182 changes: 174 additions & 8 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ extern "C" {
#define _Outptr_result_maybenull_
#define _Outptr_result_maybenull_z_
#define _In_reads_(X)
#define _In_reads_opt_
#define _Inout_updates_(X)
#define _Out_writes_(X)
#define _Out_writes_opt_(X)
Expand Down Expand Up @@ -322,6 +323,7 @@ ORT_RUNTIME_CLASS(ModelCompilationOptions);
ORT_RUNTIME_CLASS(HardwareDevice);
ORT_RUNTIME_CLASS(EpDevice);
ORT_RUNTIME_CLASS(KeyValuePairs);
ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream.

#ifdef _MSC_VER
typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr;
Expand Down Expand Up @@ -426,10 +428,14 @@ typedef enum OrtAllocatorType {
*/
// Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc
typedef enum OrtMemType {
OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider
OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED
OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED
OrtMemTypeDefault = 0, ///< The default allocator for execution provider
/// Any CPU memory used by non-CPU execution provider
OrtMemTypeCPUInput = -2,
/// CPU accessible memory outputted by non-CPU execution provider, i.e. HOST_ACCESSIBLE
OrtMemTypeCPUOutput = -1,
/// CPU accessible memory allocated by non-CPU execution provider, i.e. HOST_ACCESSIBLE
OrtMemTypeCPU = OrtMemTypeCPUOutput,
/// The default allocator for execution provider
OrtMemTypeDefault = 0,
} OrtMemType;

/** \brief This matches OrtDevice::MemoryType values */
Expand Down Expand Up @@ -1743,7 +1749,7 @@ struct OrtApi {
*/
ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out);

/** \brief Get the id from ::OrtMemoryInfo
/** \brief Get the device id from ::OrtMemoryInfo
*/
ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out);

Expand Down Expand Up @@ -5383,10 +5389,32 @@ struct OrtApi {
* \since Version 1.23
*/
ORT_API2_STATUS(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type,
_In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type,
_In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type,
_In_ size_t alignment, enum OrtAllocatorType allocator_type,
_Outptr_ OrtMemoryInfo** out);

/** \brief Get the device memory type from ::OrtMemoryInfo
*
* \param[in] ptr The OrtMemoryInfo instance to query.
* \param[out] out The device memory type.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out);

/** \brief Get the vendor id from ::OrtMemoryInfo
*
* \param[in] ptr The OrtMemoryInfo instance to query.
* \param[out] out The vendor id.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out);

/// \name OrtValueInfo
/// @{

Expand Down Expand Up @@ -6067,11 +6095,14 @@ struct OrtApi {
/** \brief Get the OrtMemoryInfo for the device.
*
* \param[in] ep_device The OrtEpDevice instance to query.
* \return A pointer to the OrtMemoryInfo for the device.
* \param[in] memory_type The memory type to return.
* \return A pointer to the OrtMemoryInfo for the device. This may be nullptr if not set.
* If memory_type is OrtDeviceMemoryType_DEFAULT and nullptr is returned the EP uses CPU memory.
*
* \since Version 1.23
*/
ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device);
ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device,
_In_ OrtDeviceMemoryType memory_type);

/** \brief Create/replace a shared allocator for the OrtEpDevice in the OrtEnv.
*
Expand Down Expand Up @@ -6163,6 +6194,141 @@ struct OrtApi {
* \since Version 1.23.
*/
ORT_API2_STATUS(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out);

/** \brief Get the OrtMemoryInfo for each input of the session.
*
* The memory info can be used to determine where the input tensors are required.
*
* The session must be fully initialized before calling this function as the input locations are not known until
* this has occurred.
*
* \param[in] session The OrtSession instance.
* \param[out] inputs_memory_info Pre-allocated array of size `num_inputs` that will be filled with the
* OrtMemoryInfo* value for each input.
* The order is the same as returned by SessionGetInputName.
* \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(SessionGetMemoryInfoForInputs, _In_ const OrtSession* session,
_Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info,
_In_ size_t num_inputs);

/** \brief Get the OrtMemoryInfo for each output of the session.
*
* The memory info can be used to determine the device the output tensors are produced on.
* The user can pre-allocate an OrtValue using this information or use IOBinding to keep the data on the device.
* ORT will copy the output to CPU otherwise.
*
* The session must be fully initialized before calling this function as the output locations are not known until
* this has occurred.
*
* \param[in] session The OrtSession instance.
* \param[out] outputs_memory_info Pre-allocated array of size `num_outputs` that will be filled with
* OrtMemoryInfo* values for each output.
* The order is the same as returned by SessionGetOutputName.
* \param[in] num_outputs The number of outputs in the session. Must match SessionGetOutputCount.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session,
_Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info,
_In_ size_t num_outputs);

/** \brief Get the OrtEpDevice (if available) for each input of the session.
*
* An OrtEpDevice will be available if auto EP selection is enabled by calling
* SessionOptionsSetEpSelectionPolicy or SessionOptionsSetEpSelectionPolicyDelegate,
* or if the OrtEpDevice was manually added to the session using SessionOptionsAppendExecutionProvider_V2.
*
* If an OrtEpDevice is not available for the input a nullptr is returned.
*
* The returned OrtEpDevice can be used to create an OrtSyncStream via CreateSyncStreamForEpDevice to asynchronously
* provide input to the inference session Run.
*
* The session must be fully initialized before calling this function as the assigned EPs are not known until
* this has occurred.
*
* \param[in] session The OrtSession instance.
* \param[out] inputs_ep_devices Pre-allocated array of size `num_inputs` that will be filled with
* OrtEpDevice* values for each input.
* The order is the same as returned by SessionGetInputName.
* \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(SessionGetEpDeviceForInputs, _In_ const OrtSession* session,
_Out_writes_(num_inputs) const OrtEpDevice** inputs_ep_devices,
_In_ size_t num_inputs);

/** \brief Create an OrtSyncStream for the given OrtEpDevice.
*
* The OrtSyncStream can be used to enable asynchronous operations.
* e.g. async usage of CopyTensors to provide input to an OrtSession Run call.
*
* An error code of ORT_NOT_IMPLEMENTED will be returned if the EP does not support OrtSyncStream.
*
* \param[in] ep_device The OrtEpDevice instance to create the sync stream for.
* \param[in] stream_options Options for OrtSyncStream creation. May be nullptr.
* \param[out] stream Output parameter set to the created OrtSyncStream instance.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device,
_In_opt_ const OrtKeyValuePairs* stream_options,
_Outptr_ OrtSyncStream** stream);

/** \brief Get the native handle of the sync stream.
*
* This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams.
*
* \param[in] stream The OrtSyncStream instance to get the handle from.
*
* \returns The native handle of the stream.
*
* \since Version 1.23
*/
ORT_API_T(void*, SyncStream_GetHandle, _In_ OrtSyncStream* stream);

ORT_CLASS_RELEASE(SyncStream);

/** \brief Copy OrtValue instances containing Tensors between devices.
*
* The overall copy must be between a single source device and a single destination device. i.e.
* - all src_tensors must have matching OrtMemoryInfo,
* - all dst_tensors must have matching OrtMemoryInfo.
*
* OrtValue instances can be created by:
* - Use GetSharedAllocator to get the shared allocator for the OrtMemoryInfo if you need to allocate memory
* on the device.
* - Use CreateTensorAsOrtValue, CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue
* to create an OrtValue containing a tensor depending on whether you have existing data or not, and whether
* you want ORT to free the existing data once it is done with the OrtValue.
*
* \param[in] env The OrtEnv instance to use. The data transfer implementation is provided by an execution provider
* that is registered in this OrtEnv.
* \param[in] src_tensors Array of OrtValue instances containing the source tensors to copy.
* \param[in] dst_tensors Array of OrtValue instances to copy the source tensors to.
* \param[in] stream Optional OrtSyncStream that can be used to perform the copy asynchronously. May be nullptr.
* \param[in] num_tensors The number of tensors to copy. The size of `src_tensors` and `dst_tensors` must match.
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*
* \since Version 1.23
*/
ORT_API2_STATUS(CopyTensors, _In_ const OrtEnv* env,
_In_reads_(num_tensors) const OrtValue* const* src_tensors,
_In_reads_(num_tensors) OrtValue* const* dst_tensors,
_In_opt_ OrtSyncStream* stream,
_In_ size_t num_tensors);
};

/*
Expand Down
Loading
Loading