diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index e51e20e9994dd..33c9e19911557 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -498,17 +498,19 @@ set (ONNXRUNTIME_AUTOEP_TEST_SRC_DIR "${TEST_SRC_DIR}/autoep") set (ONNXRUNTIME_EP_GRAPH_TEST_SRC_DIR "${TEST_SRC_DIR}/ep_graph") set (onnxruntime_shared_lib_test_SRC - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_data_copy.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ort_format_models.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.h ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.cc - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc) + ) if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 9f3a9eeabff6b..7e49275e59b8b 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -15,6 +15,7 @@ #include "core/common/status.h" #include "core/framework/allocator.h" #include "core/framework/execution_provider.h" +#include "core/framework/data_transfer_manager.h" #include "core/platform/device_discovery.h" #include "core/platform/threadpool.h" @@ -140,6 +141,10 @@ class Environment { OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator); Status ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type); + + const DataTransferManager& GetDataTransferManager() const { + return data_transfer_mgr_; + } #endif // !defined(ORT_MINIMAL_BUILD) // return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator @@ -185,6 +190,23 @@ class Environment { using OrtAllocatorUniquePtr = std::unique_ptr>; + // if the user calls CreateSharedAllocator and wraps the plugin EP's allocator with an arena we end up with + // OrtAllocator from EP -> wrapped in IAllocatorImplWrappingOrtAllocator -> inside a BFCArena IAllocator. + // we can put that in shared_allocators_ for sessions to use, but to have an OrtAllocator available in + // shared_ort_allocators_ that can be used outside of a session we need to additionally wrap that in an + // OrtAllocatorImplWrappingIAllocator. way too many levels of indirection but that is what it is currently. + // we need something to own that final OrtAllocator, so we add it to arena_ort_allocators_. + // + // TODO: we could split out the BFCArena implementation so it can be plugged into either an IAllocator + // or an OrtAllocator instance to reduce the indirection a little. + // with that we get an OrtAllocator from the EP, wrap it with an OrtAllocator based BFCArena, and wrap that with the + // IAllocatorImplWrappingOrtAllocator which takes ownership of the OrtAllocator and is in shared_allocators_. + // + // Alternatively we can disable wrapping an EP's allocator with a BFCArena and say the EP should provide the arena + // implementation directly. They're free to copy BFCArena as it came from TF originally. Or we could provide a + // cut-and-paste BFCArena implementation that works using the EP API that can be included in the EP source. + std::unordered_map> arena_ort_allocators_; + #if !defined(ORT_MINIMAL_BUILD) // register EPs that are built into the ORT binary so they can take part in AutoEP selection // added to ep_libraries @@ -207,7 +229,9 @@ class Environment { std::unique_ptr library; std::vector> execution_devices; - std::vector internal_factories; // factories that can create IExecutionProvider instances + std::vector factories; + std::vector internal_factories; // factories that can create IExecutionProvider instances + std::vector data_transfers; // data transfer instances for this EP. private: EpInfo() = default; @@ -223,6 +247,9 @@ class Environment { // lookup set for internal EPs so we can create an IExecutionProvider directly std::unordered_set internal_ep_factories_; + + DataTransferManager data_transfer_mgr_; // plugin EP IDataTransfer instances + #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 82e782112974f..69a0f4cd7a487 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -64,6 +64,7 @@ extern "C" { #define _Outptr_result_maybenull_ #define _Outptr_result_maybenull_z_ #define _In_reads_(X) +#define _In_reads_opt_ #define _Inout_updates_(X) #define _Out_writes_(X) #define _Out_writes_opt_(X) @@ -322,6 +323,7 @@ ORT_RUNTIME_CLASS(ModelCompilationOptions); ORT_RUNTIME_CLASS(HardwareDevice); ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); +ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -426,10 +428,14 @@ typedef enum OrtAllocatorType { */ // Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc typedef enum OrtMemType { - OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider - OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeDefault = 0, ///< The default allocator for execution provider + /// Any CPU memory used by non-CPU execution provider + OrtMemTypeCPUInput = -2, + /// CPU accessible memory outputted by non-CPU execution provider, i.e. HOST_ACCESSIBLE + OrtMemTypeCPUOutput = -1, + /// CPU accessible memory allocated by non-CPU execution provider, i.e. HOST_ACCESSIBLE + OrtMemTypeCPU = OrtMemTypeCPUOutput, + /// The default allocator for execution provider + OrtMemTypeDefault = 0, } OrtMemType; /** \brief This matches OrtDevice::MemoryType values */ @@ -1743,7 +1749,7 @@ struct OrtApi { */ ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); - /** \brief Get the id from ::OrtMemoryInfo + /** \brief Get the device id from ::OrtMemoryInfo */ ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); @@ -5383,10 +5389,32 @@ struct OrtApi { * \since Version 1.23 */ ORT_API2_STATUS(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, - _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType allocator_type, _Outptr_ OrtMemoryInfo** out); + /** \brief Get the device memory type from ::OrtMemoryInfo + * + * \param[in] ptr The OrtMemoryInfo instance to query. + * \param[out] out The device memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out); + + /** \brief Get the vendor id from ::OrtMemoryInfo + * + * \param[in] ptr The OrtMemoryInfo instance to query. + * \param[out] out The vendor id. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out); + /// \name OrtValueInfo /// @{ @@ -6067,11 +6095,14 @@ struct OrtApi { /** \brief Get the OrtMemoryInfo for the device. * * \param[in] ep_device The OrtEpDevice instance to query. - * \return A pointer to the OrtMemoryInfo for the device. + * \param[in] memory_type The memory type to return. + * \return A pointer to the OrtMemoryInfo for the device. This may be nullptr if not set. + * If memory_type is OrtDeviceMemoryType_DEFAULT and nullptr is returned the EP uses CPU memory. * * \since Version 1.23 */ - ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device); + ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType memory_type); /** \brief Create/replace a shared allocator for the OrtEpDevice in the OrtEnv. * @@ -6163,6 +6194,141 @@ struct OrtApi { * \since Version 1.23. */ ORT_API2_STATUS(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Get the OrtMemoryInfo for each input of the session. + * + * The memory info can be used to determine where the input tensors are required. + * + * The session must be fully initialized before calling this function as the input locations are not known until + * this has occurred. + * + * \param[in] session The OrtSession instance. + * \param[out] inputs_memory_info Pre-allocated array of size `num_inputs` that will be filled with the + * OrtMemoryInfo* value for each input. + * The order is the same as returned by SessionGetInputName. + * \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(SessionGetMemoryInfoForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info, + _In_ size_t num_inputs); + + /** \brief Get the OrtMemoryInfo for each output of the session. + * + * The memory info can be used to determine the device the output tensors are produced on. + * The user can pre-allocate an OrtValue using this information or use IOBinding to keep the data on the device. + * ORT will copy the output to CPU otherwise. + * + * The session must be fully initialized before calling this function as the output locations are not known until + * this has occurred. + * + * \param[in] session The OrtSession instance. + * \param[out] outputs_memory_info Pre-allocated array of size `num_outputs` that will be filled with + * OrtMemoryInfo* values for each output. + * The order is the same as returned by SessionGetOutputName. + * \param[in] num_outputs The number of outputs in the session. Must match SessionGetOutputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info, + _In_ size_t num_outputs); + + /** \brief Get the OrtEpDevice (if available) for each input of the session. + * + * An OrtEpDevice will be available if auto EP selection is enabled by calling + * SessionOptionsSetEpSelectionPolicy or SessionOptionsSetEpSelectionPolicyDelegate, + * or if the OrtEpDevice was manually added to the session using SessionOptionsAppendExecutionProvider_V2. + * + * If an OrtEpDevice is not available for the input a nullptr is returned. + * + * The returned OrtEpDevice can be used to create an OrtSyncStream via CreateSyncStreamForEpDevice to asynchronously + * provide input to the inference session Run. + * + * The session must be fully initialized before calling this function as the assigned EPs are not known until + * this has occurred. + * + * \param[in] session The OrtSession instance. + * \param[out] inputs_ep_devices Pre-allocated array of size `num_inputs` that will be filled with + * OrtEpDevice* values for each input. + * The order is the same as returned by SessionGetInputName. + * \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(SessionGetEpDeviceForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtEpDevice** inputs_ep_devices, + _In_ size_t num_inputs); + + /** \brief Create an OrtSyncStream for the given OrtEpDevice. + * + * The OrtSyncStream can be used to enable asynchronous operations. + * e.g. async usage of CopyTensors to provide input to an OrtSession Run call. + * + * An error code of ORT_NOT_IMPLEMENTED will be returned if the EP does not support OrtSyncStream. + * + * \param[in] ep_device The OrtEpDevice instance to create the sync stream for. + * \param[in] stream_options Options for OrtSyncStream creation. May be nullptr. + * \param[out] stream Output parameter set to the created OrtSyncStream instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStream** stream); + + /** \brief Get the native handle of the sync stream. + * + * This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams. + * + * \param[in] stream The OrtSyncStream instance to get the handle from. + * + * \returns The native handle of the stream. + * + * \since Version 1.23 + */ + ORT_API_T(void*, SyncStream_GetHandle, _In_ OrtSyncStream* stream); + + ORT_CLASS_RELEASE(SyncStream); + + /** \brief Copy OrtValue instances containing Tensors between devices. + * + * The overall copy must be between a single source device and a single destination device. i.e. + * - all src_tensors must have matching OrtMemoryInfo, + * - all dst_tensors must have matching OrtMemoryInfo. + * + * OrtValue instances can be created by: + * - Use GetSharedAllocator to get the shared allocator for the OrtMemoryInfo if you need to allocate memory + * on the device. + * - Use CreateTensorAsOrtValue, CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue + * to create an OrtValue containing a tensor depending on whether you have existing data or not, and whether + * you want ORT to free the existing data once it is done with the OrtValue. + * + * \param[in] env The OrtEnv instance to use. The data transfer implementation is provided by an execution provider + * that is registered in this OrtEnv. + * \param[in] src_tensors Array of OrtValue instances containing the source tensors to copy. + * \param[in] dst_tensors Array of OrtValue instances to copy the source tensors to. + * \param[in] stream Optional OrtSyncStream that can be used to perform the copy asynchronously. May be nullptr. + * \param[in] num_tensors The number of tensors to copy. The size of `src_tensors` and `dst_tensors` must match. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(CopyTensors, _In_ const OrtEnv* env, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 5d00ce4940d02..0dc4105ebe855 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -13,12 +13,12 @@ ORT_RUNTIME_CLASS(EpGraphSupportInfo); ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice ORT_RUNTIME_CLASS(NodeComputeContext); -// Opaque class to create an onnxruntime::Stream. Will be filled out in separate PR. -// Adding here for OrtDataTransferImpl as the stream type is required by the IDataTransfer API. -ORT_RUNTIME_CLASS(SyncStream); +ORT_RUNTIME_CLASS(DataTransferImpl); +ORT_RUNTIME_CLASS(SyncNotificationImpl); +ORT_RUNTIME_CLASS(SyncStreamImpl); // struct that an EP implements for IDataTransfer to copy between devices it uses and CPU -typedef struct OrtDataTransferImpl { +struct OrtDataTransferImpl { uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION /** \brief Release the OrtDataTransferImpl instance. @@ -30,7 +30,7 @@ typedef struct OrtDataTransferImpl { * * \since Version 1.23. */ - ORT_API_T(void, Release, _In_ void* this_ptr); + ORT_API_T(void, Release, _In_ OrtDataTransferImpl* this_ptr); /** \brief Check if the implementation can copy between the source and destination memory devices. * @@ -41,7 +41,7 @@ typedef struct OrtDataTransferImpl { * * \since Version 1.23. */ - ORT_API_T(bool, CanCopy, _In_ void* this_ptr, + ORT_API_T(bool, CanCopy, _In_ const OrtDataTransferImpl* this_ptr, _In_ const OrtMemoryDevice* src_memory_device, _In_ const OrtMemoryDevice* dst_memory_device); /** \brief Copy tensors from src_tensors to dst_tensors using the provided streams. @@ -60,12 +60,119 @@ typedef struct OrtDataTransferImpl { * * \since Version 1.23. */ - ORT_API2_STATUS(CopyTensors, _In_ void* this_ptr, + ORT_API2_STATUS(CopyTensors, _In_ OrtDataTransferImpl* this_ptr, _In_reads_(num_tensors) const OrtValue** src_tensors, _In_reads_(num_tensors) OrtValue** dst_tensors, _In_reads_(num_tensors) OrtSyncStream** streams, _In_ size_t num_tensors); -} OrtDataTransferImpl; +}; + +/** \brief Struct that an EP implements for Stream Notifications. + * + * \since Version 1.23. + */ +struct OrtSyncNotificationImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Release the OrtSyncNotificationImpl instance. + * + * This is called by ORT when the OrtSyncNotificationImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API_T(void, Release, _In_ OrtSyncNotificationImpl* this_ptr); + + /** \brief Called by ORT to activate the notification. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Activate, _In_ OrtSyncNotificationImpl* this_ptr); + + /** \brief Wait for a device to device operation to complete. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * \param[in] stream The OrtSyncStream instance that will wait on this notification to be activated. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(WaitOnDevice, _In_ OrtSyncNotificationImpl* this_ptr, _In_ OrtSyncStream* consumer_stream); + + /** \brief Wait for a device to host operation to complete. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(WaitOnHost, _In_ OrtSyncNotificationImpl* this_ptr); +}; + +/** \brief Struct that an EP implements if it wishes to implement Stream support. + * + * This struct provides the overrides for onnxruntime::Stream's virtual methods. + * + * \since Version 1.23. + */ +struct OrtSyncStreamImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Release the OrtSyncStreamImpl instance. + * + * This is called by ORT when the OrtSyncStreamImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * + * \since Version 1.23. + */ + ORT_API_T(void, Release, _In_ OrtSyncStreamImpl* this_ptr); + + /** \brief Get the handle of the stream. + * + * This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * \return The handle of the stream. + * + * \since Version 1.23. + */ + ORT_API_T(void*, GetHandle, _In_ OrtSyncStreamImpl* this_ptr); + + /** \brief Create an OrtSyncNotificationImpl for the OrtSyncStreamImpl instance. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance + * \param[out] notification The new OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateNotification, _In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification); + + /** \brief Flush the stream. + * + * This is called by ORT to flush the stream, ensuring that all operations submitted to the stream are completed. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Flush, _In_ OrtSyncStreamImpl* this_ptr); + + /** \brief Notify the stream that a session run has ended. + * + * This is called by ORT to notify the stream that a session run has ended, allowing the stream to perform any + * necessary cleanup or finalization. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr); +}; struct OrtNodeFusionOptions; typedef struct OrtNodeFusionOptions OrtNodeFusionOptions; @@ -522,6 +629,45 @@ struct OrtEp { * \since Version 1.23. */ ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream); + + /** \brief Create an OrtAllocator for the given OrtMemoryInfo for an OrtSession. + * + * The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo. + * Any allocator specific options should be read from the session options. + * + * If nullptr OrtEpFactory::CreateAllocator will be used. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr. + * \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateAllocator, _In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator); + + /** \brief Create a synchronization stream for the given memory device for an OrtSession. + * + * This is used to create a synchronization stream for the execution provider and is used to synchronize + * operations on the device during model execution. + * Any stream specific options should be read from the session options. + * + * If nullptr OrtEpFactory::CreateSyncStreamForDevice will be used. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for. + * \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -683,16 +829,15 @@ struct OrtEpFactory { */ ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); - /** \brief Create an OrtAllocator for the given OrtMemoryInfo. + /** \brief Create an OrtAllocator that can be shared across sessions for the given OrtMemoryInfo. * - * This is used to create an allocator that an execution provider requires. The factory that creates the EP is - * responsible for providing the required allocators. + * The factory that creates the EP is responsible for providing the allocators required by the EP. * The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] memory_info The OrtMemoryInfo to create the allocator for. + * \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr. * \param[in] allocator_options Optional key-value pairs for allocator options, can be nullptr. - * \param[out] allocator The created OrtAllocator instance. + * \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -700,8 +845,8 @@ struct OrtEpFactory { */ ORT_API2_STATUS(CreateAllocator, _In_ OrtEpFactory* this_ptr, _In_ const OrtMemoryInfo* memory_info, - _In_ const OrtKeyValuePairs* allocator_options, - _Outptr_ OrtAllocator** allocator); + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_result_maybenull_ OrtAllocator** allocator); /** \brief Release an OrtAllocator created by the factory. * @@ -715,13 +860,42 @@ struct OrtEpFactory { * that the execution provider supports. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[out] data_transfer The created OrtDataTransferImpl instance. + * \param[out] data_transfer The created OrtDataTransferImpl instance. Set to nullptr if not required. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateDataTransfer, _In_ OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer); + + /** \brief Check if execution providers created by the factory are stream aware. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return True if the factory creates execution providers that are stream aware and it implements CreateSyncStreamForDevice. + * + * \since Version 1.23. + */ + ORT_API_T(bool, IsStreamAware, _In_ const OrtEpFactory* this_ptr); + + /** \brief Create a synchronization stream for the given memory device. + * + * This is used to create a synchronization stream for the memory device that can be used for operations outside of + * a session. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for. + * \param[in] stream_options Options for stream creation. May be nullptr. + * \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(CreateDataTransfer, _In_ OrtEpFactory* this_ptr, _Outptr_ OrtDataTransferImpl** data_transfer); + ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStreamImpl** stream); }; #ifdef __cplusplus diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 9324fa76ded4f..9b734559776e3 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -18,13 +18,14 @@ #include "core/framework/bfc_arena.h" -using Status = onnxruntime::common::Status; +using namespace onnxruntime; +using Status = common::Status; Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg) { cfg = OrtArenaCfg{}; // reset to default values const auto from_string = [](const std::string& key, const std::string& str, auto& value) -> Status { - if (!onnxruntime::ParseStringWithClassicLocale(str, value).IsOK()) { + if (!ParseStringWithClassicLocale(str, value).IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to parse value for ", key, " from ", str); } @@ -250,7 +251,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA } ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, - _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType type, _Outptr_ OrtMemoryInfo** out) { // map the public enum values to internal OrtDevice values @@ -275,7 +276,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid device type specified."); } - *out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, device_id, alignment}, + *out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, narrow(device_id), alignment}, mem_type == OrtDeviceMemoryType_DEFAULT ? OrtMemTypeDefault : OrtMemTypeCPU); return nullptr; } @@ -313,3 +314,13 @@ ORT_API_STATUS_IMPL(OrtApis::CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, ORT_API(void, OrtApis::MemoryInfoGetDeviceType, _In_ const OrtMemoryInfo* info, _Out_ OrtMemoryInfoDeviceType* out) { *out = static_cast(info->device.Type()); } + +ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out) { + *out = static_cast(ptr->device.MemType()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out) { + *out = ptr->device.Vendor(); + return nullptr; +} diff --git a/onnxruntime/core/framework/plugin_ep_stream.cc b/onnxruntime/core/framework/plugin_ep_stream.cc new file mode 100644 index 0000000000000..1eb6ad4162f33 --- /dev/null +++ b/onnxruntime/core/framework/plugin_ep_stream.cc @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/plugin_ep_stream.h" +#include "core/framework/error_code_helper.h" + +namespace onnxruntime { +namespace plugin_ep { + +// TODO: Is num_consumers meaningful? Unused everywhere currently. +OrtStatus* Stream::CreateNotificationImpl(size_t /*num_consumers*/, std::unique_ptr& result) { + OrtSyncNotificationImpl* notification_impl = nullptr; + ORT_API_RETURN_IF_ERROR(impl_.CreateNotification(&impl_, ¬ification_impl)); + + result = std::make_unique(*this, *notification_impl, logger_); + return nullptr; +} +} // namespace plugin_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/plugin_ep_stream.h b/onnxruntime/core/framework/plugin_ep_stream.h new file mode 100644 index 0000000000000..2b89e76e16b76 --- /dev/null +++ b/onnxruntime/core/framework/plugin_ep_stream.h @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/common/logging/logging.h" +#include "core/framework/stream_handles.h" +#include "core/framework/error_code_helper.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" + +// OrtSyncStream is an alias in the C API for onnxruntime::Stream +// OrtSyncNotification is an alias in the C API for onnxruntime::synchronize::Notification +struct OrtSyncStream : public onnxruntime::Stream {}; +struct OrtSyncNotification : onnxruntime::synchronize::Notification {}; + +using onnxruntime::logging::Logger; + +#define LOG_AND_RETURN_IF_ORT_ERROR(fn, logger) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + LOGS(logger, ERROR) << "Plug-in EP Error: [" << OrtApis::GetErrorCode(_status) << "] " \ + << OrtApis::GetErrorMessage(_status); \ + OrtApis::ReleaseStatus(_status); \ + return; \ + } \ + } while (0) + +namespace onnxruntime { +namespace plugin_ep { + +class Notification : public synchronize::Notification { + public: + Notification(Stream& stream, OrtSyncNotificationImpl& impl, const Logger& logger) + : synchronize::Notification(stream), impl_{impl}, logger_{logger} { + } + + static void WaitNotificationOnDevice(onnxruntime::Stream* stream, synchronize::Notification& notification) { + auto* this_ptr = static_cast(¬ification); + + LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnDevice(&this_ptr->impl_, static_cast(stream)), + this_ptr->logger_); + } + + static void WaitNotificationOnHost(onnxruntime::Stream* /*stream*/, synchronize::Notification& notification) { + auto* this_ptr = static_cast(¬ification); + LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnHost(&this_ptr->impl_), this_ptr->logger_); + } + + void Activate() override { + LOG_AND_RETURN_IF_ORT_ERROR(impl_.Activate(&impl_), logger_); + } + + ~Notification() override { + impl_.Release(&impl_); + } + + private: + OrtSyncNotificationImpl& impl_; + const Logger& logger_; +}; + +class Stream : public onnxruntime::Stream { + public: + Stream(const OrtDevice& memory_device, OrtSyncStreamImpl& impl, const logging::Logger& logger) + : onnxruntime::Stream(impl.GetHandle(&impl), memory_device), impl_{impl}, logger_{logger} { + } + + std::unique_ptr CreateNotification(size_t num_consumers) override { + std::unique_ptr plugin_notification; + + auto* ort_status = CreateNotificationImpl(num_consumers, plugin_notification); + if (ort_status != nullptr) { + ORT_THROW("Failed to create Notification: [", OrtApis::GetErrorCode(ort_status), "] ", + OrtApis::GetErrorMessage(ort_status)); + } + + return plugin_notification; + } + + void Flush() override { + LOG_AND_RETURN_IF_ORT_ERROR(impl_.Flush(&impl_), logger_); + } + + Status CleanUpOnRunEnd() override { + auto* ort_status = impl_.OnSessionRunEnd(&impl_); + return ToStatusAndRelease(ort_status); + } + + WaitNotificationFn GetWaitNotificationFn() const override { + return Notification::WaitNotificationOnDevice; + } + + ~Stream() override { + impl_.Release(&impl_); + } + + private: + OrtSyncStream* ToApiStream() { + return static_cast(static_cast(this)); + } + + OrtStatus* CreateNotificationImpl(size_t num_consumers, std::unique_ptr& result); + + OrtSyncStreamImpl& impl_; + const Logger& logger_; +}; +} // namespace plugin_ep +} // namespace onnxruntime + +#undef LOG_AND_RETURN_IF_ORT_ERROR diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index f00bf51ae143d..bf6e8c0c7e5cc 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -7,8 +7,9 @@ #include "core/providers/cuda/cuda_provider_factory_creator.h" #include "core/providers/cuda/cuda_provider_options.h" -#include #include +#include +#include #include @@ -16,6 +17,7 @@ #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_execution_provider_info.h" #include "core/providers/cuda/cuda_allocator.h" +#include "core/providers/cuda/cuda_stream_handle.h" #include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" @@ -307,13 +309,366 @@ CUDA_Provider* GetProvider() { } // namespace onnxruntime -#include "core/framework/error_code_helper.h" +// +// Plug-in EP infrastructure +// + +#include "core/session/abi_devices.h" #include "onnxruntime_config.h" // for ORT_VERSION +struct ErrorHelper { + static const OrtApi* ort_api; + + static OrtStatus* ToOrtStatus(const Status& status) { + if (status.IsOK()) { + return nullptr; // no error + } + + return ort_api->CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } +}; + +const OrtApi* ErrorHelper::ort_api = nullptr; + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF_STATUS_NOTOK(fn) \ + do { \ + Status _status = (fn); \ + if (!_status.IsOK()) { \ + return ErrorHelper::ToOrtStatus(_status); \ + } \ + } while (0) + +#define CUDA_RETURN_IF_ERROR(expr) RETURN_IF_STATUS_NOTOK(CUDA_CALL(expr)) + +struct CudaOrtAllocator : OrtAllocator { + CudaOrtAllocator(const OrtMemoryInfo* mem_info, const OrtApi& api) : memory_info_{mem_info} { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl + GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + + const OrtEpApi& ep_api = *api.GetEpApi(); + const OrtMemoryDevice* mem_device = ep_api.MemoryInfo_GetMemoryDevice(mem_info); + uint32_t device_id = ep_api.MemoryDevice_GetDeviceId(mem_device); + const char* name = nullptr; + auto* status = api.MemoryInfoGetName(mem_info, &name); + static_cast(status); // GetName never fails + + if (ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + allocator_ = std::make_unique(device_id, name); + } else { + allocator_ = std::make_unique(device_id, name); + } + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& impl = *static_cast(this_); + return impl.allocator_->Alloc(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + auto& impl = *static_cast(this_); + impl.allocator_->Free(p); + } + + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const CudaOrtAllocator& impl = *static_cast(this_); + return impl.memory_info_; + } + + private: + const OrtMemoryInfo* memory_info_; + std::unique_ptr allocator_; +}; + +struct CudaDataTransferImpl : OrtDataTransferImpl { + CudaDataTransferImpl(const OrtApi& ort_api_in) + : ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + const auto& impl = *static_cast(this_ptr); + + // logic copied from GPUDataTransfer::CanCopy + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); + auto src_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); + auto dst_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); + + if ((src_type == OrtDevice::GPU && src_vendor_id != OrtDevice::VendorIds::NVIDIA) || + (dst_type == OrtDevice::GPU && dst_vendor_id != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + + // copy must be GPU to GPU or between GPU and CPU + return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) || + (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) || + (src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU); + } + + static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + bool need_stream_sync = false; + + for (size_t idx = 0; idx < num_tensors; ++idx) { + const OrtValue* src_tensor = src_tensors[idx]; + OrtValue* dst_tensor = dst_tensors[idx]; + OrtSyncStream* stream = streams ? streams[idx] : nullptr; + + const OrtMemoryDevice *src_device = nullptr, *dst_device = nullptr; + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensor, &src_device)); + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensor, &dst_device)); + + size_t bytes; + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensor, &bytes)); + + const void* src_data = nullptr; + void* dst_data = nullptr; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensor, &src_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensor, &dst_data)); + + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + + const bool src_is_gpu_default = src_type == OrtMemoryInfoDeviceType_GPU && + src_mem_type == OrtDeviceMemoryType_DEFAULT; + const bool dst_is_gpu_default = dst_type == OrtMemoryInfoDeviceType_GPU && + dst_mem_type == OrtDeviceMemoryType_DEFAULT; + + cudaStream_t cuda_stream = nullptr; + if (stream) { + cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + } + + if (dst_is_gpu_default) { + if (src_is_gpu_default) { + // Copy only if the two addresses are different. + if (dst_data != src_data) { + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } else { + // copy from pinned or non-pinned CPU memory to GPU + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, cuda_stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + + if (src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not + // have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } + } else if (src_is_gpu_default) { + // copying from GPU to CPU memory, this is blocking + + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } + } else { + // copying between CPU accessible memory + + if (dst_data != src_data) { + if (cuda_stream) { + if (src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // sync the stream first to make sure the data arrived + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + } + + memcpy(dst_data, src_data, bytes); + } + } + } + + if (need_stream_sync) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + + return nullptr; + } + + static void ReleaseImpl(OrtDataTransferImpl* /*this_ptr*/) noexcept { + // no-op as we have a single shared instance in OrtEpFactory which is returned from CreateDataTransferImpl, and is + // owned by and freed by the factory. + } + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct CudaSyncNotificationImpl : OrtSyncNotificationImpl { + static OrtStatus* Create(cudaStream_t stream, const OrtApi& ort_api, + std::unique_ptr& notification) { + notification.reset(new CudaSyncNotificationImpl(stream, ort_api)); // can't use make_unique with private ctor + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(¬ification->event_, cudaEventDisableTiming)); + + return nullptr; + } + + static void ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static OrtStatus* ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventRecord(impl.event_, impl.stream_)); + + return nullptr; + } + + static OrtStatus* WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* consumer_stream) noexcept { + auto& impl = *static_cast(this_ptr); + + // setup the consumer stream to wait on our event. + void* consumer_handle = impl.ort_api.SyncStream_GetHandle(consumer_stream); + CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(static_cast(consumer_handle), impl.event_)); + + return nullptr; + } + + static OrtStatus* WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(impl.event_)); + + return nullptr; + } + + ~CudaSyncNotificationImpl() { + cudaEventDestroy(event_); + } + + private: + CudaSyncNotificationImpl(cudaStream_t stream, const OrtApi& ort_api_in) + : stream_{stream}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + Release = ReleaseImpl; + } + + cudaStream_t& stream_; + cudaEvent_t event_; + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct CudaSyncStreamImpl : OrtSyncStreamImpl { + CudaSyncStreamImpl(cudaStream_t&& stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_cuda_stream, + const OrtApi& ort_api_in) + : stream_{ + stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, /*own*/ true, + /*external_cudnn_handle*/ nullptr, + /*external_cublas_handle*/ nullptr, + // ep_info is used by GetResource which seems to be a somewhat ugly way to make arbitrary info that is + // unrelated to the stream available to a custom op. + // avoiding adding GetResource to OrtSyncStreamImpl as we should have a cleaner setup for custom ops, + // so this argument value isn't used and doesn't matter. + /*ep_info*/ CUDAExecutionProviderInfo{}}, + ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; + GetHandle = GetHandleImpl; + CreateNotification = CreateNotificationImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + } + + static void ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static void* GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return impl.stream_.GetHandle(); + } + + static OrtStatus* CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification_impl) noexcept { + auto& impl = *static_cast(this_ptr); + *notification_impl = nullptr; + + std::unique_ptr notification; + cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + + RETURN_IF_ERROR(CudaSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + *notification_impl = notification.release(); + + return nullptr; + } + + static OrtStatus* FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + impl.stream_.Flush(); + + return nullptr; + } + + static OrtStatus* OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + RETURN_IF_STATUS_NOTOK(impl.stream_.CleanUpOnRunEnd()); + + return nullptr; + } + + private: + // this is a little onion-ish as CudaStream is a onnxruntime::Stream and this is an OrtSyncStreamImpl that will be + // used via plugin_ep::Stream, which is also an onnxruntime::Stream. in a 'real' plugin EP implementation + // CudaStream would go away and the logic it has would be implemented directly here. + CudaStream stream_; + const OrtApi& ort_api; +}; + // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { - CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { - ort_version_supported = ORT_API_VERSION; + using MemoryInfoUniquePtr = std::unique_ptr>; + + CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in}, + ep_api{*ort_api_in.GetEpApi()}, + data_transfer_impl{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; @@ -321,16 +676,24 @@ struct CudaEpFactory : OrtEpFactory { GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->ep_name.c_str(); + const auto& factory = *static_cast(this_ptr); + return factory.ep_name.c_str(); } static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->vendor.c_str(); + const auto& factory = *static_cast(this_ptr); + return factory.vendor.c_str(); } static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { @@ -349,15 +712,84 @@ struct CudaEpFactory : OrtEpFactory { size_t max_ep_devices, size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; - auto* factory = static_cast(this_ptr); + auto& factory = *static_cast(this_ptr); + + int num_cuda_devices = 0; + cudaGetDeviceCount(&num_cuda_devices); + RETURN_IF_ERROR(factory.CreateMemoryInfoForDevices(num_cuda_devices)); + /* in theory we can match on the LUID in the OrtHardwareDevice metadata, but that requires the CUDA Driver API + std::vector device_to_luid; + device_to_luid.resize(num_cuda_devices); + + for (int i = 0; i < num_cuda_devices; ++i) { + CUdevice device; + cuDeviceGet(&device, i); + + char luid[8]; + unsigned int nodeMask; + if (cuDeviceGetLuid(luid, &nodeMask, device) == CUDA_SUCCESS) { + device_to_luid[i] = *reinterpret_cast(luid); + } + } + */ + + int16_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && - factory->ort_api.HardwareDevice_VendorId(&device) == 0x10de) { - ORT_API_RETURN_IF_ERROR( - factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); + if (factory.ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && + factory.ort_api.HardwareDevice_VendorId(&device) == 0x10de) { + /* ideally we'd match on LUID here + for now we use an incrementing device id. could be a mismatch if you have multiple different CUDA GPUs. + alternative is to limit to one device only. + + // find the device id. On Windows we have the LUID in the OrtHardwareDevice metadata. + const OrtKeyValuePairs* metadata = factory.ort_api.HardwareDevice_Metadata(&device); + const char* luid_str = factory.ort_api.GetKeyValue(metadata, "LUID"); + + if (!luid_str && num_devices > 1) { + // if there's no LUID we can't match device + return factory.ort_api.CreateStatus(ORT_EP_FAIL, "OrtHardwareDevice does not have LUID"); + } + + char* luid_end = nullptr; + uint64_t luid = std::strtoull(luid_str, &luid_end, 10); + for (; device_id < num_cuda_devices; ++device_id) { + if (device_to_luid[device_id] == luid) { + break; + } + } + + if (device_id == num_cuda_devices) { + std::string msg("Could not match LUID to a CUDA device. LUID="); + msg += luid_str; + + return factory.ort_api.CreateStatus(ORT_EP_FAIL, msg.c_str()); + } + */ + + // create the EP options and add the device id + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory.ort_api.CreateKeyValuePairs(&ep_options); + factory.ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(device_id).c_str()); + + // create the OrtEpDevice + OrtEpDevice* ep_device = nullptr; + RETURN_IF_ERROR(factory.ort_api.GetEpApi()->CreateEpDevice(&factory, &device, ep_metadata, ep_options, + &ep_device)); + + factory.ort_api.ReleaseKeyValuePairs(ep_options); + + const OrtMemoryInfo* gpu_mem_info = factory.gpu_memory_infos[device_id].get(); + const OrtMemoryInfo* host_accessible_mem_info = factory.host_accessible_memory_infos[device_id].get(); + + RETURN_IF_ERROR(factory.ep_api.EpDevice_AddAllocatorInfo(ep_device, gpu_mem_info)); + RETURN_IF_ERROR(factory.ep_api.EpDevice_AddAllocatorInfo(ep_device, host_accessible_mem_info)); + + ep_devices[num_ep_devices++] = ep_device; + + ++device_id; } } @@ -378,10 +810,124 @@ struct CudaEpFactory : OrtEpFactory { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + // this function is free to return the same allocator instance for all calls and make ReleaseAllocator a no-op + // e.g. allocator instance is in unique_ptr in the OrtEpFactory instance. + // ORT will create a shared allocator in the environment and the user can choose to use it in an inference session. + // Otherwise ORT will create an allocator when adding the EP to an inference session. + auto& factory = *static_cast(this_ptr); + + auto cuda_allocator = std::make_unique(memory_info, factory.ort_api); + *allocator = cuda_allocator.release(); + + return nullptr; + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { + delete static_cast(allocator); + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + auto& factory = *static_cast(this_ptr); + *data_transfer = &factory.data_transfer_impl; + + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return true; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** ort_stream) noexcept { + auto& factory = *static_cast(this_ptr); + auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); + + // the OrtEpFactory could have a cache of stream instances if it wants to avoid creating a new one on every + // call. the CudaStreamSyncImpl::Release could return the instance to the cache. + cudaStream_t stream = nullptr; + CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + + // Currently this API is only used for creating a stream that is used outside of a session, as we're using the + // 'real' CUDA IExecutionProvider implementation for the EP. Due to that we need to connect it up to an internal + // onnxruntime::Stream that has the correct settings for the session. + // We do that externally by passing the cudaStream_t in via the "user_compute_stream" provider option. + // + // For use within an inference session in a completely plugin EP we'd need the session's CPU allocator to be + // available, as well as for relevant EP instance specific options such as whether graph capture is enabled + // to be applied. + + const OrtDevice* ort_device = static_cast(memory_device); + // This OrtSyncStream isn't used for running the inference, so we don't need a CPU allocator for + // CPU scratch buffers to be created by operator kernels. + AllocatorPtr null_allocator; + + auto impl = std::make_unique(std::move(stream), *ort_device, nullptr, + /*release_cpu_buffer_on_cuda_stream*/ true, + factory.ort_api); + *ort_stream = impl.release(); + + return nullptr; + } + + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { + gpu_memory_infos.reserve(num_devices); + host_accessible_memory_infos.reserve(num_devices); + + for (int device_id = 0; device_id < num_devices; ++device_id) { + OrtMemoryInfo* mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CUDA", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + + gpu_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + + // HOST_ACCESSIBLE memory should use the non-CPU device type + mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CUDA host accessible", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + + host_accessible_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + } + + return nullptr; + } + const OrtApi& ort_api; + const OrtEpApi& ep_api; const std::string ep_name{kCudaExecutionProvider}; // EP name const std::string vendor{"Microsoft"}; // EP vendor name uint32_t vendor_id{0x1414}; // Microsoft vendor ID + + // per-device memory info + std::vector gpu_memory_infos; + std::vector host_accessible_memory_infos; + + // we use a shared instance for the OrtDataTransferImpl instead of creating a new one on every call to + // CreateDataTransferImpl. + CudaDataTransferImpl data_transfer_impl; + + CudaEpFactory(const CudaEpFactory&) = delete; + CudaEpFactory& operator=(const CudaEpFactory&) = delete; + + CudaEpFactory(CudaEpFactory&&) = default; + CudaEpFactory& operator=(CudaEpFactory&&) = default; }; extern "C" { @@ -391,6 +937,7 @@ extern "C" { OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + ErrorHelper::ort_api = ort_api; // setup our error helper // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(*ort_api); diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 785177ce37788..6dcb64940dec6 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -133,6 +133,12 @@ struct QnnEpFactory : OrtEpFactory { GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } // Returns the name for the EP. Each unique factory configuration must have a unique name. @@ -201,6 +207,43 @@ struct QnnEpFactory : OrtEpFactory { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto& factory = *static_cast(this_ptr); + *allocator = nullptr; + + // we don't add allocator info to the OrtEpDevice we return so this should never be called. + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "QNN EP factory does not support CreateAllocator."); + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* /*allocator*/) noexcept { + // we don't support CreateAllocator so this should never be called. + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // return nullptr to indicate that this EP does not support data transfer. + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** ort_stream) noexcept { + auto& factory = *static_cast(this_ptr); + *ort_stream = nullptr; + + // should never be called as IsStreamAware returns false + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "QNN EP factory does not support CreateSyncStreamForDevice."); + } + const OrtApi& ort_api; const std::string ep_name; // EP name const std::string ep_vendor{"Microsoft"}; // EP vendor name diff --git a/onnxruntime/core/session/abi_devices.h b/onnxruntime/core/session/abi_devices.h index 67253a83ab490..50469126996b2 100644 --- a/onnxruntime/core/session/abi_devices.h +++ b/onnxruntime/core/session/abi_devices.h @@ -64,7 +64,11 @@ struct OrtEpDevice { OrtKeyValuePairs ep_metadata; OrtKeyValuePairs ep_options; - OrtEpFactory* ep_factory; + mutable OrtEpFactory* ep_factory; const OrtMemoryInfo* device_memory_info{nullptr}; const OrtMemoryInfo* host_accessible_memory_info{nullptr}; + + // the user provides const OrtEpDevice instances, but the OrtEpFactory API takes non-const instances for all + // get/create methods to be as flexible as possible. this helper converts to a non-const factory instance. + OrtEpFactory* GetMutableFactory() const { return ep_factory; } }; diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 8acf1df06b46d..493c0a106074c 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -9,6 +9,7 @@ #include "core/framework/allocator.h" #include "core/framework/allocator_utils.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_data_transfer.h" #include "core/graph/constants.h" #include "core/graph/op.h" #include "core/platform/device_discovery.h" @@ -163,8 +164,8 @@ Status Environment::UnregisterAllocator(const OrtMemoryInfo& mem_info) { Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool error_if_not_found) { auto it = FindExistingAllocator(shared_allocators_, mem_info); - - if (error_if_not_found && it == shared_allocators_.end()) { + const bool found_shared_allocator = it != shared_allocators_.end(); + if (!found_shared_allocator && error_if_not_found) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No allocator for this device has been registered for sharing."); } @@ -174,12 +175,18 @@ Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool // so when we remove that from shared_allocators_ we release the OrtAllocator instance. // shared_ort_allocators_ are internal only so never an error if there's no match - auto it2 = FindExistingAllocator(shared_ort_allocators_, mem_info); - if (it2 != shared_ort_allocators_.end()) { + if (auto it2 = FindExistingAllocator(shared_ort_allocators_, mem_info); it2 != shared_ort_allocators_.end()) { shared_ort_allocators_.erase(it2); } - shared_allocators_.erase(it); + // also remove an arena wrapped allocator from an EP if the user called CreateSharedAllocator to create one + if (auto it3 = arena_ort_allocators_.find(&mem_info); it3 != arena_ort_allocators_.end()) { + arena_ort_allocators_.erase(it3); + } + + if (found_shared_allocator) { + shared_allocators_.erase(it); + } return Status::OK(); } @@ -420,7 +427,10 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ provider_type + " is not implemented in CreateAndRegisterAllocatorV2()"}; } -Environment::~Environment() = default; +Environment::~Environment() { + // need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed + shared_allocators_.clear(); +} Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator) { std::lock_guard lock{mutex_}; @@ -444,6 +454,29 @@ Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocat } #if !defined(ORT_MINIMAL_BUILD) + +// +// Plugin EP support +// + +namespace { +Status CreateDataTransferForFactory(OrtEpFactory& ep_factory, + std::unique_ptr& data_transfer) { + OrtDataTransferImpl* data_transfer_impl = nullptr; + OrtStatus* status = ep_factory.CreateDataTransfer(&ep_factory, &data_transfer_impl); + if (status != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "Error creating data transfer: ", ToStatusAndRelease(status).ToString()); + } + + if (data_transfer_impl != nullptr) { + data_transfer = std::make_unique(*data_transfer_impl); + } + + return Status::OK(); +} +} // namespace + Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, std::unique_ptr ep_library, const std::vector& internal_factories) { @@ -477,6 +510,16 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra } } + for (auto* factory : ep_info->factories) { + std::unique_ptr data_transfer; + ORT_RETURN_IF_ERROR(CreateDataTransferForFactory(*factory, data_transfer)); + + if (data_transfer) { + ep_info->data_transfers.push_back(data_transfer.get()); // store so we can unregister in the unload + ORT_RETURN_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::move(data_transfer))); + } + } + for (const auto& internal_factory : internal_factories) { internal_ep_factories_.insert(internal_factory); } @@ -534,6 +577,10 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam // something goes wrong in any of the following steps.. ep_libraries_.erase(ep_name); + for (auto* data_transfer : ep_info->data_transfers) { + ORT_RETURN_IF_ERROR(data_transfer_mgr_.UnregisterDataTransfer(data_transfer)); + } + for (auto* internal_factory : ep_info->internal_factories) { internal_ep_factories_.erase(internal_factory); } @@ -590,7 +637,22 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator_out, bool replace_existing) { - // if we're replacing an existing allocator we don't care who added it + // NOTE: memory_info is guaranteed to come from the OrtEpDevice when this is called + + // we need to remove from shared_ort_allocators_ first in case the entry in shared_allocators_ owns the pointer in + // shared_ort_allocators_. + if (auto it = FindExistingAllocator(shared_ort_allocators_, memory_info, /*match_name*/ true); + it != shared_ort_allocators_.end()) { + shared_ort_allocators_.erase(it); + } + + // if a previous call created an arena wrapped allocator for the EP's memory_info we also need to remove that + if (auto it = arena_ort_allocators_.find(&memory_info); it != arena_ort_allocators_.end()) { + arena_ort_allocators_.erase(it); + } + + // we only want one shared allocator for an OrtDevice in the shared_allocators_ so that it's deterministic which + // one will be used for an inference session. ignore the name so that is the case. if (auto it = FindExistingAllocator(shared_allocators_, memory_info, /*match_name*/ false); it != shared_allocators_.end()) { if (!replace_existing) { @@ -600,12 +662,6 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, shared_allocators_.erase(it); } - // clear out any exact match in the internal shared allocators - if (auto it = FindExistingAllocator(shared_ort_allocators_, memory_info, /*match_name*/ true); - it != shared_ort_allocators_.end()) { - shared_ort_allocators_.erase(it); - } - OrtAllocator* allocator = nullptr; auto* ort_status = ep_device.ep_factory->CreateAllocator(ep_device.ep_factory, &memory_info, allocator_options, &allocator); @@ -613,10 +669,6 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, return ToStatusAndRelease(ort_status); } - if (allocator_out != nullptr) { - *allocator_out = allocator; - } - auto ort_allocator = OrtAllocatorUniquePtr(allocator, [&ep_device](OrtAllocator* allocator) { ep_device.ep_factory->ReleaseAllocator(ep_device.ep_factory, allocator); @@ -625,15 +677,13 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, AllocatorPtr shared_allocator; if (allocator_type == OrtArenaAllocator) { - // wrap with arena + // wrap with ORT arena OrtArenaCfg arena_cfg; if (allocator_options != nullptr) { auto status = OrtArenaCfg::FromKeyValuePairs(*allocator_options, arena_cfg); } - // pending Stream support being added to plugin EP API in separate PR - // ep_device.ep_factory->IsStreamAware(ep_device.ep_factory); - bool stream_aware_arena = false; + bool stream_aware_arena = ep_device.ep_factory->IsStreamAware(ep_device.ep_factory); AllocatorCreationInfo alloc_creation_info{ [&ort_allocator](int) -> std::unique_ptr { @@ -646,13 +696,27 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, }; shared_allocator = CreateAllocator(alloc_creation_info); + + // need an OrtAllocator to return to the user so we need yet another layer. + // we pass in a copy of the AllocatorPtr (which is a shared_ptr) in order to maintain the overall condition that + // shared_allocators_ is the main owner of the allocator and the last place we delete from when removing + // from shared_ort_allocators_, arena_ort_allocators_ and shared_allocators_. + auto arena_ort_allocator = std::make_unique(AllocatorPtr(shared_allocator)); + allocator = arena_ort_allocator.get(); + + // store the entry using the EPs memory info for easier lookup when removing + arena_ort_allocators_.insert({&memory_info, std::move(arena_ort_allocator)}); } else { + shared_ort_allocators_.insert(allocator); shared_allocator = std::make_shared(std::move(ort_allocator)); } - shared_ort_allocators_.insert(allocator); shared_allocators_.push_back(std::move(shared_allocator)); + if (allocator_out != nullptr) { + *allocator_out = allocator; + } + return Status::OK(); } @@ -702,13 +766,13 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u instance.internal_factories = internal_factories; ORT_RETURN_IF_ERROR(instance.library->Load()); - const auto& factories = instance.library->GetFactories(); + instance.factories = instance.library->GetFactories(); // OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured. // the set of hardware devices is static so this can also be static. const static std::vector sorted_devices = SortDevicesByType(); - for (auto* factory_ptr : factories) { + for (auto* factory_ptr : instance.factories) { ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:", instance.library->RegistrationName()); diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index ad965845041f7..6fdd47c537cb3 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -12,6 +12,7 @@ #include "core/framework/ort_value.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/tensor.h" #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index a0904c32011a7..77528565eced7 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -45,6 +45,34 @@ struct ForwardToFactory { session_options, logger, ep); } + static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_ OrtAllocator** allocator) noexcept { + return static_cast(this_ptr)->CreateAllocator(memory_info, allocator_options, allocator); + } + + static void ORT_API_CALL ReleaseAllocator(_In_ OrtEpFactory* this_ptr, _In_ OrtAllocator* allocator) noexcept { + return static_cast(this_ptr)->ReleaseAllocator(allocator); + } + + static OrtStatus* ORT_API_CALL CreateDataTransfer(_In_ OrtEpFactory* this_ptr, + _Outptr_opt_ OrtDataTransferImpl** data_transfer) noexcept { + return static_cast(this_ptr)->CreateDataTransfer(data_transfer); + } + + static bool ORT_API_CALL IsStreamAware(_In_ const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->IsStreamAware(); + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDevice(_In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_opt_ OrtSyncStreamImpl** stream) noexcept { + // ignore the OrtEp input as we won't ever have one for internal EPs + return static_cast(this_ptr)->CreateSyncStreamForDevice(memory_device, stream_options, stream); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index fa4ef2515ca92..9804aa6a5c42d 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,14 +14,8 @@ namespace onnxruntime { using Forward = ForwardToFactory; -EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, - GetSupportedFunc&& get_supported_func, - CreateFunc&& create_func) - : ep_name_{ep_name}, - vendor_{vendor}, - vendor_id_{vendor_id}, - get_supported_func_{std::move(get_supported_func)}, - create_func_{create_func} { +EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) + : impl_{std::move(impl)} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; @@ -31,20 +25,17 @@ EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::stri OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; + OrtEpFactory::CreateAllocator = Forward::CreateAllocator; + OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator; + OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; + OrtEpFactory::IsStreamAware = Forward::IsStreamAware; + OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; } const char* EpFactoryInternal::GetVersion() const noexcept { return ORT_VERSION; } -OrtStatus* EpFactoryInternal::GetSupportedDevices(const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) noexcept { - return get_supported_func_(this, devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); -} - OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t /*num_devices*/, @@ -54,25 +45,23 @@ OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); } -OrtStatus* EpFactoryInternal::CreateIExecutionProvider(const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "EpFactoryInternal currently only supports one device at a time."); +// Prior to addition to SessionOptions the EP options do not have a prefix. +// They are prefixed with 'ep..' when added to SessionOptions. +// +// Use this function to get the options without the prefix from SessionOptions. +// Required by the option parsing for multiple existing EPs. +ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); + ProviderOptions ep_options; + + for (const auto& [key, value] : session_options.config_options.configurations) { + if (key.find(option_prefix) == 0) { + // remove the prefix and add + ep_options[key.substr(option_prefix.length())] = value; + } } - return create_func_(this, devices, ep_metadata_pairs, num_devices, session_options, session_logger, ep); -} - -void EpFactoryInternal::ReleaseEp(OrtEp* /*ep*/) { - // we never create an OrtEp so we should never be trying to release one - ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); + return ep_options; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index ee08e2233c529..ae450efa394e8 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -10,43 +10,100 @@ #include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" namespace onnxruntime { +class EpFactoryInternal; class EpLibraryInternal; struct SessionOptions; -class EpFactoryInternal : public OrtEpFactory { +// class with virtual methods that are implemented for each internal EP +class EpFactoryInternalImpl { public: - // factory is non-const as a pointer to the factory is added to OrtEpDevice and needs to be non-const - // to provide flexibility to the CreateEp call to modify internal state. - using GetSupportedFunc = std::function; - - using CreateFunc = std::function* ep)>; - - EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, - GetSupportedFunc&& get_supported_func, - CreateFunc&& create_func); + EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) + : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { + } const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } uint32_t GetVendorId() const noexcept { return vendor_id_; } const char* GetVersion() const noexcept; + virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices) noexcept = 0; + + virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) = 0; + + virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, + _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, + _Outptr_ OrtAllocator** allocator) noexcept { + // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned + // so this should never be called + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); + } + + virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { + // we don't create any allocators so we don't need to release any + } + + virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + *data_transfer = nullptr; + return nullptr; // Default implementation does nothing + } + + virtual bool IsStreamAware() const { + return false; + } + + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + *stream = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + } + + // Function ORT calls to release an EP instance. + void ReleaseEp(OrtEp* ep); + + virtual ~EpFactoryInternalImpl() = default; + + protected: + ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; + + private: + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID +}; + +// this class can't have any virtual methods as they break using it as an OrtEpFactory* in OrtEpDevice. +class EpFactoryInternal : public OrtEpFactory { + public: + EpFactoryInternal(std::unique_ptr impl); + + const char* GetName() const noexcept { return impl_->GetName(); } + const char* GetVendor() const noexcept { return impl_->GetVendor(); } + uint32_t GetVendorId() const noexcept { return impl_->GetVendorId(); } + const char* GetVersion() const noexcept; + OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, _Inout_ OrtEpDevice** ep_devices, _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices) noexcept; + _Out_ size_t* num_ep_devices) noexcept { + return impl_->GetSupportedDevices(*this, devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); + } // we don't implement this. CreateIExecutionProvider should be used. OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -60,19 +117,43 @@ class EpFactoryInternal : public OrtEpFactory { _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, _In_ size_t num_devices, _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ std::unique_ptr* ep); + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) { + return impl_->CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, session_options, logger, ep); + } + + OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* memory_info, + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_ OrtAllocator** allocator) noexcept { + return impl_->CreateAllocator(memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(_In_ OrtAllocator* allocator) noexcept { + return impl_->ReleaseAllocator(allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + return impl_->CreateDataTransfer(data_transfer); + } + + bool IsStreamAware() const { + return impl_->IsStreamAware(); + } + + OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); + } // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* ep); + void ReleaseEp(OrtEp* /*ep*/) { + // we never create an OrtEp so we should never be trying to release one + ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); + } private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID - const GetSupportedFunc get_supported_func_; // function to return supported devices - const CreateFunc create_func_; // function to create the EP instance - - std::vector> eps_; // EP instances created by this factory + std::unique_ptr impl_; }; // IExecutionProviderFactory for EpFactoryInternal that is required for SessionOptionsAppendExecutionProvider_V2 diff --git a/onnxruntime/core/session/ep_library.cc b/onnxruntime/core/session/ep_library.cc deleted file mode 100644 index a5aa6e925311c..0000000000000 --- a/onnxruntime/core/session/ep_library.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library.h" - -#include "core/framework/provider_options.h" -#include "core/session/abi_session_options_impl.h" - -namespace onnxruntime { -// Prior to addition to SessionOptions the EP options do not have a prefix. -// They are prefixed with 'ep..' when added to SessionOptions. -// -// Use this function to get the options without the prefix from SessionOptions. -// Required by the option parsing for multiple existing EPs. -ProviderOptions EpLibrary::GetOptionsFromSessionOptions(const std::string& ep_name, - const SessionOptions& session_options) { - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_name.c_str()); - ProviderOptions ep_options; - - for (const auto& [key, value] : session_options.config_options.configurations) { - if (key.find(option_prefix) == 0) { - // remove the prefix and add - ep_options[key.substr(option_prefix.length())] = value; - } - } - - return ep_options; -} -} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library.h b/onnxruntime/core/session/ep_library.h index 648bec022580f..24ab74e1c77fc 100644 --- a/onnxruntime/core/session/ep_library.h +++ b/onnxruntime/core/session/ep_library.h @@ -26,9 +26,5 @@ class EpLibrary { virtual ~EpLibrary() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibrary); - - protected: - static ProviderOptions GetOptionsFromSessionOptions(const std::string& ep_name, - const SessionOptions& session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index ce5736f601b45..986ccb1fa17fc 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -4,6 +4,7 @@ #include "core/session/ep_library_internal.h" #include "core/framework/error_code_helper.h" +#include "core/framework/ortmemoryinfo.h" #include "core/framework/session_options.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/abi_devices.h" @@ -21,33 +22,38 @@ #endif namespace onnxruntime { -std::unique_ptr EpLibraryInternal::CreateCpuEp() { - const auto get_supported = [](OrtEpFactory* factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) -> OrtStatus* { + +class CpuEpFactory : public EpFactoryInternalImpl { + public: + CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override { size_t& num_ep_devices = *p_num_ep_devices; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { ORT_API_RETURN_IF_ERROR( - OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, &ep_devices[num_ep_devices++])); } } return nullptr; - }; - - const auto create_cpu_ep = [](OrtEpFactory* /*factory*/, - const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) -> OrtStatus* { + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { if (num_devices != 1) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "CPU EP factory currently only supports one device at a time."); @@ -58,62 +64,93 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { (*ep)->SetLogger(session_logger->ToInternal()); return nullptr; - }; + } +}; - std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - get_supported, create_cpu_ep); - return std::make_unique(std::move(cpu_factory)); +std::unique_ptr EpLibraryInternal::CreateCpuEp() { + auto cpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); } #if defined(USE_DML) -std::unique_ptr EpLibraryInternal::CreateDmlEp() { - static const std::string ep_name = kDmlExecutionProvider; - const auto is_supported = [](OrtEpFactory* factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) -> OrtStatus* { +class DmlEpFactory : public EpFactoryInternalImpl { + public: + DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override { size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { std::unique_ptr ep_options; - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is associated with - // a specific device. + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is + // associated with a specific device. // How would we know what options should not allow user overrides if set in OrtEpDevice? + int32_t device_id = 0; // If no device_id was found default to 0 if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { ep_options = std::make_unique(); - ep_options->Add("device_id", it->second.c_str()); + device_id = std::stoi(it->second); } - auto* api_status = OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, ep_options.get(), - &ep_devices[num_ep_devices++]); + ep_options->Add("device_id", std::to_string(device_id)); + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices]); + + if (device_memory_infos.size() < device_id + 1) { + device_memory_infos.resize(device_id + 1); + device_allocators.resize(device_id + 1); + } + + if (device_memory_infos[device_id] == nullptr) { + // Create memory info for the device if it doesn't already exist + device_memory_infos[device_id] = std::make_unique( + "DML", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + narrow(device_id))); + } + + // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. + // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], + // device_memory_infos[device_id].get()); if (api_status != nullptr) { return api_status; } + + ++num_ep_devices; } } return nullptr; - }; - - const auto create_dml_ep = [](OrtEpFactory* /*factory*/, - const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) -> OrtStatus* { + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { + *ep = nullptr; + if (num_devices != 1) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "DML EP factory currently only supports one device at a time."); } - auto ep_options = GetOptionsFromSessionOptions(ep_name, session_options->value); + auto ep_options = GetOptionsFromSessionOptions(session_options->value); auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, ep_options); @@ -121,45 +158,73 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { (*ep)->SetLogger(session_logger->ToInternal()); return nullptr; - }; + } + + OrtStatus* CreateAllocator(const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept override { + // TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That + // requires pulling lots of things out of the DML EP to get the D3D12 device and create a + // BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp + //*allocator = device_allocators[memory_info->device.Id()].get(); + *allocator = nullptr; + return nullptr; + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; + } - auto dml_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_dml_ep); + std::vector> device_memory_infos; // memory info for each device + std::vector> device_allocators; // allocators for each device +}; - return std::make_unique(std::move(dml_factory)); +std::unique_ptr EpLibraryInternal::CreateDmlEp() { + auto dml_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(dml_factory_impl)); + return std::make_unique(std::move(internal_factory)); } #endif #if defined(USE_WEBGPU) -std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { - static const std::string ep_name = kWebGpuExecutionProvider; - - const auto is_supported = [](OrtEpFactory* factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) -> OrtStatus* { +class WebGpuEpFactory : public EpFactoryInternalImpl { + public: + WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override { size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { // TODO: any metadata or options to add? - ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, nullptr, &ep_devices[num_ep_devices++])); } } return nullptr; - }; - - const auto create_webgpu_ep = [](OrtEpFactory* /*factory*/, - const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) -> OrtStatus* { + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { + *ep = nullptr; + if (num_devices != 1) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "WebGPU EP factory currently only supports one device at a time."); @@ -170,12 +235,28 @@ std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { (*ep)->SetLogger(session_logger->ToInternal()); return nullptr; - }; - - auto webgpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_webgpu_ep); + } + + /* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of + an InferenceSession. + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + *allocator = device_allocators[memory_info->device.Id()].get(); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; + } + */ +}; - return std::make_unique(std::move(webgpu_factory)); +std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { + auto webgpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); } #endif diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 70937bdc5d3e8..ae553891beaa7 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -13,6 +13,81 @@ #include "core/session/ep_factory_internal.h" namespace onnxruntime { +class ProviderBridgeEpFactory : public EpFactoryInternalImpl { + public: + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), + ep_factory.GetVendor(&ep_factory), + ep_factory.GetVendorId(&ep_factory)), + ep_factory_{ep_factory}, + provider_library_{provider_library} { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept override { + ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = &ep_factory; + } + } + + return nullptr; + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { + // get the provider specific options + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto& provider = provider_library_.Get(); + + auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, + ep_options, *session_options, *session_logger, *ep); + + return ToOrtStatus(status); + } + + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(OrtAllocator* allocator) noexcept override { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); + } + + bool IsStreamAware() const noexcept override { + return ep_factory_.IsStreamAware(&ep_factory_); + } + + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept override { + return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); + } + + OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP + ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP +}; + Status EpLibraryProviderBridge::Load() { std::lock_guard lock{mutex_}; @@ -34,47 +109,9 @@ Status EpLibraryProviderBridge::Load() { // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. for (const auto& factory : ep_library_plugin_->GetFactories()) { - const auto is_supported_fn = [&factory](OrtEpFactory* ep_factory_internal, // from factory_ptrs_ - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) -> OrtStatus* { - ORT_API_RETURN_IF_ERROR(factory->GetSupportedDevices(factory, devices, num_devices, ep_devices, max_ep_devices, - num_ep_devices)); - - // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. - for (size_t i = 0; i < *num_ep_devices; ++i) { - auto* ep_device = ep_devices[i]; - if (ep_device) { - ep_device->ep_factory = ep_factory_internal; - } - } + auto factory_impl = std::make_unique(*factory, *provider_library_); + auto internal_factory = std::make_unique(std::move(factory_impl)); - return nullptr; - }; - - const auto create_fn = [this, &factory](OrtEpFactory* /*ep_factory_internal from factory_ptrs_*/, - const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* logger, std::unique_ptr* ep) { - // get the provider options - auto ep_options = GetOptionsFromSessionOptions(factory->GetName(factory), session_options->value); - auto& provider = provider_library_->Get(); - - auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, - ep_options, *session_options, *logger, *ep); - - return ToOrtStatus(status); - }; - - auto internal_factory = std::make_unique(factory->GetName(factory), - factory->GetVendor(factory), - factory->GetVendorId(factory), - is_supported_fn, - create_fn); factory_ptrs_.push_back(internal_factory.get()); internal_factory_ptrs_.push_back(internal_factory.get()); factories_.push_back(std::move(internal_factory)); diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/ep_library_provider_bridge.h index 3c7f083df227e..0717ccd957de7 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/ep_library_provider_bridge.h @@ -49,7 +49,7 @@ class EpLibraryProviderBridge : public EpLibrary { std::unique_ptr provider_library_; // provider bridge EP library // EpLibraryPlugin that provides the CreateEpFactories and ReleaseEpFactory implementations. - // we wrap the OrtEpFactory instances it contains to pass through GetSupportedDevices calls, and + // we wrap the OrtEpFactory instances it contains to pass through function calls, and // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index 878a5384dfee7..1a596acdd486d 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -12,6 +12,7 @@ #include "core/framework/error_code_helper.h" #include "core/framework/model_metadef_id_generator.h" #include "core/framework/plugin_data_transfer.h" +#include "core/framework/plugin_ep_stream.h" #include "core/graph/ep_api_types.h" #include "core/graph/model_editor_api_types.h" #include "core/session/abi_devices.h" @@ -556,7 +557,11 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { for (const auto* memory_info : allocator_mem_infos_) { OrtAllocator* ort_allocator_ptr = nullptr; - OrtStatus* ort_status = ep_factory_.CreateAllocator(&ep_factory_, memory_info, nullptr, &ort_allocator_ptr); + // prefer OrtEp function if available, otherwise fall back to using the OrtEpFactory implementation. + OrtStatus* ort_status = ort_ep_->CreateAllocator + ? ort_ep_->CreateAllocator(ort_ep_.get(), memory_info, &ort_allocator_ptr) + : ep_factory_.CreateAllocator(&ep_factory_, memory_info, /*options*/ nullptr, + &ort_allocator_ptr); // throw or log? start with throw if (ort_status != nullptr) { @@ -574,4 +579,39 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { return allocators; } +void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& registry, + AllocatorMap& /*allocators*/) const { + if (!ep_factory_.IsStreamAware(&ep_factory_)) { + return; + } + + for (const auto* mem_info : allocator_mem_infos_) { + if (mem_info->device.UsesCpuMemory()) { + // CPU memory does not need a stream + continue; + } + + auto device_type = mem_info->device.Type(); + + registry.RegisterCreateStreamFn( + device_type, + [mem_info, this](const OrtDevice& device) { + OrtSyncStreamImpl* stream = nullptr; + const OrtMemoryDevice* memory_device = static_cast(&mem_info->device); + + // prefer OrtEp function if available, otherwise fall back to using the OrtEpFactory implementation. + OrtStatus* status = ort_ep_->CreateSyncStreamForDevice + ? ort_ep_->CreateSyncStreamForDevice(ort_ep_.get(), memory_device, &stream) + : ep_factory_.CreateSyncStreamForDevice(&ep_factory_, memory_device, + /*stream_options*/ nullptr, &stream); + + ORT_ENFORCE(status == nullptr && stream != nullptr, + "Error creating sync stream for device: ", ToStatusAndRelease(status).ToString()); + return std::make_unique(device, *stream, *GetLogger()); + }); + + registry.RegisterWaitFn(device_type, device_type, plugin_ep::Notification::WaitNotificationOnDevice); + registry.RegisterWaitFn(device_type, OrtDevice::CPU, plugin_ep::Notification::WaitNotificationOnHost); + } +} } // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/ep_plugin_provider_interfaces.h index 3ba3118fcaa36..728f959ad67cb 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.h @@ -94,6 +94,8 @@ class PluginExecutionProvider : public IExecutionProvider { std::unique_ptr GetDataTransfer() const override; + void RegisterStreamHandlers(IStreamCommandHandleRegistry&, AllocatorMap&) const override; + // create per-session allocators // longer term we should prefer shared allocators in Environment and only create per-session allocators as // needed based on matching against allocator_mem_infos_. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 15dc6f377766a..25cabd256e318 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -3326,6 +3326,92 @@ std::pair InferenceSession::GetModelOutput return std::make_pair(common::Status::OK(), &model_->MainGraph().GetOutputs()); } +common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType type, + InlinedVector& memory_info) const { + memory_info.clear(); + + if (!is_inited_) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized."); + } + + std::pair result; + switch (type) { + case SessionInputOutputType::kInput: + result = GetModelInputs(); + break; + case SessionInputOutputType::kOutput: + result = GetModelOutputs(); + break; + case SessionInputOutputType::kOverridableInitializer: + // add if/when needed + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "GetInputOutputMemoryInfo for kOverridableInitializer is not implemented."); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unexpected SessionInputOutputType of ", static_cast(type)); + } + + ORT_RETURN_IF_ERROR(result.first); + + const auto& def_list = *result.second; + memory_info.reserve(def_list.size()); + + for (const auto* def : def_list) { + InlinedVector node_info_vec; + if (type == SessionInputOutputType::kOutput) { + ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + } else { + ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + } + + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } + + return Status::OK(); +} + +common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector& ep_devices) const { + ep_devices.clear(); + +#if defined(ORT_MINIMAL_BUILD) + return common::Status(common::ONNXRUNTIME, common::FAIL, + "GetEpDeviceForInputs is not available in a minimal build."); +#else + if (!is_inited_) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized."); + } + + std::pair inputs = GetModelInputs(); + + ORT_RETURN_IF_ERROR(inputs.first); + + const auto& def_list = *inputs.second; + ep_devices.reserve(def_list.size()); + + const auto& available_eps = environment_.GetOrtEpDevices(); + + for (const auto* def : def_list) { + InlinedVector node_info_vec; + ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + + // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map + // instead of doing a linear search each time. + const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } + + return Status::OK(); +#endif +} + common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { { std::lock_guard l(session_mutex_); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 4e25187ff1b47..8bea15c169ed4 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -465,6 +465,25 @@ class InferenceSession { */ std::pair GetModelOutputs() const; + enum class SessionInputOutputType : uint8_t { + kInput = 0, + kOutput = 1, + kOverridableInitializer = 2 + }; + + /** + * Get the OrtMemoryInfo for the inputs or outputs of the model. + * + * This is required for a user to know the location of the input/output when autoep selection is enabled. + */ + common::Status GetInputOutputMemoryInfo(SessionInputOutputType type, + InlinedVector& memory_info) const; + /** + * Get the OrtEpDevice (if available) for the inputs of the model. + * + * This is required for a user to know the location of the input/output when autoep selection is enabled. + */ + common::Status GetEpDeviceForInputs(InlinedVector& memory_info) const; /** * Get the current number of in-progress concurrent Run calls. */ diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index db2a62c77d1bc..64b93ee0f592a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -22,6 +22,7 @@ #include "core/framework/execution_provider.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/framework/ort_value.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/tensor.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/tensorprotoutils.h" @@ -32,6 +33,7 @@ #include "core/graph/model_editor_api_types.h" #include "core/graph/ep_api_types.h" #include "core/providers/get_execution_providers.h" +#include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" @@ -298,7 +300,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* alloc API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape, +ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, + _In_ const int64_t* dense_shape, size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { API_IMPL_BEGIN #if !defined(DISABLE_SPARSE_TENSORS) @@ -1312,9 +1315,17 @@ ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElement, _In_ const OrtValue* value, using DefListResult = std::pair; using GetDefListFn = DefListResult (*)(const ::onnxruntime::InferenceSession*); -const auto get_inputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetModelInputs(); }; -const auto get_outputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetModelOutputs(); }; -const auto get_overridable_initializers_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetOverridableInitializers(); }; +const auto get_inputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { + return session->GetModelInputs(); +}; + +const auto get_outputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { + return session->GetModelOutputs(); +}; + +const auto get_overridable_initializers_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { + return session->GetOverridableInitializers(); +}; static ORT_STATUS_PTR GetNodeDefListCountHelper(const OrtSession* sess, GetDefListFn get_fn, size_t* out) { API_IMPL_BEGIN @@ -3187,6 +3198,148 @@ ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { return OrtExecutionProviderApi::GetEpApi(); } +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* ort_session, + _Out_writes_(num_values) const OrtEpDevice** inputs_ep_devices, + _In_ size_t num_values) { + API_IMPL_BEGIN + if (ort_session == nullptr || inputs_ep_devices == nullptr || num_values == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to SessionGetEpDeviceForInputs."); + } + + auto session = reinterpret_cast(ort_session); + + InlinedVector ep_devices; + + ORT_API_RETURN_IF_STATUS_NOT_OK(session->GetEpDeviceForInputs(ep_devices)); + + auto num_found = ep_devices.size(); + if (num_found > num_values) { + auto msg = MakeString("Number of inputs ", num_found, " exceeds the provided size of ", num_values); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, msg.c_str()); + } + + for (size_t i = 0; i < num_values; ++i) { + inputs_ep_devices[i] = (i < num_found) ? ep_devices[i] : nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStream** ort_stream) { + API_IMPL_BEGIN + if (ep_device == nullptr || ort_stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and stream must be provided."); + } + + const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr; + + if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not use DEFAULT memory of a non-CPU device."); + } + + const auto* factory = ep_device->ep_factory; + if (!factory->IsStreamAware(factory)) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support streams."); + } + + // get the stream implementation from the EP factory + OrtSyncStreamImpl* stream_impl = nullptr; + ORT_API_RETURN_IF_ERROR(factory->CreateSyncStreamForDevice(ep_device->GetMutableFactory(), + static_cast(device), // alias + stream_options, &stream_impl)); + + if (stream_impl == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Failed to get a stream implementation from the EP factory."); + } + + // create the wrapper class that uses the EP implementation + auto stream = std::make_unique(ep_device->device_memory_info->device, + *stream_impl, LoggingManager::DefaultLogger()); + + // cast to base type, and to API alias type + *ort_stream = static_cast(static_cast(stream.release())); + + return nullptr; + + API_IMPL_END +} + +ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* stream) { + return stream->GetHandle(); +} + +ORT_API(void, OrtApis::ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* ort_stream) { + // convert from API alias to internal type + auto* stream = static_cast(ort_stream); + + // the only way for the user to get a non-const OrtSyncStream is from CreateSyncStreamForEpDevice, + // so we can safely cast to the plugin_ep::Stream type. + std::unique_ptr ep_stream(reinterpret_cast(stream)); +} + +ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors) { + API_IMPL_BEGIN + if (env == nullptr || src_tensors == nullptr || dst_tensors == nullptr || num_tensors == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments provided to CopyTensors."); + } + + const OrtMemoryInfo* src_memory_info = nullptr; + const OrtMemoryInfo* dst_memory_info = nullptr; + + const auto validate_and_get_mem_info = + [](const OrtValue* const* values, size_t num_values, const OrtMemoryInfo*& mem_info) -> OrtStatus* { + for (size_t i = 0; i < num_values; ++i) { + const OrtValue* value = values[i]; + if (value == nullptr || !value->IsTensor() || !value->IsAllocated()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue must contain Tensor with data."); + } + + if (i == 0) { + mem_info = &value->Get().Location(); + } else if (*mem_info != value->Get().Location()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All OrtValue instances must have the same OrtMemoryInfo"); + } + } + + return nullptr; + }; + + ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(src_tensors, num_tensors, src_memory_info)); + ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(const_cast(dst_tensors), num_tensors, + dst_memory_info)); + + auto& data_transfer_mgr = env->GetEnvironment().GetDataTransferManager(); + const auto* data_transfer = data_transfer_mgr.GetDataTransfer(src_memory_info->device, dst_memory_info->device); + + if (data_transfer == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "Data transfer implementation between source and destination device was not found."); + } + + std::vector pairs; + pairs.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + pairs.push_back({ + src_tensors[i]->Get(), + *dst_tensors[i]->GetMutable(), + stream, + }); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(data_transfer->CopyTensors(pairs)); + + return nullptr; + + API_IMPL_END +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { @@ -3218,7 +3371,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS size_t /*num_ep_options*/) { API_IMPL_BEGIN return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); - API_IMPL_END } @@ -3227,6 +3379,41 @@ ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* /*ort_session*/, + _Out_writes_(num_values) const OrtEpDevice** /*inputs_ep_devices*/, + _In_ size_t /*num_values*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* /*ep_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_ OrtSyncStream** /*ort_stream*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + API_IMPL_END +} + +ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* /*stream*/) { + fprintf(stderr, "OrtSyncStream is not supported in a minimal build.\n"); + return nullptr; +} + +ORT_API(void, OrtApis::ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* /*ort_stream*/) { + fprintf(stderr, "OrtSyncStream is not supported in a minimal build.\n"); +} + +ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, + _In_reads_(num_tensors) const OrtValue* const* /*src_tensors*/, + _In_reads_(num_tensors) OrtValue* const* /*dst_tensors*/, + _In_opt_ OrtSyncStream* /*stream*/, + _In_ size_t /*num_tensors*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + API_IMPL_END +} + #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -3270,8 +3457,61 @@ ORT_API(const OrtHardwareDevice*, OrtApis::EpDevice_Device, _In_ const OrtEpDevi return ep_device->device; } -ORT_API(const OrtMemoryInfo*, OrtApis::EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device) { - return ep_device->device_memory_info; +ORT_API(const OrtMemoryInfo*, OrtApis::EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType memory_type) { + switch (memory_type) { + case OrtDeviceMemoryType_DEFAULT: + return ep_device->device_memory_info; + case OrtDeviceMemoryType_HOST_ACCESSIBLE: + return ep_device->host_accessible_memory_info; + default: + return nullptr; + } +} + +namespace { +OrtStatus* GetInputOutputMemoryInfo(const OrtSession* ort_session, + InferenceSession::SessionInputOutputType type, + const OrtMemoryInfo** memory_info, + _In_ size_t num_values) { + auto session = reinterpret_cast(ort_session); + + InlinedVector mem_info; + ORT_API_RETURN_IF_STATUS_NOT_OK( + session->GetInputOutputMemoryInfo(InferenceSession::SessionInputOutputType::kInput, mem_info)); + + auto num_found = mem_info.size(); + if (num_found > num_values) { + auto msg = MakeString("Number of ", + type == InferenceSession::SessionInputOutputType::kOutput ? "outputs " : "inputs ", + mem_info.size(), " exceeds the provided size of ", num_values); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, msg.c_str()); + } + + for (size_t i = 0; i < num_values; ++i) { + memory_info[i] = (i < num_found) ? mem_info[i] : nullptr; + } + + return nullptr; +} +} // namespace + +ORT_API_STATUS_IMPL(OrtApis::SessionGetMemoryInfoForInputs, _In_ const OrtSession* ort_session, + _Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info, + _In_ size_t num_inputs) { + API_IMPL_BEGIN + return GetInputOutputMemoryInfo(ort_session, InferenceSession::SessionInputOutputType::kInput, + inputs_memory_info, num_inputs); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info, + _In_ size_t num_outputs) { + API_IMPL_BEGIN + return GetInputOutputMemoryInfo(session, InferenceSession::SessionInputOutputType::kOutput, + outputs_memory_info, num_outputs); + API_IMPL_END } static constexpr OrtApiBase ort_api_base = { @@ -3708,6 +3948,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::AllocatorGetStats, &OrtApis::CreateMemoryInfo_V2, + &OrtApis::MemoryInfoGetDeviceMemType, + &OrtApis::MemoryInfoGetVendorId, &OrtApis::ValueInfo_GetValueProducer, &OrtApis::ValueInfo_GetValueNumConsumers, @@ -3764,6 +4006,16 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetTensorData, &OrtApis::GetSessionOptionsConfigEntries, + + &OrtApis::SessionGetMemoryInfoForInputs, + &OrtApis::SessionGetMemoryInfoForOutputs, + &OrtApis::SessionGetEpDeviceForInputs, + + &OrtApis::CreateSyncStreamForEpDevice, + &OrtApis::SyncStream_GetHandle, + &OrtApis::ReleaseSyncStream, + + &OrtApis::CopyTensors, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9ab927006c320..5f68c56cb044e 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -604,10 +604,13 @@ ORT_API_STATUS_IMPL(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ ORT_API_STATUS_IMPL(AllocatorGetStats, _In_ const OrtAllocator* ptr, _Outptr_ OrtKeyValuePairs** out); ORT_API_STATUS_IMPL(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, - _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType allocator_type, _Outptr_ OrtMemoryInfo** out); +ORT_API_STATUS_IMPL(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out); +ORT_API_STATUS_IMPL(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out); + // OrtValueInfo ORT_API_STATUS_IMPL(ValueInfo_GetValueProducer, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtNode** producer_node, _Out_opt_ size_t* producer_output_index); @@ -685,7 +688,8 @@ ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_may ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); -ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device); +ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType memory_type); ORT_API_STATUS_IMPL(CreateSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device, _In_ OrtDeviceMemoryType mem_type, _In_ OrtAllocatorType allocator_type, @@ -700,4 +704,30 @@ ORT_API_STATUS_IMPL(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDe ORT_API_STATUS_IMPL(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); ORT_API_STATUS_IMPL(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); + +ORT_API_STATUS_IMPL(SessionGetMemoryInfoForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info, + _In_ size_t num_inputs); + +ORT_API_STATUS_IMPL(SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info, + _In_ size_t num_outputs); + +ORT_API_STATUS_IMPL(SessionGetEpDeviceForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtEpDevice** inputs_ep_devices, + _In_ size_t num_inputs); + +ORT_API_STATUS_IMPL(CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStream** stream); + +ORT_API(void*, SyncStream_GetHandle, _In_ OrtSyncStream* stream); + +ORT_API(void, ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* stream); + +ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors); } // namespace OrtApis diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 44b3f9a213abf..e65fd013d14e7 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -13,6 +13,7 @@ #include #include "ep_factory.h" +#include "ep_stream_support.h" /// /// Example implementation of ONNX Mul. Does not handle many things like broadcasting. @@ -163,6 +164,8 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C GetCapability = GetCapabilityImpl; Compile = CompileImpl; ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr auto status = ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -435,6 +438,43 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes return nullptr; } + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept { + // A per-session allocator could be created here. + // Logging of any issues should use ep->logger_ which is the session logger. + + ExampleEp* ep = static_cast(this_ptr); + + // for simplicity in this example we use the factory implementation. + return ep->factory_.CreateAllocator(&ep->factory_, memory_info, nullptr, allocator); +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept { + // A per-session OrtSyncStreamImpl can be created here if the session options affect the implementation. + // Logging of any issues should use logger_ which is the session logger. + + ExampleEp* ep = static_cast(this_ptr); + + // we only create streams for the default device memory. + if (auto mem_type = ep->factory_.ep_api.MemoryDevice_GetMemoryType(memory_device); + mem_type != OrtDeviceMemoryType_DEFAULT) { + std::string error = "Invalid OrtMemoryDevice. Expected OrtDeviceMemoryType_DEFAULT(0). Got "; + error += std::to_string(mem_type); + return ep->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, error.c_str()); + } + + auto sync_stream = std::make_unique(ep->factory_, ep, nullptr); + *stream = sync_stream.release(); + + return nullptr; +} + // // Implementation of ExampleNodeComputeInfo // diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index dfebcc52a0caf..fa6eb24c5cc04 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -30,12 +30,23 @@ class ExampleEp : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; - static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, size_t num_node_compute_infos) noexcept; diff --git a/onnxruntime/test/autoep/library/ep_allocator.h b/onnxruntime/test/autoep/library/ep_allocator.h index 5e6f81fd0aa9e..624b4fcb484cd 100644 --- a/onnxruntime/test/autoep/library/ep_allocator.h +++ b/onnxruntime/test/autoep/library/ep_allocator.h @@ -5,16 +5,35 @@ #include "example_plugin_ep_utils.h" +// from onnxruntime/core/framework/allocator_stats.h +struct AllocatorStats { + int64_t num_allocs; // Number of allocations. + int64_t num_reserves; // Number of reserves. (Number of calls to Reserve() in arena-based allocators) + int64_t bytes_in_use; // Number of bytes in use. + int64_t total_allocated_bytes; // The total number of allocated bytes by the allocator. + int64_t max_bytes_in_use; // The maximum bytes in use. + int64_t max_alloc_size; // The max single allocation seen. + int64_t bytes_limit; // The upper limit what the allocator can allocate, if such a limit + // is known. Certain allocator may return 0 to indicate the limit is + // unknown. +}; + struct CustomAllocator : OrtAllocator { - CustomAllocator(const OrtMemoryInfo* mem_info) : memory_info{mem_info} { + CustomAllocator(const OrtMemoryInfo* mem_info, const ApiPtrs& api_ptrs_in) + : memory_info{mem_info}, api_ptrs{api_ptrs_in} { + version = ORT_API_VERSION; Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; - Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + GetStats = GetStatsImpl; // this can be set to nullptr if you don't want to implement it } - static void* ORT_API_CALL AllocImpl(struct OrtAllocator* /*this_*/, size_t size) { - // CustomAllocator& impl = *static_cast(this_); + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + CustomAllocator& impl = *static_cast(this_); + ++impl.stats.num_allocs; + impl.stats.max_alloc_size = std::max(size, impl.stats.max_alloc_size); + return malloc(size); } @@ -29,6 +48,24 @@ struct CustomAllocator : OrtAllocator { return impl.memory_info; } + static OrtStatus* ORT_API_CALL GetStatsImpl(const struct OrtAllocator* this_, OrtKeyValuePairs** out) noexcept { + const CustomAllocator& impl = *static_cast(this_); + + OrtKeyValuePairs* kvps; + impl.api_ptrs.ort_api.CreateKeyValuePairs(&kvps); + + // if you wish to return stats the values in GetStatus should be formatted like this: + // https://github.com/microsoft/onnxruntime/blob/2f878c60296de169a8a523e692d3d65893f7c133/onnxruntime/core/session/allocator_adapters.cc#L75-L85 + + impl.api_ptrs.ort_api.AddKeyValuePair(kvps, "NumAllocs", std::to_string(impl.stats.num_allocs).c_str()); + impl.api_ptrs.ort_api.AddKeyValuePair(kvps, "MaxAllocSize", std::to_string(impl.stats.max_alloc_size).c_str()); + + *out = kvps; + return nullptr; + } + private: const OrtMemoryInfo* memory_info; + const ApiPtrs api_ptrs; + AllocatorStats stats{}; }; diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.cc b/onnxruntime/test/autoep/library/ep_data_transfer.cc index 48f97fe88ec44..ca8d0cf089fc6 100644 --- a/onnxruntime/test/autoep/library/ep_data_transfer.cc +++ b/onnxruntime/test/autoep/library/ep_data_transfer.cc @@ -7,12 +7,10 @@ #include /*static*/ -bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(void* this_ptr, +bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this_ptr, const OrtMemoryDevice* src_memory_device, const OrtMemoryDevice* dst_memory_device) noexcept { - static constexpr uint32_t VendorId = 0xBE57; // Example vendor ID for demonstration purposes. - - auto& impl = *static_cast(this_ptr); + const auto& impl = *static_cast(this_ptr); bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info); bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info); @@ -24,30 +22,37 @@ bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(void* this_ptr, // and the vendor and device IDs as needed. OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); - // OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_memory_device); - // OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_memory_device); - // uint32_t src_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); - // uint32_t dst_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); - // uint32_t src_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); - // uint32_t dst_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_memory_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_memory_device); + // we can copy to/from CPU or CPU accessible memory if (src_is_our_device) { - // check device type and vendor to see if compatible - return (dst_device_type == OrtMemoryInfoDeviceType_CPU); + return (dst_device_type == OrtMemoryInfoDeviceType_CPU || dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE); } if (dst_is_our_device) { - // check device type and vendor to see if compatible - return (src_device_type == OrtMemoryInfoDeviceType_CPU); + return (src_device_type == OrtMemoryInfoDeviceType_CPU || src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE); } return false; } +namespace { +void CopyImpl(const void* src_data, void* dst_data, size_t bytes, OrtSyncStream* stream) { + // in our example setup this is really CPU to CPU + + if (stream) { + // EP can do an async copy using the stream. e.g. an NVIDIA EP would provide the stream to cudaMemcpyAsync + } + + memcpy(dst_data, src_data, bytes); +} +} // namespace + // function to copy one or more tensors. // implementation can optionally use async copy if a stream is available for the input. /*static*/ -OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, +OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr, OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, @@ -56,11 +61,10 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, auto src_tensors = gsl::make_span(src_tensors_ptr, num_tensors); auto dst_tensors = gsl::make_span(dst_tensors_ptr, num_tensors); - auto streams = gsl::make_span(streams_ptr, num_tensors); for (size_t i = 0; i < num_tensors; ++i) { - // NOTE: Stream support will be a separate PR. ignore teh streams_ptr values for now - + // the implementation for a 'real' EP would be something along these lines. + // See CudaDataTransferImpl in onnxruntime\core\providers\cuda\cuda_provider_factory.cc const OrtMemoryDevice* src_device = nullptr; const OrtMemoryDevice* dst_device = nullptr; RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensors[i], &src_device)); @@ -71,13 +75,16 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, // OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); // OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); - // bool copy_involves_pinned_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || - // dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; + // bool copy_involves_host_accessible_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || + // dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; const void* src_data = nullptr; void* dst_data = nullptr; + size_t bytes; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensors[i], &src_data)); RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensors[i], &dst_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensors[i], &bytes)); if (dst_device_type == OrtMemoryInfoDeviceType_GPU) { if (src_device_type == OrtMemoryInfoDeviceType_GPU) { @@ -88,15 +95,18 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, } else if (src_device_type == OrtMemoryInfoDeviceType_GPU) { // GPU -> CPU } else { - // CPU -> CPU involves copy to/from pinned memory and a synchronize may be required first + // CPU -> CPU. may involve copy a to/from host accessible memory and a synchronize may be required first } + + // but in our example EP it's simpler as it's really a (fake) CPU to CPU copy + CopyImpl(src_data, dst_data, bytes, streams_ptr ? streams_ptr[i] : nullptr); } return nullptr; } /*static*/ -void ORT_API_CALL ExampleDataTransfer::ReleaseImpl(void* /*this_ptr*/) noexcept { +void ORT_API_CALL ExampleDataTransfer::ReleaseImpl(OrtDataTransferImpl* /*this_ptr*/) noexcept { // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) // diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.h b/onnxruntime/test/autoep/library/ep_data_transfer.h index d73b9e457b844..da74d42b4affe 100644 --- a/onnxruntime/test/autoep/library/ep_data_transfer.h +++ b/onnxruntime/test/autoep/library/ep_data_transfer.h @@ -7,28 +7,26 @@ struct ExampleDataTransfer : OrtDataTransferImpl, ApiPtrs { ExampleDataTransfer(ApiPtrs api_ptrs, - const OrtMemoryDevice* device_mem_info_, - const OrtMemoryDevice* shared_mem_info_ = nullptr) - : ApiPtrs(api_ptrs), device_mem_info{device_mem_info_}, shared_mem_info{shared_mem_info_} { + const OrtMemoryDevice* device_mem_info_) + : ApiPtrs(api_ptrs), device_mem_info{device_mem_info_} { CanCopy = CanCopyImpl; CopyTensors = CopyTensorsImpl; Release = ReleaseImpl; } - static bool ORT_API_CALL CanCopyImpl(void* this_ptr, + static bool ORT_API_CALL CanCopyImpl(const OrtDataTransferImpl* this_ptr, const OrtMemoryDevice* src_memory_device, const OrtMemoryDevice* dst_memory_device) noexcept; // function to copy one or more tensors. // implementation can optionally use async copy if a stream is available for the input. - static OrtStatus* ORT_API_CALL CopyTensorsImpl(void* this_ptr, + static OrtStatus* ORT_API_CALL CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr, OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, size_t num_tensors) noexcept; - static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; private: - const OrtMemoryDevice* device_mem_info; - const OrtMemoryDevice* shared_mem_info; + const OrtMemoryDevice* device_mem_info; // device our EP runs on }; diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 19a44008b8c97..97cdea15187d9 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -8,6 +8,7 @@ #include "ep.h" #include "ep_allocator.h" #include "ep_data_transfer.h" +#include "ep_stream_support.h" ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { @@ -27,35 +28,27 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) CreateDataTransfer = CreateDataTransferImpl; - // for the sake of this example we specify a CPU allocator with no arena and 1K alignment (arbitrary) - // as well as GPU and GPU shared memory. the actual EP implementation would typically define two at most for a - // device (one for device memory and one for shared memory for data transfer between device and CPU) + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // setup the OrtMemoryInfo instances required by the EP. + // We pretend the device the EP is running on is GPU. OrtMemoryInfo* mem_info = nullptr; - auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP CPU", OrtMemoryInfoDeviceType_CPU, + auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU, /*vendor*/ 0xBE57, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 1024, - OrtAllocatorType::OrtDeviceAllocator, // no arena + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, &mem_info); assert(status == nullptr); // should never fail. + default_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - cpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + // create data transfer for the device + const OrtMemoryDevice* device = ep_api.MemoryInfo_GetMemoryDevice(default_memory_info_.get()); + data_transfer_impl_ = std::make_unique(apis, device); - // - // GPU allocator OrtMemoryInfo for example purposes - mem_info = nullptr; - status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU, - /*vendor*/ 0xBE57, /* device_id */ 0, - OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 0, - OrtAllocatorType::OrtDeviceAllocator, - &mem_info); - assert(status == nullptr); // should never fail. - default_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - - // HOST_ACCESSIBLE memory should use the non-CPU device type + // HOST_ACCESSIBLE memory example. use the non-CPU device type so it's clear which device the memory is also + // accessible from. we infer from the type of HOST_ACCESSIBLE that it's CPU accessible. mem_info = nullptr; status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU pinned", OrtMemoryInfoDeviceType_GPU, /*vendor*/ 0xBE57, /* device_id */ 0, @@ -63,17 +56,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); - assert(status == nullptr); // should never fail. - host_accessible_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - - // if we were to use GPU we'd create it like this - data_transfer_impl_ = std::make_unique( - apis, - ep_api.MemoryInfo_GetMemoryDevice(default_gpu_memory_info_.get()), // device memory - ep_api.MemoryInfo_GetMemoryDevice(host_accessible_gpu_memory_info_.get()) // shared memory - ); - - data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer. + ort_api.ReleaseMemoryInfo(mem_info); } /*static*/ @@ -137,9 +120,8 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* } // register the allocator info required by the EP. - // in this example we register CPU info which is unnecessary unless you need to override the default ORT allocator - // for a non-CPU EP this would be device info (GPU/NPU) and possible host accessible info. - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->cpu_memory_info_.get())); + // registering OrtMemoryInfo for host accessible memory would be done in an additional call. + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_.get())); ep_devices[num_ep_devices++] = ep_device; } @@ -218,32 +200,17 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this auto& factory = *static_cast(this_ptr); *allocator = nullptr; - // NOTE: The factory implementation can return a shared OrtAllocator* instead of creating a new instance on each call. - // To do this just make ReleaseAllocatorImpl a no-op. - - // NOTE: If OrtMemoryInfo has allocator type (call MemoryInfoGetType) of OrtArenaAllocator, an ORT BFCArena - // will be added to wrap the returned OrtAllocator. The EP is free to implement its own arena, and if it - // wants to do this the OrtMemoryInfo MUST be created with an allocator type of OrtDeviceAllocator. - - // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based - // matching should work. - if (memory_info == factory.cpu_memory_info_.get()) { - // create a CPU allocator. use the basic OrtAllocator for this example. - auto cpu_allocator = std::make_unique(memory_info); - *allocator = cpu_allocator.release(); - } else if (memory_info == factory.default_gpu_memory_info_.get()) { - // create a GPU allocator - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Example is not implemented."); - } else if (memory_info == factory.host_accessible_gpu_memory_info_.get()) { - // create a pinned/shared memory allocator. Use the real device type (i.e. GPU/NPU) and id and a memory type of - // OrtDeviceMemoryType_HOST_ACCESSIBLE. - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Example is not implemented."); - } else { + if (memory_info != factory.default_memory_info_.get()) { return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " "Value did not come directly from an OrtEpDevice returned by this factory."); } + // NOTE: The factory implementation is free to return a shared OrtAllocator* instance instead of creating a new + // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make + // ReleaseAllocatorImpl a no-op. + auto cpu_allocator = std::make_unique(memory_info, factory); + *allocator = cpu_allocator.release(); return nullptr; } @@ -260,3 +227,25 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateDataTransferImpl(OrtEpFactory* t return nullptr; } + +/*static*/ +bool ORT_API_CALL ExampleEpFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return true; // the example EP implements stream synchronization. +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept { + auto& factory = *static_cast(this_ptr); + *stream = nullptr; + + // we only need stream synchronization on the device stream + if (factory.ep_api.MemoryDevice_GetMemoryType(memory_device) == OrtDeviceMemoryType_DEFAULT) { + auto sync_stream = std::make_unique(factory, /*OrtEp**/ nullptr, stream_options); + *stream = sync_stream.release(); + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 72fa1c1301841..261730b8adf83 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -52,6 +52,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept; + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID @@ -59,12 +66,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. using MemoryInfoUniquePtr = std::unique_ptr>; - MemoryInfoUniquePtr cpu_memory_info_; - - // for example purposes. if the EP used GPU, and pinned/shared memory was required for data transfer, these are the - // OrtMemoryInfo instance required for that. - MemoryInfoUniquePtr default_gpu_memory_info_; - MemoryInfoUniquePtr host_accessible_gpu_memory_info_; + MemoryInfoUniquePtr default_memory_info_; std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/ep_stream_support.cc new file mode 100644 index 0000000000000..a948fe1bfce1e --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_stream_support.cc @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_stream_support.h" + +// +// StreamImpl implementation +// + +/*static*/ +OrtStatus* ORT_API_CALL StreamImpl::CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification) noexcept { + auto& impl = *static_cast(this_ptr); + *notification = std::make_unique(impl).release(); + return nullptr; +} + +/*static*/ +void* ORT_API_CALL StreamImpl::GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return impl.handle_; +} + +/*static*/ +OrtStatus* ORT_API_CALL StreamImpl::FlushImpl(_In_ OrtSyncStreamImpl* /*this_ptr*/) noexcept { + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL StreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* /*this_ptr*/) noexcept { + return nullptr; +} + +// callback for EP library to release any internal state +/*static*/ +void ORT_API_CALL StreamImpl::ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +// +// Notification support +// + +/*static*/ +OrtStatus* ORT_API_CALL NotificationImpl::ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + static_cast(impl); + + // e.g. + // CUDA: cudaEventRecord + // CANN: aclrtRecordEvent + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL NotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept { + auto& impl = *static_cast(this_ptr); + void* handle = impl.ort_api.SyncStream_GetHandle(stream); + static_cast(handle); + + // Setup the event or similar that will be activated on notification. + // See CudaNotification or CannNotification for examples. + // + // e.g. + // CUDA: cudaStreamWaitEvent(static_cast(device_stream.GetHandle()), event_) + // CANN: aclrtStreamWaitEvent(static_cast(device_stream.GetHandle()), event_) + // + // `event_` should be a member that is created in the ctor. + // The stream handle should come from the StreamImpl instance and can be the real type so no static_cast is needed. + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL NotificationImpl::WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + static_cast(impl); + + // e.g. + // CUDA: cudaEventSynchronize(event_) + // CANN: aclrtSynchronizeEvent(event_) + return nullptr; +} + +/*static*/ +void ORT_API_CALL NotificationImpl::ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/ep_stream_support.h b/onnxruntime/test/autoep/library/ep_stream_support.h new file mode 100644 index 0000000000000..10c4804722f8b --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_stream_support.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "onnxruntime_c_api.h" +#include "example_plugin_ep_utils.h" + +// +// Class implementing Stream support for synchronization. +// +class StreamImpl : public OrtSyncStreamImpl, public ApiPtrs { + public: + StreamImpl(ApiPtrs apis, const OrtEp* ep, const OrtKeyValuePairs* /*stream_options*/) + : ApiPtrs(apis), ep_{ep} { + ort_version_supported = ORT_API_VERSION; + CreateNotification = CreateNotificationImpl; + GetHandle = GetHandleImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + } + + private: + static OrtStatus* ORT_API_CALL CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** sync_notification) noexcept; + static void* ORT_API_CALL GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + + void* handle_{nullptr}; // use the real stream type, like cudaStream_t or aclrtStream, etc. + + // EP instance if the stream is being created internally for inferencing. + // nullptr when the stream is created outside of an inference session for data copies. + const OrtEp* ep_; +}; + +// +// Class implementing synchronization notification support. +// +class NotificationImpl : public OrtSyncNotificationImpl, public ApiPtrs { + public: + NotificationImpl(ApiPtrs apis) : ApiPtrs(apis) { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + Release = ReleaseImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + } + + private: + static OrtStatus* ORT_API_CALL ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept; + static OrtStatus* ORT_API_CALL WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + + void* event_{NULL}; // placeholder. e.g. CANN uses aclrtEvent, CUDA uses cudaEvent_t +}; diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc new file mode 100644 index 0000000000000..84b6e284ccb8e --- /dev/null +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +#include +#include +#include + +#include "core/framework/allocator.h" +#include "core/session/abi_key_value_pairs.h" +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +namespace { +struct DummyAllocator : OrtAllocator { + DummyAllocator(const OrtMemoryInfo* mem_info) + : memory_info{mem_info} { + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + GetStats = nullptr; // this can be set to nullptr if not implemented + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& impl = *static_cast(this_); + ++impl.stats.num_allocs; + impl.stats.max_alloc_size = std::max(size, impl.stats.max_alloc_size); + + return malloc(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* /*this_*/, void* p) { + return free(p); + } + + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const DummyAllocator& impl = *static_cast(this_); + return impl.memory_info; + } + + private: + const OrtMemoryInfo* memory_info; + AllocatorStats stats{}; +}; +} // namespace + +// validate CreateSharedAllocator allows adding an arena to the shared allocator +TEST(SharedAllocators, AddArenaToSharedAllocator) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + + const auto* ep_memory_info = c_api.EpDevice_MemoryInfo(example_ep.get(), OrtDeviceMemoryType_DEFAULT); + + // validate there is a shared allocator + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + + // call CreateSharedAllocator to replace with arena based allocator. arena is configured with kvps + OrtKeyValuePairs allocator_options; + auto initial_chunk_size = "25600"; // arena allocates in 256 byte amounts + allocator_options.Add(OrtArenaCfg::ConfigKeyNames::InitialChunkSizeBytes, initial_chunk_size); + + ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), + OrtDeviceMemoryType_DEFAULT, OrtArenaAllocator, &allocator_options, + &allocator)); + + // first allocation should init the arena to the initial chunk size + void* mem = allocator->Alloc(allocator, 16); + allocator->Free(allocator, mem); + + // stats should prove the arena was used + OrtKeyValuePairs* allocator_stats = nullptr; + ASSERT_ORTSTATUS_OK(allocator->GetStats(allocator, &allocator_stats)); + + using ::testing::Contains; + using ::testing::Pair; + const auto& stats = allocator_stats->Entries(); + EXPECT_THAT(stats, Contains(Pair("NumAllocs", "1"))); + EXPECT_THAT(stats, Contains(Pair("NumArenaExtensions", "1"))); + EXPECT_THAT(stats, Contains(Pair("TotalAllocated", initial_chunk_size))); + + // optional. ORT owns the allocator but we want to test the release implementation + ASSERT_ORTSTATUS_OK(c_api.ReleaseSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT)); +} + +TEST(SharedAllocators, GetSharedAllocator) { + const OrtApi& c_api = Ort::GetApi(); + + // default CPU allocator should be available. + // create a memory info with a different name to validate the shared allocator lookup ignores the name + OrtMemoryInfo* test_cpu_memory_info = nullptr; + ASSERT_ORTSTATUS_OK(c_api.CreateMemoryInfo_V2("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, + OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator, + &test_cpu_memory_info)); + + const auto get_allocator_and_check_name = [&](const std::string& expected_name) { + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, test_cpu_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + + const OrtMemoryInfo* ort_cpu_memory_info = nullptr; + ASSERT_ORTSTATUS_OK(c_api.AllocatorGetInfo(allocator, &ort_cpu_memory_info)); + const char* allocator_name; + ASSERT_ORTSTATUS_OK(c_api.MemoryInfoGetName(ort_cpu_memory_info, &allocator_name)); + ASSERT_EQ(expected_name, allocator_name); // Default ORT CPU allocator + }; + + // check we get the default ORT CPU allocator initially + get_allocator_and_check_name(onnxruntime::CPU); + + // register custom allocator and make sure that is accessible by exact match + DummyAllocator dummy_alloc{test_cpu_memory_info}; + c_api.RegisterAllocator(*ort_env, &dummy_alloc); + + // GetSharedAllocator should now match the custom allocator + get_allocator_and_check_name("dummy"); + + // unregister custom allocator + ASSERT_ORTSTATUS_OK(c_api.UnregisterAllocator(*ort_env, test_cpu_memory_info)); + + // there should always be a CPU allocator available + get_allocator_and_check_name(onnxruntime::CPU); + + c_api.ReleaseMemoryInfo(test_cpu_memory_info); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_autoep_utils.cc b/onnxruntime/test/autoep/test_autoep_utils.cc new file mode 100644 index 0000000000000..7045ccca2f576 --- /dev/null +++ b/onnxruntime/test/autoep/test_autoep_utils.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include "test/autoep/test_autoep_utils.h" + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/api_asserts.h" + +namespace onnxruntime { +namespace test { + +Utils::ExamplePluginInfo Utils::example_ep_info; + +void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device) { + const OrtApi& c_api = Ort::GetApi(); + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices; + ASSERT_ORTSTATUS_OK(c_api.GetEpDevices(env, &ep_devices, &num_devices)); + + auto it = std::find_if(ep_devices, ep_devices + num_devices, + [&c_api, &ep_name](const OrtEpDevice* ep_device) { + // example ep uses registration name as ep name + return c_api.EpDevice_EpName(ep_device) == ep_name; + }); + + if (it == ep_devices + num_devices) { + ep_device = nullptr; + } else { + ep_device = *it; + } +} + +void Utils::RegisterAndGetExampleEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& registered_ep) { + const OrtApi& c_api = Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(env, + example_ep_info.registration_name.c_str(), + example_ep_info.library_path.c_str())); + const OrtEpDevice* example_ep = nullptr; + GetEp(env, example_ep_info.registration_name, example_ep); + ASSERT_NE(example_ep, nullptr); + + registered_ep = RegisteredEpDeviceUniquePtr(example_ep, [&env, c_api](const OrtEpDevice* /*ep*/) { + c_api.UnregisterExecutionProviderLibrary(env, example_ep_info.registration_name.c_str()); + }); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_autoep_utils.h b/onnxruntime/test/autoep/test_autoep_utils.h new file mode 100644 index 0000000000000..2dd7b5f0428e2 --- /dev/null +++ b/onnxruntime/test/autoep/test_autoep_utils.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { +namespace test { + +using RegisteredEpDeviceUniquePtr = std::unique_ptr>; + +struct Utils { + struct ExamplePluginInfo { + const std::filesystem::path library_path = +#if _WIN32 + "example_plugin_ep.dll"; +#else + "libexample_plugin_ep.so"; +#endif + const std::string registration_name = "example_ep"; + }; + + static ExamplePluginInfo example_ep_info; + + // get the OrtEpDevice for an arbitrary EP from the environment + static void GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device); + + // Register the example EP library, get the OrtEpDevice for it, and return a unique pointer that will + // automatically unregister the EP library. + static void RegisterAndGetExampleEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& example_ep); +}; +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/autoep/test_data_transfer.cc b/onnxruntime/test/autoep/test_data_transfer.cc new file mode 100644 index 0000000000000..cc09699b754b6 --- /dev/null +++ b/onnxruntime/test/autoep/test_data_transfer.cc @@ -0,0 +1,81 @@ + + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/common/random_generator.h" +#include "test/util/include/api_asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +TEST(OrtEpLibrary, DataTransfer) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + const OrtEpDevice* ep_device = example_ep.get(); + + const OrtMemoryInfo* device_memory_info = c_api.EpDevice_MemoryInfo(ep_device, OrtDeviceMemoryType_DEFAULT); + + // create a tensor using the default CPU allocator + Ort::AllocatorWithDefaultOptions cpu_allocator; + std::vector shape{2, 3, 4}; // shape doesn't matter + const size_t num_elements = 2 * 3 * 4; + + RandomValueGenerator random{}; + std::vector input_data = random.Gaussian(shape, 0.0f, 2.f); + Ort::Value cpu_tensor = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), + input_data.data(), input_data.size(), + shape.data(), shape.size()); + + // create an on-device Tensor using the example EPs alternative CPU allocator. + // it has a different vendor to the default ORT CPU allocator so we can copy between them even though both are + // really CPU based. + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, device_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + Ort::Value device_tensor = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::vector src_tensor_ptrs{cpu_tensor}; + std::vector dst_tensor_ptrs{device_tensor}; + + ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), nullptr, + src_tensor_ptrs.size())); + + const float* src_data = nullptr; + const float* dst_data = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor, reinterpret_cast(&src_data))); + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(device_tensor, reinterpret_cast(&dst_data))); + + size_t bytes; + ASSERT_ORTSTATUS_OK(c_api.GetTensorSizeInBytes(cpu_tensor, &bytes)); + ASSERT_EQ(bytes, num_elements * sizeof(float)); + + ASSERT_NE(src_data, dst_data) << "Should have copied between two different memory locations"; + + auto src_span = gsl::make_span(src_data, num_elements); + auto dst_span = gsl::make_span(dst_data, num_elements); + + EXPECT_THAT(src_span, ::testing::ContainerEq(dst_span)); + + // must release this before we unload the EP and the allocator is deleted + device_tensor = Ort::Value(); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc new file mode 100644 index 0000000000000..f1ef67e1f6ba4 --- /dev/null +++ b/onnxruntime/test/autoep/test_execution.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +// #include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +namespace { +void RunModelWithPluginEp(Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + + // Create input + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data(6, 2.0f); + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); +} +} // namespace + +// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + const OrtEpDevice* plugin_ep_device = example_ep.get(); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + + RunModelWithPluginEp(session_options); +} + +// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Uses the PREFER_CPU policy to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + + { + // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunModelWithPluginEp(session_options); + } +} + +// Generate an EPContext model with a plugin EP. +// This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. +TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + const OrtEpDevice* plugin_ep_device = example_ep.get(); + + { + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + + session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + + // Make sure the compiled model was generated. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + } +} +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc new file mode 100644 index 0000000000000..88c2e320990e1 --- /dev/null +++ b/onnxruntime/test/autoep/test_registration.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { + const std::filesystem::path& library_path = Utils::example_ep_info.library_path; + const std::string& registration_name = Utils::example_ep_info.registration_name; + + const OrtApi* c_api = &Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(*ort_env, registration_name.c_str(), + library_path.c_str())); + + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices = 0; + + ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(*ort_env, &ep_devices, &num_devices)); + // should be one device for the example EP + auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, + [®istration_name, &c_api](const OrtEpDevice* device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return c_api->EpDevice_EpName(device) == registration_name; + }); + ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; + + // and this should unload it + ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, + registration_name.c_str())); +} + +TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { + const std::filesystem::path& library_path = Utils::example_ep_info.library_path; + const std::string& registration_name = Utils::example_ep_info.registration_name; + + // this should load the library and create OrtEpDevice + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + std::vector ep_devices = ort_env->GetEpDevices(); + + // should be one device for the example EP + auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [®istration_name](Ort::ConstEpDevice& device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return device.EpName() == registration_name; + }); + ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; + + // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); + + auto metadata = test_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); + ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); + + auto options = test_ep_device->EpOptions(); + ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); + + // the CPU device info will vary by machine so check for the lowest common denominator values + Ort::ConstHardwareDevice device = test_ep_device->Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + Ort::ConstKeyValuePairs device_metadata = device.Metadata(); + std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); + ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows + + // and this should unload it without throwing + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_selection.cc similarity index 67% rename from onnxruntime/test/autoep/test_autoep_selection.cc rename to onnxruntime/test/autoep/test_selection.cc index 01dece34e50b0..72f39be917f90 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_selection.cc @@ -5,18 +5,17 @@ #ifdef _WIN32 #include -#include +// #include +#include #include -#include "core/common/common.h" -#include "core/framework/provider_options.h" #include "core/graph/constants.h" #include "core/session/abi_key_value_pairs.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "test_allocator.h" +#include "test/autoep/test_autoep_utils.h" #include "test/shared_lib/utils.h" #include "test/util/include/api_asserts.h" #include "test/util/include/asserts.h" @@ -501,212 +500,6 @@ TEST(AutoEpSelection, PolicyDelegateReturnsError) { Ort::Exception); } -namespace { -struct ExamplePluginInfo { - const std::filesystem::path library_path = -#if _WIN32 - "example_plugin_ep.dll"; -#else - "libexample_plugin_ep.so"; -#endif - const std::string registration_name = "example_ep"; -}; - -static const ExamplePluginInfo example_plugin_info; -} // namespace - -TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - OrtEnv* c_api_env = *ort_env; - const OrtApi* c_api = &Ort::GetApi(); - // this should load the library and create OrtEpDevice - ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(c_api_env, registration_name.c_str(), - library_path.c_str())); - - const OrtEpDevice* const* ep_devices = nullptr; - size_t num_devices = 0; - - ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(c_api_env, &ep_devices, &num_devices)); - // should be one device for the example EP - auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, - [®istration_name, &c_api](const OrtEpDevice* device) { - // the example uses the registration name for the EP name - // but that is not a requirement and the two can differ. - return c_api->EpDevice_EpName(device) == registration_name; - }); - ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; - - // and this should unload it - ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(c_api_env, - registration_name.c_str())); -} - -TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - // this should load the library and create OrtEpDevice - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - std::vector ep_devices = ort_env->GetEpDevices(); - - // should be one device for the example EP - auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), - [®istration_name](Ort::ConstEpDevice& device) { - // the example uses the registration name for the EP name - // but that is not a requirement and the two can differ. - return device.EpName() == registration_name; - }); - ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; - - // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc - ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); - - auto metadata = test_ep_device->EpMetadata(); - ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); - ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); - - auto options = test_ep_device->EpOptions(); - ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); - - // the CPU device info will vary by machine so check for the lowest common denominator values - Ort::ConstHardwareDevice device = test_ep_device->Device(); - ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); - ASSERT_GE(device.VendorId(), 0); - ASSERT_GE(device.DeviceId(), 0); - ASSERT_NE(device.Vendor(), nullptr); - Ort::ConstKeyValuePairs device_metadata = device.Metadata(); - std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); - ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows - - // and this should unload it without throwing - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} - -static void RunModelWithPluginEp(Ort::SessionOptions& session_options) { - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); - - // Create input - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector shape = {3, 2}; - std::vector input0_data(6, 2.0f); - std::vector ort_inputs; - std::vector ort_input_names; - - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); - ort_input_names.push_back("X"); - - // Run session and get outputs - std::array output_names{"Y"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check expected output values - Ort::Value& ort_output = ort_outputs[0]; - const float* output_data = ort_output.GetTensorData(); - gsl::span output_span(output_data, 6); - EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); -} - -// Creates a session with the example plugin EP and runs a model with a single Mul node. -// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. -TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - { - std::vector ep_devices = ort_env->GetEpDevices(); - - // Find the OrtEpDevice associated with our example plugin EP. - Ort::ConstEpDevice plugin_ep_device; - for (Ort::ConstEpDevice& device : ep_devices) { - if (std::string(device.EpName()) == registration_name) { - plugin_ep_device = device; - break; - } - } - ASSERT_NE(plugin_ep_device, nullptr); - - // Create session with example plugin EP - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - - RunModelWithPluginEp(session_options); - } - - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} - -// Creates a session with the example plugin EP and runs a model with a single Mul node. -// Uses the PREFER_CPU policy to append the example plugin EP to the session. -TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - { - // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. - Ort::SessionOptions session_options; - session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); - RunModelWithPluginEp(session_options); - } - - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} - -// Generate an EPContext model with a plugin EP. -// This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. -TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - { - const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); - const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_ctx.onnx"); - std::filesystem::remove(output_model_file); - - std::vector ep_devices = ort_env->GetEpDevices(); - - // Find the OrtEpDevice associated with our example plugin EP. - Ort::ConstEpDevice plugin_ep_device; - for (Ort::ConstEpDevice& device : ep_devices) { - if (std::string(device.EpName()) == registration_name) { - plugin_ep_device = device; - break; - } - } - ASSERT_NE(plugin_ep_device, nullptr); - - // Create session with example plugin EP - Ort::SessionOptions session_options; - std::unordered_map ep_options; - - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - - // Create model compilation options from the session options. - Ort::ModelCompilationOptions compile_options(*ort_env, session_options); - compile_options.SetInputModelPath(input_model_file); - compile_options.SetOutputModelPath(output_model_file); - - // Compile the model. - Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); - - // Make sure the compiled model was generated. - ASSERT_TRUE(std::filesystem::exists(output_model_file)); - } - - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 9bc50ce88ef16..2335db69ff571 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -12,6 +12,7 @@ #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" +#include "core/framework/float16.h" #include "core/framework/int4.h" #include "test/util/include/test_random_seed.h" diff --git a/onnxruntime/test/shared_lib/test_data_copy.cc b/onnxruntime/test/shared_lib/test_data_copy.cc new file mode 100644 index 0000000000000..c09dbda745b76 --- /dev/null +++ b/onnxruntime/test/shared_lib/test_data_copy.cc @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/graph/constants.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/api_asserts.h" +#include "test/shared_lib/utils.h" + +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" +#include +#endif + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +using StreamUniquePtr = std::unique_ptr>; + +#ifdef USE_CUDA +// test copying input to CUDA using an OrtEpFactory based EP. +// tests GetSharedAllocator, CreateSyncStreamForEpDevice and CopyTensors APIs +TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { + OrtEnv* env = *ort_env; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + +#ifdef _WIN32 + std::string cuda_lib = "onnxruntime_providers_cuda.dll"; +#else + std::string cuda_lib = "onnxruntime_providers_cuda.so"; +#endif + + if (!std::filesystem::exists(cuda_lib)) { + GTEST_SKIP() << "CUDA library was not found"; + } + + // register the provider bridge based CUDA EP so allocator and data transfer is available + // not all the CIs have the provider library in the expected place so we allow for that + const char* ep_registration_name = "ORT CUDA"; + ASSERT_ORTSTATUS_OK(api->RegisterExecutionProviderLibrary(env, ep_registration_name, + ORT_TSTR("onnxruntime_providers_cuda"))); + + const OrtEpDevice* cuda_device = nullptr; + for (const auto& ep_device : ort_env->GetEpDevices()) { + std::string vendor{ep_device.EpVendor()}; + std::string name = {ep_device.EpName()}; + if (vendor == std::string("Microsoft") && name == kCudaExecutionProvider) { + cuda_device = ep_device; + break; + } + } + + if (!cuda_device) { // device running tests may not have an nvidia card + return; + } + + const auto run_test = [&](bool use_streams) { + Ort::SessionOptions options; + options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); + + // we pass in the CUDA cudaStream_t from the OrtSyncStream via provider options so need to create it upfront. + // in the future the stream should be an input to the Session Run. + OrtSyncStream* stream = nullptr; + StreamUniquePtr stream_ptr; + if (use_streams) { + ASSERT_ORTSTATUS_OK(api->CreateSyncStreamForEpDevice(cuda_device, /*options*/ nullptr, &stream)); + stream_ptr = StreamUniquePtr(stream, [api](OrtSyncStream* stream) { api->ReleaseSyncStream(stream); }); + + size_t stream_addr = reinterpret_cast(api->SyncStream_GetHandle(stream)); + options.AddConfigEntry("ep.cudaexecutionprovider.user_compute_stream", std::to_string(stream_addr).c_str()); + // we explicitly specify user_compute_stream, so why do we also need to set has_user_compute_stream? + options.AddConfigEntry("ep.cudaexecutionprovider.has_user_compute_stream", "1"); + } + + Ort::Session session(*ort_env, ORT_TSTR("testdata/mnist.onnx"), options); + + size_t num_inputs = session.GetInputCount(); + + // find the input location so we know which inputs can be provided on device. + std::vector input_locations; + input_locations.resize(num_inputs, nullptr); + ASSERT_ORTSTATUS_OK(api->SessionGetMemoryInfoForInputs(session, input_locations.data(), num_inputs)); + + std::vector cpu_tensors; + + // info for device copy + std::vector src_tensor_ptrs; + std::vector dst_tensor_ptrs; + + // values we'll call Run with + std::vector input_tensors; + + ASSERT_EQ(num_inputs, 1); + + // create cpu based input data. + Ort::AllocatorWithDefaultOptions cpu_allocator; + std::vector shape{1, 1, 28, 28}; + std::vector input_data(28 * 28, 0.5f); + Ort::Value input_value = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), + input_data.data(), input_data.size(), + shape.data(), shape.size()); + cpu_tensors.push_back(std::move(input_value)); + + for (size_t idx = 0; idx < num_inputs; ++idx) { + const OrtMemoryInfo* mem_info = input_locations[idx]; + OrtDeviceMemoryType mem_type; + OrtMemoryInfoDeviceType device_type; + ASSERT_ORTSTATUS_OK(api->MemoryInfoGetDeviceMemType(mem_info, &mem_type)); + api->MemoryInfoGetDeviceType(mem_info, &device_type); + + if (device_type == OrtMemoryInfoDeviceType_GPU && mem_type == OrtDeviceMemoryType_DEFAULT) { + // copy to device + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(api->GetSharedAllocator(env, mem_info, &allocator)); + + // allocate new on-device memory + auto src_shape = cpu_tensors[idx].GetTensorTypeAndShapeInfo().GetShape(); + Ort::Value device_value = Ort::Value::CreateTensor(allocator, src_shape.data(), src_shape.size()); + + /* if you have existing memory on device use one of these instead of CreateTensorAsOrtValue + CopyTensors + void* existing_data; + size_t data_length = 128 * sizeof(float); + api->CreateTensorWithDataAsOrtValue(input_locations[0], existing_data, data_length, shape, 2, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value); + + // existing with ownership transfer. ORT will use the allocator to free the memory once it is no longer required + api->CreateTensorWithDataAndDeleterAsOrtValue(allocator, existing_data, data_length, shape, 2, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value); + */ + + src_tensor_ptrs.push_back(cpu_tensors[idx]); + dst_tensor_ptrs.push_back(device_value); + input_tensors.push_back(std::move(device_value)); + } else { + // input is on CPU accessible memory. move to input_tensors + input_tensors.push_back(std::move(cpu_tensors[idx])); + } + } + + if (!src_tensor_ptrs.empty()) { + ASSERT_ORTSTATUS_OK(api->CopyTensors(env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), stream, + src_tensor_ptrs.size())); + + // Stream support is still a work in progress. + // + // CUDA EP can use a user provided stream via provider options, so we can pass in the cudaStream_t from the + // OrtSyncStream used in CopyTensors call that way. + // + // Alternatively you can manually sync the device via IoBinding. + // Ort::IoBinding iobinding(session); + // iobinding.SynchronizeInputs(); // this doesn't actually require any bound inputs + } + + std::vector input_names = {"Input3"}; + std::vector output_names = {"Plus214_Output_0"}; + Ort::Value output; + + session.Run(Ort::RunOptions{}, input_names.data(), input_tensors.data(), input_tensors.size(), + output_names.data(), &output, 1); + + const float* results = nullptr; + ASSERT_ORTSTATUS_OK(api->GetTensorData(output, reinterpret_cast(&results))); + + // expected results from the CPU EP. can check/re-create by running with PREFER_CPU. + std::vector expected = { + -0.701670527f, + -0.583666623f, + 0.0480501056f, + 0.550699294f, + -1.25372827f, + 1.17879760f, + 0.838122189f, + -1.51267099f, + 0.902430952f, + 0.243748352f, + }; + + for (size_t i = 0; i < expected.size(); ++i) { + EXPECT_NEAR(expected[i], results[i], 1e-3) << "i=" << i; + } + }; + + run_test(/*use_streams*/ true); + run_test(/*use_streams*/ false); + + ASSERT_ORTSTATUS_OK(api->UnregisterExecutionProviderLibrary(env, ep_registration_name)); +} +#endif // USE_CUDA + +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD