diff --git a/.github/workflows/windows_openvino.yml b/.github/workflows/windows_openvino.yml index f38fcdae57a35..96289e65502d9 100644 --- a/.github/workflows/windows_openvino.yml +++ b/.github/workflows/windows_openvino.yml @@ -54,12 +54,12 @@ jobs: core.exportVariable('ACTIONS_CACHE_URL', process.env.ACTIONS_CACHE_URL || ''); core.exportVariable('ACTIONS_RUNTIME_TOKEN', process.env.ACTIONS_RUNTIME_TOKEN || ''); - - name: Download OpenVINO Toolkit v2025.0.0 + - name: Download OpenVINO Toolkit v2025.2.0 env: - OpenVINOVersion: 2025.0.0 + OpenVINOVersion: 2025.2.0 shell: pwsh run: | - $Url = "https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.0/windows/openvino_toolkit_windows_2025.0.0.17942.1f68be9f594_x86_64.zip" + $Url ="https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.2/windows/openvino_toolkit_windows_2025.2.0.19140.c01cd93e24d_x86_64.zip" $OutputPath = "$env:RUNNER_TEMP\openvino.zip" $ExtractPath = "$env:RUNNER_TEMP\openvino-v$env:OpenVINOVersion" $TempExtractPath = "$env:RUNNER_TEMP\openvino_temp" @@ -102,7 +102,7 @@ jobs: shell: pwsh # Use $GITHUB_ENV to set the variable for subsequent steps run: | - $openVinoRootDir = Join-Path $env:RUNNER_TEMP "openvino-v2025.0.0" + $openVinoRootDir = Join-Path $env:RUNNER_TEMP "openvino-v2025.2.0" echo "OpenVINORootDir=$openVinoRootDir" >> $env:GITHUB_ENV - name: Print OpenVINORootDir after downloading OpenVINO diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index d1fb06a95f4c9..33c9e19911557 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -498,17 +498,19 @@ set (ONNXRUNTIME_AUTOEP_TEST_SRC_DIR "${TEST_SRC_DIR}/autoep") set (ONNXRUNTIME_EP_GRAPH_TEST_SRC_DIR "${TEST_SRC_DIR}/ep_graph") set (onnxruntime_shared_lib_test_SRC - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_allocator.cc - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_data_copy.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_fixture.h ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_loading.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_nontensor_types.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_ort_format_models.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_run_options.cc + ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_session_options.cc ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.h ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/utils.cc - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.h - ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/custom_op_utils.cc) + ) if (NOT onnxruntime_MINIMAL_BUILD) list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc) @@ -722,6 +724,7 @@ endif() if(onnxruntime_USE_QNN AND NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_REDUCED_OPS_BUILD) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/*) list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/qnn_node_group/*) + list(APPEND onnxruntime_test_framework_src_patterns ${TEST_SRC_DIR}/providers/qnn/optimizer/*) list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_qnn) list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_qnn) if(NOT onnxruntime_BUILD_QNN_EP_STATIC_LIB) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 9fb1eb9107774..0aad80c4ddab9 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -89,7 +89,12 @@ struct OrtTensorRTProviderOptionsV2 { size_t trt_onnx_bytestream_size{0}; // size of the byte stream provided as "trt_onnx_bytestream" // can be updated using: UpdateTensorRTProviderOptionsWithValue - const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix - int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true - const char* trt_op_types_to_exclude{}; // Exclude specific ops from running on TRT. + const void* trt_external_data_bytestream{nullptr}; // The byte stream containing the weights to override the ones provided in the ONNX model. + // can be updated using: UpdateTensorRTProviderOptionsWithValue + size_t trt_external_data_bytestream_size{0}; // size of the byte stream provided as "trt_external_data_bytestream" + // can be updated using: UpdateTensorRTProviderOptionsWithValue + const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix + int trt_engine_hw_compatible{0}; // Enable hardware compatibility. Default 0 = false, nonzero = true + const char* trt_op_types_to_exclude{}; // Exclude specific ops from running on TRT. + int trt_load_user_initializer{0}; // Save initializers locally instead of to disk. Default 0 = false, nonzero = true }; diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 37665542f614f..f649576658d00 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -665,11 +665,11 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr Ort::Status status{ort_api.ReadOpAttr(&ort_attr, attr_type, nullptr, 0, &total_attr_bytes)}; std::string* str = attr_proto.mutable_s(); - str->resize(total_attr_bytes, '\0'); + str->resize(total_attr_bytes); ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.ReadOpAttr(&ort_attr, attr_type, str->data(), total_attr_bytes, &total_attr_bytes)); - str->resize(total_attr_bytes - 1); // remove extra ending terminating '\0' character. + str->resize(total_attr_bytes); break; } case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 9f3a9eeabff6b..7e49275e59b8b 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -15,6 +15,7 @@ #include "core/common/status.h" #include "core/framework/allocator.h" #include "core/framework/execution_provider.h" +#include "core/framework/data_transfer_manager.h" #include "core/platform/device_discovery.h" #include "core/platform/threadpool.h" @@ -140,6 +141,10 @@ class Environment { OrtDeviceMemoryType mem_type, OrtAllocatorType allocator_type, const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator); Status ReleaseSharedAllocator(const OrtEpDevice& ep_device, OrtDeviceMemoryType mem_type); + + const DataTransferManager& GetDataTransferManager() const { + return data_transfer_mgr_; + } #endif // !defined(ORT_MINIMAL_BUILD) // return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator @@ -185,6 +190,23 @@ class Environment { using OrtAllocatorUniquePtr = std::unique_ptr>; + // if the user calls CreateSharedAllocator and wraps the plugin EP's allocator with an arena we end up with + // OrtAllocator from EP -> wrapped in IAllocatorImplWrappingOrtAllocator -> inside a BFCArena IAllocator. + // we can put that in shared_allocators_ for sessions to use, but to have an OrtAllocator available in + // shared_ort_allocators_ that can be used outside of a session we need to additionally wrap that in an + // OrtAllocatorImplWrappingIAllocator. way too many levels of indirection but that is what it is currently. + // we need something to own that final OrtAllocator, so we add it to arena_ort_allocators_. + // + // TODO: we could split out the BFCArena implementation so it can be plugged into either an IAllocator + // or an OrtAllocator instance to reduce the indirection a little. + // with that we get an OrtAllocator from the EP, wrap it with an OrtAllocator based BFCArena, and wrap that with the + // IAllocatorImplWrappingOrtAllocator which takes ownership of the OrtAllocator and is in shared_allocators_. + // + // Alternatively we can disable wrapping an EP's allocator with a BFCArena and say the EP should provide the arena + // implementation directly. They're free to copy BFCArena as it came from TF originally. Or we could provide a + // cut-and-paste BFCArena implementation that works using the EP API that can be included in the EP source. + std::unordered_map> arena_ort_allocators_; + #if !defined(ORT_MINIMAL_BUILD) // register EPs that are built into the ORT binary so they can take part in AutoEP selection // added to ep_libraries @@ -207,7 +229,9 @@ class Environment { std::unique_ptr library; std::vector> execution_devices; - std::vector internal_factories; // factories that can create IExecutionProvider instances + std::vector factories; + std::vector internal_factories; // factories that can create IExecutionProvider instances + std::vector data_transfers; // data transfer instances for this EP. private: EpInfo() = default; @@ -223,6 +247,9 @@ class Environment { // lookup set for internal EPs so we can create an IExecutionProvider directly std::unordered_set internal_ep_factories_; + + DataTransferManager data_transfer_mgr_; // plugin EP IDataTransfer instances + #endif // !defined(ORT_MINIMAL_BUILD) }; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 82e782112974f..aff2f1860d8e5 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -64,6 +64,7 @@ extern "C" { #define _Outptr_result_maybenull_ #define _Outptr_result_maybenull_z_ #define _In_reads_(X) +#define _In_reads_opt_ #define _Inout_updates_(X) #define _Out_writes_(X) #define _Out_writes_opt_(X) @@ -322,6 +323,7 @@ ORT_RUNTIME_CLASS(ModelCompilationOptions); ORT_RUNTIME_CLASS(HardwareDevice); ORT_RUNTIME_CLASS(EpDevice); ORT_RUNTIME_CLASS(KeyValuePairs); +ORT_RUNTIME_CLASS(SyncStream); // Opaque class to create an onnxruntime::Stream. #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -426,10 +428,14 @@ typedef enum OrtAllocatorType { */ // Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc typedef enum OrtMemType { - OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider - OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeDefault = 0, ///< The default allocator for execution provider + /// Any CPU memory used by non-CPU execution provider + OrtMemTypeCPUInput = -2, + /// CPU accessible memory outputted by non-CPU execution provider, i.e. HOST_ACCESSIBLE + OrtMemTypeCPUOutput = -1, + /// CPU accessible memory allocated by non-CPU execution provider, i.e. HOST_ACCESSIBLE + OrtMemTypeCPU = OrtMemTypeCPUOutput, + /// The default allocator for execution provider + OrtMemTypeDefault = 0, } OrtMemType; /** \brief This matches OrtDevice::MemoryType values */ @@ -1743,7 +1749,7 @@ struct OrtApi { */ ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); - /** \brief Get the id from ::OrtMemoryInfo + /** \brief Get the device id from ::OrtMemoryInfo */ ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); @@ -3667,7 +3673,8 @@ struct OrtApi { * * \param[in] name Name of the attribute * \param[in] data Data content of the attribute - * \param[in] len Number of elements if data represents an array (e.g., ORT_OP_ATTR_INTS). Otherwise, set to 1. + * \param[in] len Number of bytes stored in data for ORT_OP_ATTR_STRING. + Number of elements if data represents an array (e.g., ORT_OP_ATTR_INTS). Otherwise, set to 1. * \param[in] type Data type * \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr * @@ -5383,10 +5390,32 @@ struct OrtApi { * \since Version 1.23 */ ORT_API2_STATUS(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, - _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType allocator_type, _Outptr_ OrtMemoryInfo** out); + /** \brief Get the device memory type from ::OrtMemoryInfo + * + * \param[in] ptr The OrtMemoryInfo instance to query. + * \param[out] out The device memory type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out); + + /** \brief Get the vendor id from ::OrtMemoryInfo + * + * \param[in] ptr The OrtMemoryInfo instance to query. + * \param[out] out The vendor id. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out); + /// \name OrtValueInfo /// @{ @@ -6067,11 +6096,14 @@ struct OrtApi { /** \brief Get the OrtMemoryInfo for the device. * * \param[in] ep_device The OrtEpDevice instance to query. - * \return A pointer to the OrtMemoryInfo for the device. + * \param[in] memory_type The memory type to return. + * \return A pointer to the OrtMemoryInfo for the device. This may be nullptr if not set. + * If memory_type is OrtDeviceMemoryType_DEFAULT and nullptr is returned the EP uses CPU memory. * * \since Version 1.23 */ - ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device); + ORT_API_T(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType memory_type); /** \brief Create/replace a shared allocator for the OrtEpDevice in the OrtEnv. * @@ -6163,6 +6195,141 @@ struct OrtApi { * \since Version 1.23. */ ORT_API2_STATUS(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); + + /** \brief Get the OrtMemoryInfo for each input of the session. + * + * The memory info can be used to determine where the input tensors are required. + * + * The session must be fully initialized before calling this function as the input locations are not known until + * this has occurred. + * + * \param[in] session The OrtSession instance. + * \param[out] inputs_memory_info Pre-allocated array of size `num_inputs` that will be filled with the + * OrtMemoryInfo* value for each input. + * The order is the same as returned by SessionGetInputName. + * \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(SessionGetMemoryInfoForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info, + _In_ size_t num_inputs); + + /** \brief Get the OrtMemoryInfo for each output of the session. + * + * The memory info can be used to determine the device the output tensors are produced on. + * The user can pre-allocate an OrtValue using this information or use IOBinding to keep the data on the device. + * ORT will copy the output to CPU otherwise. + * + * The session must be fully initialized before calling this function as the output locations are not known until + * this has occurred. + * + * \param[in] session The OrtSession instance. + * \param[out] outputs_memory_info Pre-allocated array of size `num_outputs` that will be filled with + * OrtMemoryInfo* values for each output. + * The order is the same as returned by SessionGetOutputName. + * \param[in] num_outputs The number of outputs in the session. Must match SessionGetOutputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info, + _In_ size_t num_outputs); + + /** \brief Get the OrtEpDevice (if available) for each input of the session. + * + * An OrtEpDevice will be available if auto EP selection is enabled by calling + * SessionOptionsSetEpSelectionPolicy or SessionOptionsSetEpSelectionPolicyDelegate, + * or if the OrtEpDevice was manually added to the session using SessionOptionsAppendExecutionProvider_V2. + * + * If an OrtEpDevice is not available for the input a nullptr is returned. + * + * The returned OrtEpDevice can be used to create an OrtSyncStream via CreateSyncStreamForEpDevice to asynchronously + * provide input to the inference session Run. + * + * The session must be fully initialized before calling this function as the assigned EPs are not known until + * this has occurred. + * + * \param[in] session The OrtSession instance. + * \param[out] inputs_ep_devices Pre-allocated array of size `num_inputs` that will be filled with + * OrtEpDevice* values for each input. + * The order is the same as returned by SessionGetInputName. + * \param[in] num_inputs The number of inputs in the session. Must match SessionGetInputCount. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(SessionGetEpDeviceForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtEpDevice** inputs_ep_devices, + _In_ size_t num_inputs); + + /** \brief Create an OrtSyncStream for the given OrtEpDevice. + * + * The OrtSyncStream can be used to enable asynchronous operations. + * e.g. async usage of CopyTensors to provide input to an OrtSession Run call. + * + * An error code of ORT_NOT_IMPLEMENTED will be returned if the EP does not support OrtSyncStream. + * + * \param[in] ep_device The OrtEpDevice instance to create the sync stream for. + * \param[in] stream_options Options for OrtSyncStream creation. May be nullptr. + * \param[out] stream Output parameter set to the created OrtSyncStream instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStream** stream); + + /** \brief Get the native handle of the sync stream. + * + * This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams. + * + * \param[in] stream The OrtSyncStream instance to get the handle from. + * + * \returns The native handle of the stream. + * + * \since Version 1.23 + */ + ORT_API_T(void*, SyncStream_GetHandle, _In_ OrtSyncStream* stream); + + ORT_CLASS_RELEASE(SyncStream); + + /** \brief Copy OrtValue instances containing Tensors between devices. + * + * The overall copy must be between a single source device and a single destination device. i.e. + * - all src_tensors must have matching OrtMemoryInfo, + * - all dst_tensors must have matching OrtMemoryInfo. + * + * OrtValue instances can be created by: + * - Use GetSharedAllocator to get the shared allocator for the OrtMemoryInfo if you need to allocate memory + * on the device. + * - Use CreateTensorAsOrtValue, CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue + * to create an OrtValue containing a tensor depending on whether you have existing data or not, and whether + * you want ORT to free the existing data once it is done with the OrtValue. + * + * \param[in] env The OrtEnv instance to use. The data transfer implementation is provided by an execution provider + * that is registered in this OrtEnv. + * \param[in] src_tensors Array of OrtValue instances containing the source tensors to copy. + * \param[in] dst_tensors Array of OrtValue instances to copy the source tensors to. + * \param[in] stream Optional OrtSyncStream that can be used to perform the copy asynchronously. May be nullptr. + * \param[in] num_tensors The number of tensors to copy. The size of `src_tensors` and `dst_tensors` must match. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(CopyTensors, _In_ const OrtEnv* env, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index ba5d53e6c2dd0..18a674f276899 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2627,7 +2627,7 @@ inline std::string ShapeInferContext::GetAttrString(const char* attr_name) { if (status) { std::vector chars(out, '\0'); Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRING, chars.data(), out, &out)); - return {chars.data()}; + return std::string{chars.data(), out}; } else { return {c}; } diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index 5d00ce4940d02..0dc4105ebe855 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -13,12 +13,12 @@ ORT_RUNTIME_CLASS(EpGraphSupportInfo); ORT_RUNTIME_CLASS(MemoryDevice); // opaque class to wrap onnxruntime::OrtDevice ORT_RUNTIME_CLASS(NodeComputeContext); -// Opaque class to create an onnxruntime::Stream. Will be filled out in separate PR. -// Adding here for OrtDataTransferImpl as the stream type is required by the IDataTransfer API. -ORT_RUNTIME_CLASS(SyncStream); +ORT_RUNTIME_CLASS(DataTransferImpl); +ORT_RUNTIME_CLASS(SyncNotificationImpl); +ORT_RUNTIME_CLASS(SyncStreamImpl); // struct that an EP implements for IDataTransfer to copy between devices it uses and CPU -typedef struct OrtDataTransferImpl { +struct OrtDataTransferImpl { uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION /** \brief Release the OrtDataTransferImpl instance. @@ -30,7 +30,7 @@ typedef struct OrtDataTransferImpl { * * \since Version 1.23. */ - ORT_API_T(void, Release, _In_ void* this_ptr); + ORT_API_T(void, Release, _In_ OrtDataTransferImpl* this_ptr); /** \brief Check if the implementation can copy between the source and destination memory devices. * @@ -41,7 +41,7 @@ typedef struct OrtDataTransferImpl { * * \since Version 1.23. */ - ORT_API_T(bool, CanCopy, _In_ void* this_ptr, + ORT_API_T(bool, CanCopy, _In_ const OrtDataTransferImpl* this_ptr, _In_ const OrtMemoryDevice* src_memory_device, _In_ const OrtMemoryDevice* dst_memory_device); /** \brief Copy tensors from src_tensors to dst_tensors using the provided streams. @@ -60,12 +60,119 @@ typedef struct OrtDataTransferImpl { * * \since Version 1.23. */ - ORT_API2_STATUS(CopyTensors, _In_ void* this_ptr, + ORT_API2_STATUS(CopyTensors, _In_ OrtDataTransferImpl* this_ptr, _In_reads_(num_tensors) const OrtValue** src_tensors, _In_reads_(num_tensors) OrtValue** dst_tensors, _In_reads_(num_tensors) OrtSyncStream** streams, _In_ size_t num_tensors); -} OrtDataTransferImpl; +}; + +/** \brief Struct that an EP implements for Stream Notifications. + * + * \since Version 1.23. + */ +struct OrtSyncNotificationImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Release the OrtSyncNotificationImpl instance. + * + * This is called by ORT when the OrtSyncNotificationImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API_T(void, Release, _In_ OrtSyncNotificationImpl* this_ptr); + + /** \brief Called by ORT to activate the notification. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Activate, _In_ OrtSyncNotificationImpl* this_ptr); + + /** \brief Wait for a device to device operation to complete. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * \param[in] stream The OrtSyncStream instance that will wait on this notification to be activated. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(WaitOnDevice, _In_ OrtSyncNotificationImpl* this_ptr, _In_ OrtSyncStream* consumer_stream); + + /** \brief Wait for a device to host operation to complete. + * + * \param[in] this_ptr Pointer to the OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(WaitOnHost, _In_ OrtSyncNotificationImpl* this_ptr); +}; + +/** \brief Struct that an EP implements if it wishes to implement Stream support. + * + * This struct provides the overrides for onnxruntime::Stream's virtual methods. + * + * \since Version 1.23. + */ +struct OrtSyncStreamImpl { + uint32_t ort_version_supported; ///< Must be initialized to ORT_API_VERSION + + /** \brief Release the OrtSyncStreamImpl instance. + * + * This is called by ORT when the OrtSyncStreamImpl instance is no longer needed. + * The implementation should release any resources held by the instance. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * + * \since Version 1.23. + */ + ORT_API_T(void, Release, _In_ OrtSyncStreamImpl* this_ptr); + + /** \brief Get the handle of the stream. + * + * This returns the native handle for the stream. e.g. cudaStream_t for CUDA streams. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * \return The handle of the stream. + * + * \since Version 1.23. + */ + ORT_API_T(void*, GetHandle, _In_ OrtSyncStreamImpl* this_ptr); + + /** \brief Create an OrtSyncNotificationImpl for the OrtSyncStreamImpl instance. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance + * \param[out] notification The new OrtSyncNotificationImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateNotification, _In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification); + + /** \brief Flush the stream. + * + * This is called by ORT to flush the stream, ensuring that all operations submitted to the stream are completed. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Flush, _In_ OrtSyncStreamImpl* this_ptr); + + /** \brief Notify the stream that a session run has ended. + * + * This is called by ORT to notify the stream that a session run has ended, allowing the stream to perform any + * necessary cleanup or finalization. + * + * \param[in] this_ptr Pointer to the OrtSyncStreamImpl instance. + * + * \since Version 1.23. + */ + ORT_API2_STATUS(OnSessionRunEnd, _In_ OrtSyncStreamImpl* this_ptr); +}; struct OrtNodeFusionOptions; typedef struct OrtNodeFusionOptions OrtNodeFusionOptions; @@ -522,6 +629,45 @@ struct OrtEp { * \since Version 1.23. */ ORT_API2_STATUS(OnRunEnd, _In_ OrtEp* this_ptr, _In_ const OrtRunOptions* run_options, _In_ bool sync_stream); + + /** \brief Create an OrtAllocator for the given OrtMemoryInfo for an OrtSession. + * + * The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo. + * Any allocator specific options should be read from the session options. + * + * If nullptr OrtEpFactory::CreateAllocator will be used. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr. + * \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateAllocator, _In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator); + + /** \brief Create a synchronization stream for the given memory device for an OrtSession. + * + * This is used to create a synchronization stream for the execution provider and is used to synchronize + * operations on the device during model execution. + * Any stream specific options should be read from the session options. + * + * If nullptr OrtEpFactory::CreateSyncStreamForDevice will be used. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for. + * \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream); }; /** \brief The function signature that ORT will call to create OrtEpFactory instances. @@ -683,16 +829,15 @@ struct OrtEpFactory { */ ORT_API_T(const char*, GetVersion, _In_ const OrtEpFactory* this_ptr); - /** \brief Create an OrtAllocator for the given OrtMemoryInfo. + /** \brief Create an OrtAllocator that can be shared across sessions for the given OrtMemoryInfo. * - * This is used to create an allocator that an execution provider requires. The factory that creates the EP is - * responsible for providing the required allocators. + * The factory that creates the EP is responsible for providing the allocators required by the EP. * The OrtMemoryInfo instance will match one of the values set in the OrtEpDevice using EpDevice_AddAllocatorInfo. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] memory_info The OrtMemoryInfo to create the allocator for. + * \param[in] memory_info The OrtMemoryInfo to create the allocator for. May be nullptr. * \param[in] allocator_options Optional key-value pairs for allocator options, can be nullptr. - * \param[out] allocator The created OrtAllocator instance. + * \param[out] allocator The created OrtAllocator instance. Set to nullptr if the default CPU allocator is used. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -700,8 +845,8 @@ struct OrtEpFactory { */ ORT_API2_STATUS(CreateAllocator, _In_ OrtEpFactory* this_ptr, _In_ const OrtMemoryInfo* memory_info, - _In_ const OrtKeyValuePairs* allocator_options, - _Outptr_ OrtAllocator** allocator); + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_result_maybenull_ OrtAllocator** allocator); /** \brief Release an OrtAllocator created by the factory. * @@ -715,13 +860,42 @@ struct OrtEpFactory { * that the execution provider supports. * * \param[in] this_ptr The OrtEpFactory instance. - * \param[out] data_transfer The created OrtDataTransferImpl instance. + * \param[out] data_transfer The created OrtDataTransferImpl instance. Set to nullptr if not required. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(CreateDataTransfer, _In_ OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer); + + /** \brief Check if execution providers created by the factory are stream aware. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return True if the factory creates execution providers that are stream aware and it implements CreateSyncStreamForDevice. + * + * \since Version 1.23. + */ + ORT_API_T(bool, IsStreamAware, _In_ const OrtEpFactory* this_ptr); + + /** \brief Create a synchronization stream for the given memory device. + * + * This is used to create a synchronization stream for the memory device that can be used for operations outside of + * a session. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] memory_device The OrtMemoryDevice to create the synchronization stream for. + * \param[in] stream_options Options for stream creation. May be nullptr. + * \param[out] stream The created OrtSyncStreamImpl instance. nullptr if the execution provider is not stream aware. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ - ORT_API2_STATUS(CreateDataTransfer, _In_ OrtEpFactory* this_ptr, _Outptr_ OrtDataTransferImpl** data_transfer); + ORT_API2_STATUS(CreateSyncStreamForDevice, _In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStreamImpl** stream); }; #ifdef __cplusplus diff --git a/js/react_native/e2e/package-lock.json b/js/react_native/e2e/package-lock.json index fff331486dfc1..cf033192b3ba5 100644 --- a/js/react_native/e2e/package-lock.json +++ b/js/react_native/e2e/package-lock.json @@ -5291,16 +5291,16 @@ } }, "node_modules/compression": { - "version": "1.8.0", - "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.0.tgz", - "integrity": "sha512-k6WLKfunuqCYD3t6AsuPGvQWaKwuLLh2/xHNcX4qE+vIfDNXpSqnrhwA7O53R7WVQUnt8dVAIW+YHr7xTgOgGA==", + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/compression/-/compression-1.8.1.tgz", + "integrity": "sha512-9mAqGPHLakhCLeNyxPkK4xVo746zQ/czLH1Ky+vkitMnWfWZps8r0qXuwhwizagCRttsL4lfG4pIOvaWLpAP0w==", "license": "MIT", "dependencies": { "bytes": "3.1.2", "compressible": "~2.0.18", "debug": "2.6.9", "negotiator": "~0.6.4", - "on-headers": "~1.0.2", + "on-headers": "~1.1.0", "safe-buffer": "5.2.1", "vary": "~1.1.2" }, @@ -10220,9 +10220,9 @@ } }, "node_modules/on-headers": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.0.2.tgz", - "integrity": "sha512-pZAE+FJLoyITytdqK0U5s+FIpjN0JP3OzFi/u8Rx+EV5/W+JTWGXG8xFzevE7AjBfDqHv/8vL8qQsIhHnqRkrA==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/on-headers/-/on-headers-1.1.0.tgz", + "integrity": "sha512-737ZY3yNnXy37FHkQxPzt4UZ2UWPWiCZWLvFZ4fu5cueciegX0zGPnrlY6bwRg4FdQOe9YU8MkmJwGhoMybl8A==", "license": "MIT", "engines": { "node": ">= 0.8" diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h index f01ce985658aa..46d3e7e675e85 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention_helper.h @@ -379,11 +379,6 @@ Status CheckCustomAttentionInputs(const T* position_ids, } if (head_sink != nullptr) { - if (parameters.use_smooth_softmax) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "head_sink should not be provided when use_smooth_softmax is true."); - } - const auto& head_sink_shape = head_sink->Shape(); if (head_sink_shape.NumDimensions() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "head_sink must be a 1D tensor"); diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 691391ccef0d0..e08d120750a40 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -156,6 +156,7 @@ struct GroupQueryAttentionData { int* seqlens_k = nullptr; const T* cos_cache = nullptr; const T* sin_cache = nullptr; + const T* head_sink = nullptr; // Flash buffers T* softmax_lse = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h index c24bf88fa729b..09ead61e7d80d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -123,6 +123,7 @@ struct Flash_fwd_params : public Qkv_params { bool is_rotary_interleaved = false; + void* __restrict__ head_sink_ptr = nullptr; bool smooth_softmax = false; int num_splits = 0; // For split-KV version diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index b0241c26aafc6..76704b5b29fcd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -28,6 +28,7 @@ void set_params_fprop(Flash_fwd_params& params, void* q, void* k, void* v, + void* head_sink, void* out, void* cu_seqlens_q_d, void* cu_seqlens_k_d, @@ -50,7 +51,9 @@ void set_params_fprop(Flash_fwd_params& params, params.o_ptr = out; params.is_bf16 = is_bf16; + params.smooth_softmax = use_smooth_softmax; + params.head_sink_ptr = head_sink; // All stride are in elements, not bytes. if (kv_bsnh) { @@ -297,6 +300,8 @@ Status mha_fwd(const cudaDeviceProp& dprops, const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + constexpr void* head_sink = nullptr; + Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -304,7 +309,7 @@ Status mha_fwd(const cudaDeviceProp& dprops, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - q, k, v, out, + q, k, v, head_sink, out, /*cu_seqlens_q*/ nullptr, /*cu_seqlens_k*/ nullptr, /*seqused_k=*/nullptr, @@ -376,6 +381,8 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); const bool paged_KV = block_table != nullptr; + constexpr void* head_sink = nullptr; + Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -383,7 +390,7 @@ Status mha_varlen_fwd(const cudaDeviceProp& dprops, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - q, k, v, out, + q, k, v, head_sink, out, cu_seqlens_q, cu_seqlens_k, seqused_k, @@ -443,6 +450,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* seqlens_k_, // batch_size void* rotary_cos, // seqlen_ro x (rotary_dim / 2) void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* head_sink, // num_heads int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, @@ -480,7 +488,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, seqlen_q_rounded, seqlen_k_rounded, num_heads, num_heads_k, head_size, head_size_rounded, - q, kcache, vcache, out, + q, kcache, vcache, head_sink, out, /*cu_seqlens_q_d=*/nullptr, /*cu_seqlens_k_d=*/nullptr, /*seqused_k=*/nullptr, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index e28e38ea3ed93..e29dd7c1c231d 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -98,6 +98,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* seqlens_k_, // batch_size void* rotary_cos, // seqlen_ro x (rotary_dim / 2) void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* head_sink, // num_heads int* block_table, // batch_size x max_num_blocks_per_seq int batch_size, int num_heads, diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h index 4110e715c4391..91104b8c3dfe0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -369,8 +369,10 @@ inline __device__ void compute_attn_1rowblock(const Params& params, const int bi } // Epilogue - - Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, params.smooth_softmax); + float sink = (params.head_sink_ptr != nullptr) + ? reinterpret_cast(params.head_sink_ptr)[bidh] + : (params.smooth_softmax ? 0.0f : -kInfinity); + Tensor lse = softmax.template normalize_softmax_lse<>(acc_o, params.scale_softmax, sink); // Convert acc_o from fp32 to fp16/bf16 Tensor rO = flash::convert_type(acc_o); @@ -928,8 +930,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params& params, cons } // Epilogue - - Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, params.smooth_softmax); + float sink = (params.head_sink_ptr != nullptr) + ? reinterpret_cast(params.head_sink_ptr)[bidh] + : (params.smooth_softmax ? 0.0f : -std::numeric_limits::infinity()); + Tensor lse = softmax.template normalize_softmax_lse(acc_o, params.scale_softmax, sink); Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) // Partition sO to match the accumulator partitioning diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h index 7fe506e01a9b9..c7a8476f5beae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -18,6 +18,7 @@ namespace flash { using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// +constexpr float kInfinity = std::numeric_limits::infinity(); template __device__ __forceinline__ void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { @@ -72,9 +73,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso // If max is -inf, then all elements must have been -inf (possibly due to masking). // We don't want (-inf - (-inf)) since that would give NaN. // If we don't have float around M_LOG2E the multiplication is done in fp64. - const float max_scaled = max(mi) == -std::numeric_limits::infinity() - ? 0.f - : max(mi) * (Scale_max ? scale : float(M_LOG2E)); + const float max_scaled = max(mi) == -kInfinity ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); #pragma unroll for (int ni = 0; ni < size<1>(tensor); ++ni) { // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - @@ -85,38 +84,6 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor& tenso } } -// Apply the exp to all the elements. -template -__forceinline__ __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { - static_assert(Layout0::rank == 2, "Only support 2D Tensor"); - static_assert(Layout1::rank == 1, "Only support 1D Tensor"); - CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); -#pragma unroll - for (int mi = 0; mi < size<0>(tensor); ++mi) { - MaxOp max_op; - max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); -#pragma unroll - for (int ni = 1; ni < size<1>(tensor); ni++) { - max(mi) = max_op(max(mi), tensor(mi, ni)); - } - max(mi) = Allreduce<4>::run(max(mi), max_op); - // If max is -inf, then all elements must have been -inf (possibly due to masking). - // We don't want (-inf - (-inf)) since that would give NaN. - const float max_scaled = max(mi) == -std::numeric_limits::infinity() ? 0.f : max(mi) * scale; - sum(mi) = 0; -#pragma unroll - for (int ni = 0; ni < size<1>(tensor); ++ni) { - // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - - // max * log_2(e)) This allows the compiler to use the ffma - // instruction instead of fadd and fmul separately. - tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); - sum(mi) += tensor(mi, ni); - } - SumOp sum_op; - sum(mi) = Allreduce<4>::run(sum(mi), sum_op); - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// template @@ -143,10 +110,10 @@ struct Softmax { Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); #pragma unroll - for (int mi = 0; mi < size(row_max); ++mi) { + for (int mi = 0; mi < size<0>(row_max); ++mi) { float scores_max_cur = !Check_inf ? row_max(mi) - : (row_max(mi) == -std::numeric_limits::infinity() ? 0.0f : row_max(mi)); + : (row_max(mi) == -kInfinity ? 0.0f : row_max(mi)); float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); row_sum(mi) *= scores_scale; #pragma unroll @@ -154,6 +121,7 @@ struct Softmax { acc_o_rowcol(mi, ni) *= scores_scale; } } + flash::scale_apply_exp2(scores, row_max, softmax_scale_log2); // We don't do the reduce across threads here since we don't need to use the row_sum. // We do that reduce at the end when we need to normalize the softmax. @@ -162,27 +130,62 @@ struct Softmax { }; template - __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, float softmax_scale, bool smooth_softmax) { + __forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0& acc_o, + float softmax_scale, + float sink) { // IMPORTANT: sink is a pre-scaled logit + SumOp sum_op; quad_allreduce_(row_sum, row_sum, sum_op); TensorT lse = make_fragment_like(row_sum); Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + + const bool use_sink = (sink != -kInfinity); + #pragma unroll for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) { - float sum = smooth_softmax ? row_sum(mi) + expf(-row_max(mi) * softmax_scale) : row_sum(mi); - float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + float sum = row_sum(mi); + float max_unscaled = row_max(mi); // Max of the qk scores, NOT scaled. + + if (use_sink) { + const float max_scaled = (max_unscaled == -kInfinity) + ? -kInfinity + : max_unscaled * softmax_scale; + + const float true_max_scaled = max(max_scaled, sink); + + // Rescale the intermediate the output accumulator (acc_o) and sum. + // They were calculated relative to `max_scaled` and must be + // rescaled to be relative to `true_max_scaled`. + const float rescale_factor = expf(max_scaled - true_max_scaled); + +#pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= rescale_factor; + } + + sum *= rescale_factor; + + // Add the sink to the sum. + sum += expf(sink - true_max_scaled); + + // The unscaled max that reflects the sink. It is used for the below LSE calculation. + max_unscaled = true_max_scaled / softmax_scale; + } + lse(mi) = (sum == 0.f || sum != sum) - ? (Split ? -std::numeric_limits::infinity() : std::numeric_limits::infinity()) - : row_max(mi) * softmax_scale + __logf(sum); - float scale = inv_sum; + ? (Split ? -kInfinity : kInfinity) + : max_unscaled * softmax_scale + __logf(sum); + + float inv_sum = (sum == 0.f || !isfinite(sum)) ? 1.f : 1.f / sum; #pragma unroll for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { - acc_o_rowcol(mi, ni) *= scale; + acc_o_rowcol(mi, ni) *= inv_sum; } } + return lse; - }; + } }; } // namespace flash diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 9cb93cbcd3f32..e5d2434a31808 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -78,6 +78,14 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* total_seqlen = context->Input(6); const Tensor* cos_cache = context->Input(7); const Tensor* sin_cache = context->Input(8); + const Tensor* position_ids = context->Input(9); + const Tensor* attention_bias = context->Input(10); + const Tensor* head_sink = context->Input(11); + + if (position_ids != nullptr || attention_bias != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "position_ids and attention_bias are not supported in GroupQueryAttention cuda kernel."); + } auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -99,12 +107,17 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, softcap_, device_prop.maxThreadsPerBlock)); + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckCustomAttentionInputs(position_ids, + attention_bias, + head_sink, + parameters)); parameters.local_window_size = local_window_size_; parameters.is_unidirectional = is_unidirectional_; - parameters.use_smooth_softmax = use_smooth_softmax_; + parameters.use_smooth_softmax = use_smooth_softmax_ || head_sink != nullptr; parameters.zeros_count = kZerosCount; parameters.zero_ptr = zeros_.get(); - // parameters.left_padding = left_padding_; + int sequence_length = parameters.sequence_length; parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; @@ -276,6 +289,10 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.sin_cache = reinterpret_cast(sin_cache->Data()); } + if (head_sink != nullptr) { + data.head_sink = reinterpret_cast(head_sink->Data()); + } + cublasHandle_t cublas = GetCublasHandle(context); return QkvToContext( diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index bb450e476d5ba..19d496569f79e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -460,11 +460,18 @@ Status FlashAttention( void* present_value = reinterpret_cast(const_cast(data.present_value)); void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + void* head_sink = reinterpret_cast(const_cast(data.head_sink)); bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("Q", reinterpret_cast(query), batch_size, sequence_length, num_heads, head_size); + DUMP_TENSOR("K", reinterpret_cast(present_key), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + DUMP_TENSOR("V", reinterpret_cast(present_value), batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( device_prop, stream, query, present_key, present_value, key, value, data.output, - reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, head_sink, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, @@ -475,7 +482,6 @@ Status FlashAttention( // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); // } - DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); @@ -680,6 +686,11 @@ template Status QkvToContext( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); +template Status LaunchUnpackQKV( + const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, + cudaStream_t stream, const int max_threads_per_block); + template struct GroupQueryAttentionData; template Status QkvToContext( @@ -689,11 +700,6 @@ template Status QkvToContext( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); -template Status LaunchUnpackQKV( - const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, - const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, - cudaStream_t stream, const int max_threads_per_block); - template Status LaunchUnpackQKV( const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.cc b/onnxruntime/contrib_ops/webgpu/bert/attention.cc index 55bcf42f2f04b..dbea308e0b08c 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.cc @@ -98,6 +98,7 @@ Status AttentionProbsProgram::GenerateShaderCode(ShaderHelper& shader) const { shader.AdditionalImplementation() << "var tileQ: array;\n" << "var tileK: array;\n" << "alias f32_val_t = " << (components_ == 4 ? "vec4" : (components_ == 2 ? "vec2" : "f32")) << ";\n"; + shader.MainFunctionBody() << "// x holds the N and y holds the M\n" << "let m = u32(workgroup_idx / uniforms.num_total_seq_length_tile) % uniforms.num_seq_length_tile * TILE_SIZE;\n" << "let n = (workgroup_idx % uniforms.num_total_seq_length_tile) * TILE_SIZE;\n" @@ -224,6 +225,8 @@ Status ComputeAttentionProbs(onnxruntime::webgpu::ComputeContext& context, int o } Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { + bool has_sliding_window = local_window_size_ != -1; + if (has_seqlen_k_) { shader.AddInput("seqlen_k", ShaderUsage::UseUniform); } @@ -241,15 +244,33 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { std::ostringstream oss; InitVarStub(oss, has_seqlen_k_); shader.MainFunctionBody() << oss.str() - << "let local_offset = local_idx * uniforms.elements_per_thread;\n" - << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n" << "let seq_causal_length = " << (has_seqlen_k_ ? "past_sequence_length + workgroup_idx % sequence_length + 1" : "uniforms.total_sequence_length_comp") << ";\n" - << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " thread_max_vector = max(f32_val_t(x[offset + i]), thread_max_vector);\n" - << "}\n" - << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" - << "workgroupBarrier();\n"; + << "let local_offset = local_idx * uniforms.elements_per_thread;\n" + << "let offset = workgroup_idx * uniforms.total_sequence_length_comp + local_offset;\n"; + if (has_sliding_window) { + // Sliding window + shader.MainFunctionBody() + << "let should_apply_local_window = uniforms.local_window_size >= 0 && seq_causal_length > uniforms.local_window_size + 1;\n" + << "let start_offset = select(0, seq_causal_length - uniforms.local_window_size, should_apply_local_window);\n" + << "let effective_seq_length = select(seq_causal_length, uniforms.local_window_size, should_apply_local_window);\n"; + } else { + // No sliding window: we keep the code for sliding window in the shader but + // using const for start_offset and should_apply_local_window will make the compiler optimize it out. + shader.MainFunctionBody() + << "const start_offset = 0;\n" + << "const should_apply_local_window = false;\n" + << "let effective_seq_length = seq_causal_length;\n"; + } + shader.MainFunctionBody() + << "var thread_max_vector = f32_val_t(-3.402823e+38f);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" + << " thread_max_vector = max(f32_val_t(x[offset + i + start_offset]), thread_max_vector);\n" + << " }\n" + << "}\n" + << "thread_max[local_idx] = " << (components_ == 4 ? "max(max(thread_max_vector.x, thread_max_vector.y), max(thread_max_vector.z, thread_max_vector.w))" : (components_ == 2 ? "max(thread_max_vector.x, thread_max_vector.y)" : "thread_max_vector")) << ";\n" + << "workgroupBarrier();\n"; if (has_head_sink_) { // Handle head sink @@ -265,8 +286,11 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { << " max_value = max(thread_max[i], max_value);\n" << "}\n" << "var sum_vector = f32_val_t(0);\n" - << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " sum_vector += exp(f32_val_t(x[offset + i]) - max_value);\n" + << "for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" + << " sum_vector += exp(f32_val_t(x[offset + i + start_offset]) - max_value);\n" + << " }\n" << "}\n" << "thread_sum[local_idx] = " << (components_ == 4 ? "sum_vector.x + sum_vector.y + sum_vector.z + sum_vector.w" : (components_ == 2 ? "sum_vector.x + sum_vector.y" : "sum_vector")) << ";\n" << "workgroupBarrier();\n" @@ -282,15 +306,33 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } shader.MainFunctionBody() << "if (sum == 0) {\n" - << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " x[offset + i] = x_value_t(x_element_t(1.0)/x_element_t(seq_causal_length));\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " if (actual_pos < seq_causal_length) {\n" + << " x[offset + i + start_offset] = x_value_t(x_element_t(1.0)/x_element_t(effective_seq_length));\n" + << " }\n" << " }\n" << "} else {\n" + << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < effective_seq_length; i++) {\n" + << " let actual_pos = local_offset + i + start_offset;\n" + << " let pos = offset + i + start_offset;\n" + << " if (!should_apply_local_window || actual_pos < seq_causal_length) {\n" + << " var f32input = f32_val_t(x[pos]);\n" + << " x[pos] = x_value_t(exp(f32input - max_value) / sum);\n" + << " }\n" + << " }\n" + << "}\n"; + + // zero out elements outsize the sliding window + shader.MainFunctionBody() << "if (should_apply_local_window) {\n" << " for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < seq_causal_length; i++) {\n" - << " var f32input = f32_val_t(x[offset + i]);\n" - << " x[offset + i] = x_value_t(exp(f32input - max_value) / sum);\n" + << " let global_pos = i + local_offset;\n" + << " if (global_pos < start_offset) {\n" + << " x[offset + i] = x_value_t(x_element_t(0));\n" + << " }\n" << " }\n" << "}\n"; + if (has_seqlen_k_) { shader.MainFunctionBody() << "for (var total_seq_id: u32 = seq_causal_length; total_seq_id + local_offset < uniforms.total_sequence_length_comp; total_seq_id++) {\n" << " x[offset + total_seq_id] = x_value_t(x_element_t(0));\n" @@ -301,7 +343,7 @@ Status InPlaceSoftmaxProgram::GenerateShaderCode(ShaderHelper& shader) const { } Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tensor* probs, int32_t batch_size, int32_t num_heads, int32_t past_sequence_length, int32_t sequence_length, int32_t total_sequence_length, - const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink) { + const Tensor* seqlen_k, bool is_first_prompt, bool use_smooth_softmax, const Tensor* head_sink, int local_window_size) { const int components = seqlen_k != nullptr ? 1 : (total_sequence_length % 4 == 0 ? 4 : (total_sequence_length % 2 == 0 ? 2 : 1)); int work_group_size = 64; const int total_sequence_length_comp = (total_sequence_length + components - 1) / components; @@ -310,7 +352,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso } const int elementsPerThread = (total_sequence_length_comp + work_group_size - 1) / work_group_size; - InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr}; + InPlaceSoftmaxProgram program{work_group_size, components, use_smooth_softmax, seqlen_k != nullptr, head_sink != nullptr, local_window_size}; if (seqlen_k != nullptr) { program.AddInput({seqlen_k, ProgramTensorMetadataDependency::TypeAndRank}); } @@ -318,7 +360,7 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso program.AddInput({head_sink, ProgramTensorMetadataDependency::Type}); } program.AddOutputs({{probs, ProgramTensorMetadataDependency::TypeAndRank, components}}) - .CacheHint(work_group_size, use_smooth_softmax) + .CacheHint(work_group_size, use_smooth_softmax, local_window_size != -1) .SetDispatchGroupSize(batch_size * num_heads * sequence_length) .SetWorkgroupSize(work_group_size) .AddUniformVariables({{static_cast(batch_size)}, @@ -327,7 +369,8 @@ Status ComputeInPlaceSoftmax(onnxruntime::webgpu::ComputeContext& context, Tenso {static_cast(sequence_length)}, {static_cast(total_sequence_length_comp)}, {static_cast(elementsPerThread)}, - {static_cast(is_first_prompt ? 1 : 0)}}); + {static_cast(is_first_prompt ? 1 : 0)}, + {static_cast(local_window_size)}}); return context.RunProgram(program); } @@ -467,7 +510,8 @@ Status ComputeVxAttentionScore(onnxruntime::webgpu::ComputeContext& context, int Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, - WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, const Tensor* seqlen_k) { + WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, const Tensor* head_sink, + const Tensor* seqlen_k, int local_window_size) { const int output_count = std::min({context.OutputCount(), 1 + (past_key != nullptr ? 1 : 0) + (past_value != nullptr ? 1 : 0)}); const int past_sequence_length = output_count > 1 ? parameters.past_sequence_length_ : 0; const int total_sequence_length = @@ -481,7 +525,7 @@ Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const T parameters, past_sequence_length, total_sequence_length, seqlen_k)); ORT_RETURN_IF_ERROR(ComputeInPlaceSoftmax(context, &probs, - parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink)); + parameters.batch_size_, parameters.num_heads_, parameters.past_sequence_length_, parameters.sequence_length_, total_sequence_length, seqlen_k, parameters.is_first_prompt_, parameters.use_smooth_softmax_, head_sink, local_window_size)); ORT_RETURN_IF_ERROR(ComputeVxAttentionScore(context, output_count, &probs, V, past_value, output, present_value, parameters, past_sequence_length, total_sequence_length, seqlen_k)); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention.h b/onnxruntime/contrib_ops/webgpu/bert/attention.h index e64ca3539c23d..3450705b04908 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention.h @@ -69,8 +69,8 @@ class AttentionProbsProgram final : public Program { class InPlaceSoftmaxProgram final : public Program { public: - InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink) - : Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink) { + InPlaceSoftmaxProgram(int work_group_size, int components, bool use_smooth_softmax, bool has_seqlen_k, bool has_head_sink, int local_window_size) + : Program{"InPlaceSoftmax"}, work_group_size_(work_group_size), components_(components), use_smooth_softmax_(use_smooth_softmax), has_seqlen_k_(has_seqlen_k), has_head_sink_(has_head_sink), local_window_size_(local_window_size) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -81,7 +81,8 @@ class InPlaceSoftmaxProgram final : public Program { {"sequence_length", ProgramUniformVariableDataType::Uint32}, {"total_sequence_length_comp", ProgramUniformVariableDataType::Uint32}, {"elements_per_thread", ProgramUniformVariableDataType::Uint32}, - {"is_first_prompt", ProgramUniformVariableDataType::Uint32}); + {"is_first_prompt", ProgramUniformVariableDataType::Uint32}, + {"local_window_size", ProgramUniformVariableDataType::Uint32}); private: int work_group_size_; @@ -89,6 +90,7 @@ class InPlaceSoftmaxProgram final : public Program { bool use_smooth_softmax_; bool has_seqlen_k_; bool has_head_sink_; + int local_window_size_; }; class VxAttentionScoreProgram final : public Program { diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 9d4740ede7143..71161c120a306 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -124,7 +124,7 @@ Status TransferBSDToBNSH(onnxruntime::webgpu::ComputeContext& context, int num_h Status ApplyAttention(const Tensor* Q, const Tensor* K, const Tensor* V, const Tensor* attention_bias, const Tensor* past_key, const Tensor* past_value, Tensor* output, Tensor* present_key, Tensor* present_value, WebgpuAttentionParameters& parameters, onnxruntime::webgpu::ComputeContext& context, - const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr); + const Tensor* head_sink = nullptr, const Tensor* seqlen_k = nullptr, int local_window_size = -1); } // namespace webgpu } // namespace contrib diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 1f039177b0a21..40d46cc3fba59 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -200,6 +200,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& if (!do_rotary_ && head_sink == nullptr && !use_smooth_softmax_ && + local_window_size_ == -1 && CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, present_value, parameters, context); @@ -241,7 +242,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& context, parameters.num_heads_, parameters.sequence_length_, parameters.head_size_, query, nullptr, 0, &Q)); if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { // key and value in BNSH format return ApplyAttention(&Q, key, value, attention_bias, past_key, past_value, output, present_key, - present_value, parameters, context, head_sink, seqlen_k); + present_value, parameters, context, head_sink, seqlen_k, local_window_size_); } TensorShapeVector k_new_dims({parameters.batch_size_, parameters.kv_num_heads_, @@ -258,7 +259,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& ORT_RETURN_IF_ERROR(TransferBSDToBNSH(context, parameters.kv_num_heads_, parameters.kv_sequence_length_, parameters.v_head_size_, value, nullptr, 0, &V)); return ApplyAttention(&Q, &K, &V, attention_bias, past_key, past_value, output, present_key, - present_value, parameters, context, head_sink, seqlen_k); + present_value, parameters, context, head_sink, seqlen_k, local_window_size_); } } // namespace webgpu diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc new file mode 100755 index 0000000000000..d2e0beeba00db --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.cc @@ -0,0 +1,246 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_utils.h" +#include "core/providers/webgpu/webgpu_supported_types.h" +#include "contrib_ops/webgpu/webgpu_contrib_kernels.h" +#include "contrib_ops/webgpu/quantization/gather_block_quantized.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +Status GatherBlockQuantizedProgram::GenerateShaderCode(ShaderHelper& shader) const { + const auto& x = shader.AddInput("input", ShaderUsage::UseElementTypeAlias); + const auto& x_shape = shader.AddIndices("input_shape", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias); + const auto& indices = shader.AddInput("indices", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseIndicesToOffset); + const auto& scales = shader.AddInput("scales", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias); + const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseShapeAndStride | ShaderUsage::UseValueTypeAlias); + + bool is_4bit = bits_ == 4; + const std::string unpack = (is_signed_) ? "unpack4xI8" : "unpack4xU8"; + + shader.MainFunctionBody() + << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"; + + if (indices_rank_ > 1) { + shader.MainFunctionBody() + << "var indices_indices = indices_indices_t(0);\n" + << "for (var i: u32 = 0; i < " << indices_rank_ << "; i++) {\n" + << " let index = " << output.IndicesGet("output_indices", "uniforms.gather_axis + i") << ";\n" + << " " << indices.IndicesSet("indices_indices", "i", "index") << ";\n};\n"; + } else { + shader.MainFunctionBody() + << "let indices_indices = " << output.IndicesGet("output_indices", "uniforms.gather_axis") << ";\n"; + } + shader.MainFunctionBody() + << "var data_indices = input_shape_indices_t(0);\n" + << "for (var i: u32 = 0; i < uniforms.gather_axis; i++) {\n" + << " let index = " << output.IndicesGet("output_indices", "i") << ";\n " + << x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n" + << "var index_from_indices = " << indices.GetByIndices("indices_indices") << ";\n" + << "if (index_from_indices < 0) { index_from_indices += " << x_shape_[gather_axis_] << ";}\n" + << x_shape.IndicesSet("data_indices", "uniforms.gather_axis", "u32(index_from_indices)") << ";\n" + << "for (var i = uniforms.gather_axis + 1; i < " << output_shape_.NumDimensions() << "; i++) {\n" + << " let index = " << output.IndicesGet("output_indices", "i + " + std::to_string(indices_rank_ - 1)) << ";\n " + << x_shape.IndicesSet("data_indices", "i", "index") << ";\n};\n" + << " let data_offset = " << x_shape.IndicesToOffset("data_indices") << ";\n"; + + if (is_4bit) { + shader.MainFunctionBody() + << " let data_index = data_offset % 8;\n" + << " let packed_4bit_quantized_data = " << x.GetByOffset("data_offset / 8") << ";\n" + << " let packed_8bit_quantized_data = (packed_4bit_quantized_data >> (4 * (data_index % 2))) & 0x0f0f0f0f;\n" + << " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n" + << " var quantized_data = quantized_data_vec[data_index / 2];\n"; + if (is_signed_) { + shader.MainFunctionBody() + << " if((quantized_data & 0x8) != 0) { quantized_data = quantized_data - 16 ;};\n"; + } + } else { + shader.MainFunctionBody() + << " let data_index = data_offset % 4;\n" + << " let packed_8bit_quantized_data = " << x.GetByOffset("data_offset / 4") << ";\n" + << " let quantized_data_vec = " << unpack << "(u32(packed_8bit_quantized_data));\n" + << " var quantized_data = quantized_data_vec[data_index];\n"; + } + + shader.MainFunctionBody() + << " var scale_indices = data_indices;\n" + << " let quantize_axis_index = " << scales.IndicesGet("data_indices", "uniforms.quantize_axis") << "/ uniforms.block_size;\n " + << scales.IndicesSet("scale_indices", "uniforms.quantize_axis", "quantize_axis_index") << ";\n" + << " var scale = " << scales.GetByIndices("scale_indices") << ";\n"; + + if (!has_zeropoint_) { + const std::string default_zero_point = is_uint8_ ? is_4bit ? "input_element_t(8)" : "input_element_t(128)" : "input_element_t(0)"; + shader.MainFunctionBody() + << " let zero_point = " << default_zero_point << ";\n"; + } else { + const auto& zero_point = shader.AddInput("zero_point", ShaderUsage::None); + shader.MainFunctionBody() + << " let zero_point_indices = scale_indices;\n" + << " let zero_point_offset = " << scales.IndicesToOffset("zero_point_indices") << ";\n"; + if (is_4bit) { + shader.MainFunctionBody() + << " let zero_point_index = zero_point_offset % 8;\n" + << " let packed_4bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 8") << ";\n" + << " let packed_8bit_zero_points = (packed_4bit_zero_points >> (4 * (zero_point_index % 2))) & 0x0f0f0f0f;\n" + << " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n" + << " var zero_point = zero_point_vec[zero_point_index / 2];\n"; + } else { + shader.MainFunctionBody() + << " let zero_point_index = zero_point_offset % 4;\n" + << " let packed_8bit_zero_points = " << zero_point.GetByOffset("zero_point_offset / 4") << ";\n" + << " let zero_point_vec = " << unpack << "(u32(packed_8bit_zero_points));\n" + << " var zero_point = zero_point_vec[zero_point_index];\n"; + } + if (is_signed_) { + shader.MainFunctionBody() + << " if((zero_point & 0x8) != 0) { zero_point = zero_point - 16 ;};\n"; + } + } + shader.MainFunctionBody() + << " let dequantized_data = (output_value_t(quantized_data) - output_value_t(zero_point)) * scale;\n " + << output.SetByOffset("global_idx", "dequantized_data") << ";\n"; + + return Status::OK(); +} + +TensorShapeVector splice(TensorShapeVector vec, size_t start, size_t deleteCount, const TensorShapeVector toInsert = {}) { + TensorShapeVector new_vec; + + for (size_t i = 0; i < vec.size(); i++) { + if (i < start) { + new_vec.push_back(vec[i]); + } else if (i == start) { + new_vec.insert(new_vec.end(), toInsert.begin(), toInsert.end()); + } else if (i >= start + deleteCount) { + new_vec.push_back(vec[i]); + } + } + return new_vec; +} + +Status GatherBlockQuantized::ComputeInternal(ComputeContext& context) const { + const auto* x = context.Input(0); + const auto* indices = context.Input(1); + const auto* scales = context.Input(2); + const auto* zero_points = context.Input(3); + + int x_rank = static_cast(x->Shape().NumDimensions()); + int64_t x_dtype = x->GetElementType(); + bool is_signed = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4; + bool is_int8 = x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 || x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8; + + std::optional data_representation_4bit; + std::optional zero_points_representation_4bit; + if (bits_ == 4 && is_int8) { + TensorShape data_representation_4bit_shape{x->Shape()}; + MLDataType new_dtype = (x_dtype == ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) ? DataTypeImpl::GetType() : DataTypeImpl::GetType(); + auto memory_info = OrtMemoryInfo{ + "WebGPU_Buffer", + OrtDeviceAllocator, + OrtDevice{OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0}}; + + data_representation_4bit_shape[x_rank - 1] = data_representation_4bit_shape[x_rank - 1] * 2; + data_representation_4bit.emplace( + new_dtype, + data_representation_4bit_shape, + const_cast(x->DataRaw()), + memory_info); + + if (zero_points) { + TensorShape zero_points_representation_4bit_shape{zero_points->Shape()}; + zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] = + zero_points_representation_4bit_shape[zero_points->Shape().NumDimensions() - 1] * 2; + zero_points_representation_4bit.emplace( + new_dtype, + zero_points_representation_4bit_shape, + const_cast(zero_points->DataRaw()), + memory_info); + } + x = data_representation_4bit.has_value() ? &data_representation_4bit.value() : x; + zero_points = zero_points_representation_4bit.has_value() ? &zero_points_representation_4bit.value() : zero_points; + } + + const auto& x_shape = x->Shape(); + + size_t indices_rank = indices->Shape().NumDimensions(); + const auto scales_shape = scales->Shape(); + size_t scales_rank = scales_shape.NumDimensions(); + int gather_axis = (gather_axis_ >= 0) ? gather_axis_ : gather_axis_ + x_rank; + int quantize_axis = (quantize_axis_ >= 0) ? quantize_axis_ : quantize_axis_ + x_rank; + + ORT_RETURN_IF_NOT(x_shape.NumDimensions() == scales_rank, + "data and scales must have the same rank."); + for (size_t i = 0; i < x_shape.NumDimensions(); ++i) { + ORT_RETURN_IF_NOT(i == static_cast(quantize_axis) + ? (x_shape[i] * 1 + block_size_ - 1) / block_size_ == scales_shape[i] + : x_shape[i] == scales_shape[i], + "data and scales do not match shapes."); + } + + TensorShape output_shape = splice(x_shape.AsShapeVector(), gather_axis, 1, indices->Shape().AsShapeVector()); + int64_t output_size = output_shape.Size(); + auto* output_tensor = context.Output(0, output_shape); + + GatherBlockQuantizedProgram program{is_signed, is_int8, indices_rank, gather_axis, bits_, zero_points != nullptr, x_shape, output_shape}; + + program + .AddInputs({{x, ProgramTensorMetadataDependency::Type, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}}) + .AddIndices(x_shape) + .AddInputs({{indices, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddInputs({{scales, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutput({output_tensor, ProgramTensorMetadataDependency::None}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}}) + .AddUniformVariables({{static_cast(quantize_axis)}}) + .AddUniformVariables({{static_cast(gather_axis)}}) + .AddUniformVariables({{static_cast(block_size_)}}) + .CacheHint(std::to_string(gather_axis), std::to_string(quantize_axis), std::to_string(block_size_)); + + if (zero_points != nullptr) { + ORT_RETURN_IF_NOT(scales_shape == zero_points->Shape(), + "scales and zero_points must have the same shape."); + auto zero_points_shape = zero_points->Shape(); + program.AddInputs({{zero_points, ProgramTensorMetadataDependency::None, ProgramInput::Flatten, (bits_ == 4) ? 8 : 4}}); + } + + return context.RunProgram(program); +} + +namespace { +const std::vector& GatherBlockQuantizedT1Constraint() { + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +const std::vector& GatherBlockQuantizedTindConstraint() { + static std::vector types{ + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}; + return types; +} +} // namespace + +ONNX_OPERATOR_KERNEL_EX( + GatherBlockQuantized, + kMSDomain, + 1, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", GatherBlockQuantizedT1Constraint()) + .TypeConstraint("T2", WebGpuSupportedFloatTypes()) + .TypeConstraint("Tind", GatherBlockQuantizedTindConstraint()), + GatherBlockQuantized); + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h new file mode 100755 index 0000000000000..cd7392995f4cf --- /dev/null +++ b/onnxruntime/contrib_ops/webgpu/quantization/gather_block_quantized.h @@ -0,0 +1,71 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/program.h" +#include "core/providers/webgpu/webgpu_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace webgpu { + +using namespace onnxruntime::webgpu; +using onnxruntime::webgpu::ComputeContext; + +class GatherBlockQuantizedProgram final : public Program { + public: + GatherBlockQuantizedProgram(const bool is_signed, const bool is_uint8, size_t indices_rank, int gather_axis, int bits, bool has_zeropoint, + TensorShape x_shape, TensorShape output_shape) : Program{"GatherBlockQuantized"}, + is_signed_{is_signed}, + is_uint8_{is_uint8}, + indices_rank_{indices_rank}, + gather_axis_{gather_axis}, + bits_{bits}, + has_zeropoint_{has_zeropoint}, + x_shape_{x_shape}, + output_shape_{output_shape} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"quantize_axis", ProgramUniformVariableDataType::Uint32}, + {"gather_axis", ProgramUniformVariableDataType::Uint32}, + {"block_size", ProgramUniformVariableDataType::Uint32}); + + private: + bool is_signed_; + bool is_uint8_; + size_t indices_rank_; + int gather_axis_; + int bits_; + bool has_zeropoint_; + TensorShape x_shape_; + TensorShape output_shape_; +}; + +class GatherBlockQuantized final : public WebGpuKernel { + public: + GatherBlockQuantized(const OpKernelInfo& info) : WebGpuKernel(info) { + gather_axis_ = static_cast(info.GetAttrOrDefault("gather_axis", 0)); + block_size_ = static_cast(info.GetAttrOrDefault("block_size", 128)); + quantize_axis_ = static_cast(info.GetAttrOrDefault("quantize_axis", 1)); + bits_ = static_cast(info.GetAttrOrDefault("bits", 4)); + + ORT_ENFORCE(bits_ == 4 || bits_ == 8, "'bits' must be 4 or 8."); + ORT_ENFORCE(block_size_ >= 16 && ((block_size_ - 1) & block_size_) == 0, + "'block_size' must be 2's power and not less than 16."); + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int gather_axis_; + int quantize_axis_; + int block_size_; + int bits_; +}; + +} // namespace webgpu +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index 4136477a1d88c..25cc13b3ea1df 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -14,6 +14,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Bi class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GatherBlockQuantized); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention); // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it @@ -40,6 +41,7 @@ Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 9324fa76ded4f..9b734559776e3 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -18,13 +18,14 @@ #include "core/framework/bfc_arena.h" -using Status = onnxruntime::common::Status; +using namespace onnxruntime; +using Status = common::Status; Status OrtArenaCfg::FromKeyValuePairs(const OrtKeyValuePairs& kvps, OrtArenaCfg& cfg) { cfg = OrtArenaCfg{}; // reset to default values const auto from_string = [](const std::string& key, const std::string& str, auto& value) -> Status { - if (!onnxruntime::ParseStringWithClassicLocale(str, value).IsOK()) { + if (!ParseStringWithClassicLocale(str, value).IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to parse value for ", key, " from ", str); } @@ -250,7 +251,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA } ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, - _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType type, _Outptr_ OrtMemoryInfo** out) { // map the public enum values to internal OrtDevice values @@ -275,7 +276,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo_V2, _In_ const char* name, _In_ en return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid device type specified."); } - *out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, device_id, alignment}, + *out = new OrtMemoryInfo(name, type, OrtDevice{dt, mt, vendor_id, narrow(device_id), alignment}, mem_type == OrtDeviceMemoryType_DEFAULT ? OrtMemTypeDefault : OrtMemTypeCPU); return nullptr; } @@ -313,3 +314,13 @@ ORT_API_STATUS_IMPL(OrtApis::CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, ORT_API(void, OrtApis::MemoryInfoGetDeviceType, _In_ const OrtMemoryInfo* info, _Out_ OrtMemoryInfoDeviceType* out) { *out = static_cast(info->device.Type()); } + +ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out) { + *out = static_cast(ptr->device.MemType()); + return nullptr; +} + +ORT_API_STATUS_IMPL(OrtApis::MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out) { + *out = ptr->device.Vendor(); + return nullptr; +} diff --git a/onnxruntime/core/framework/plugin_ep_stream.cc b/onnxruntime/core/framework/plugin_ep_stream.cc new file mode 100644 index 0000000000000..1eb6ad4162f33 --- /dev/null +++ b/onnxruntime/core/framework/plugin_ep_stream.cc @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/plugin_ep_stream.h" +#include "core/framework/error_code_helper.h" + +namespace onnxruntime { +namespace plugin_ep { + +// TODO: Is num_consumers meaningful? Unused everywhere currently. +OrtStatus* Stream::CreateNotificationImpl(size_t /*num_consumers*/, std::unique_ptr& result) { + OrtSyncNotificationImpl* notification_impl = nullptr; + ORT_API_RETURN_IF_ERROR(impl_.CreateNotification(&impl_, ¬ification_impl)); + + result = std::make_unique(*this, *notification_impl, logger_); + return nullptr; +} +} // namespace plugin_ep +} // namespace onnxruntime diff --git a/onnxruntime/core/framework/plugin_ep_stream.h b/onnxruntime/core/framework/plugin_ep_stream.h new file mode 100644 index 0000000000000..2b89e76e16b76 --- /dev/null +++ b/onnxruntime/core/framework/plugin_ep_stream.h @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/common/logging/logging.h" +#include "core/framework/stream_handles.h" +#include "core/framework/error_code_helper.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" + +// OrtSyncStream is an alias in the C API for onnxruntime::Stream +// OrtSyncNotification is an alias in the C API for onnxruntime::synchronize::Notification +struct OrtSyncStream : public onnxruntime::Stream {}; +struct OrtSyncNotification : onnxruntime::synchronize::Notification {}; + +using onnxruntime::logging::Logger; + +#define LOG_AND_RETURN_IF_ORT_ERROR(fn, logger) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + LOGS(logger, ERROR) << "Plug-in EP Error: [" << OrtApis::GetErrorCode(_status) << "] " \ + << OrtApis::GetErrorMessage(_status); \ + OrtApis::ReleaseStatus(_status); \ + return; \ + } \ + } while (0) + +namespace onnxruntime { +namespace plugin_ep { + +class Notification : public synchronize::Notification { + public: + Notification(Stream& stream, OrtSyncNotificationImpl& impl, const Logger& logger) + : synchronize::Notification(stream), impl_{impl}, logger_{logger} { + } + + static void WaitNotificationOnDevice(onnxruntime::Stream* stream, synchronize::Notification& notification) { + auto* this_ptr = static_cast(¬ification); + + LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnDevice(&this_ptr->impl_, static_cast(stream)), + this_ptr->logger_); + } + + static void WaitNotificationOnHost(onnxruntime::Stream* /*stream*/, synchronize::Notification& notification) { + auto* this_ptr = static_cast(¬ification); + LOG_AND_RETURN_IF_ORT_ERROR(this_ptr->impl_.WaitOnHost(&this_ptr->impl_), this_ptr->logger_); + } + + void Activate() override { + LOG_AND_RETURN_IF_ORT_ERROR(impl_.Activate(&impl_), logger_); + } + + ~Notification() override { + impl_.Release(&impl_); + } + + private: + OrtSyncNotificationImpl& impl_; + const Logger& logger_; +}; + +class Stream : public onnxruntime::Stream { + public: + Stream(const OrtDevice& memory_device, OrtSyncStreamImpl& impl, const logging::Logger& logger) + : onnxruntime::Stream(impl.GetHandle(&impl), memory_device), impl_{impl}, logger_{logger} { + } + + std::unique_ptr CreateNotification(size_t num_consumers) override { + std::unique_ptr plugin_notification; + + auto* ort_status = CreateNotificationImpl(num_consumers, plugin_notification); + if (ort_status != nullptr) { + ORT_THROW("Failed to create Notification: [", OrtApis::GetErrorCode(ort_status), "] ", + OrtApis::GetErrorMessage(ort_status)); + } + + return plugin_notification; + } + + void Flush() override { + LOG_AND_RETURN_IF_ORT_ERROR(impl_.Flush(&impl_), logger_); + } + + Status CleanUpOnRunEnd() override { + auto* ort_status = impl_.OnSessionRunEnd(&impl_); + return ToStatusAndRelease(ort_status); + } + + WaitNotificationFn GetWaitNotificationFn() const override { + return Notification::WaitNotificationOnDevice; + } + + ~Stream() override { + impl_.Release(&impl_); + } + + private: + OrtSyncStream* ToApiStream() { + return static_cast(static_cast(this)); + } + + OrtStatus* CreateNotificationImpl(size_t num_consumers, std::unique_ptr& result); + + OrtSyncStreamImpl& impl_; + const Logger& logger_; +}; +} // namespace plugin_ep +} // namespace onnxruntime + +#undef LOG_AND_RETURN_IF_ORT_ERROR diff --git a/onnxruntime/core/graph/graph_proto_serializer.cc b/onnxruntime/core/graph/graph_proto_serializer.cc index 993020278eb03..0fbcea2719ce8 100644 --- a/onnxruntime/core/graph/graph_proto_serializer.cc +++ b/onnxruntime/core/graph/graph_proto_serializer.cc @@ -94,8 +94,8 @@ void GraphViewerToProto(const GraphViewer& graph_view, current_scope_initializer_set.insert(name); auto* p_initializer = graph_proto.add_initializer(); - // Do not save raw or external data into the graph, only the metadata - if (!include_initializer_data && (init->has_raw_data() || init->has_data_location())) { + // Do not save raw into the graph, only the metadata + if (!include_initializer_data && init->has_raw_data()) { // Set datatype if (init->has_data_type()) { p_initializer->set_data_type(init->data_type()); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 4edf804e48aaa..dbf86e2bb7fc7 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -220,6 +220,12 @@ InlinedVector> GenerateTransformers( AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); switch (level) { + case TransformerLevel::Default: { + if (!session_options.free_dimension_overrides.empty()) { + transformers.emplace_back(std::make_unique( + session_options.free_dimension_overrides)); + } + } break; case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) // so run them first so there is potentially less for the more intensive optimizations like ConstantFolding, diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index ac128011c0b9f..6f2538bcde3b1 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -1085,7 +1085,8 @@ static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, cons if (perm.has_value()) { auto perm_inv = InvertPerm(*perm); std::vector indices = {0}; - HandlerArgs args{ctx, *inp_node, unsqueeze, *perm, perm_inv, indices}; + std::unordered_set dummy_outputs_leading_to_transpose; + HandlerArgs args{ctx, *inp_node, unsqueeze, *perm, perm_inv, indices, dummy_outputs_leading_to_transpose}; const auto new_input = HelpHandleUnsqueeze(args, axes); // Use output from optimization (likely from pushed transpose) node.SetInput(i, new_input); @@ -2391,7 +2392,7 @@ static bool FinalizeReshapeShape(const std::vector& input_shape, / return true; } -static bool HandleReshape(HandlerArgs& args) { +bool HandleReshape(HandlerArgs& args) { // A Reshape can be logically equivalent to a Transpose if all dims with a value > 1 remain in the same order // and do not change size. If so, we can use HandleTransposeImpl to merge them. // e.g. Reshape(input {1, 512, 4, 1}, shape {1, 1, 512, 4}) is equivalent to Transpose with perms { 0, 3, 1, 2 } @@ -2700,7 +2701,7 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef& } std::vector perm_inv = InvertPerm(perm); - HandlerArgs args = {ctx, transpose, node, perm, perm_inv, input_indices}; + HandlerArgs args = {ctx, transpose, node, perm, perm_inv, input_indices, outputs_leading_to_transpose}; return info->handler_fn(args); } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index f65bd6aa82fbb..9a057cc45a39a 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -5,6 +5,7 @@ #include #include +#include #include // implementation details of the transpose optimizer API defined in optimizer_api.h. @@ -24,6 +25,7 @@ struct HandlerArgs { const std::vector& perm_inv; // inverse of perm. // Cached result from calling HandlerInfo.transposible_inputs_fn std::vector& transposible_inputs; + const std::unordered_set& outputs_leading_to_transpose; }; // Each op handler points to a (potentially shared) function for determining which input indices are eligible for @@ -76,6 +78,7 @@ bool HandleSoftHardMax(HandlerArgs& args); // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); +bool HandleReshape(HandlerArgs& args); bool HandleResize([[maybe_unused]] HandlerArgs& args); void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index 824ab20a84668..f861fad96f301 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -4,6 +4,9 @@ #include "core/optimizer/transpose_optimization/ort_transpose_optimization.h" #include + +#include + #include "core/graph/constants.h" #include "core/framework/utils.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" @@ -34,6 +37,93 @@ static bool EPAwareHandleResize(HandlerArgs& args) { constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; +static bool EPAwareHandleReshape(HandlerArgs& args) { + const auto ep_type = args.node.GetExecutionProviderType(); + if (ep_type == kQnnExecutionProvider) { + // In some cases, the pattern of Transpose-Reshape-Transpose can be optimized to a single Reshape. + // For example, [N,H,W,C] - [N,C,H,W] - [N,C,HxW] - [N,HxW,C] is functionally equivalent to [N,H,W,C] - [N,HxW,C]. + // In this optimization, we attempt to handle those "channel" Transpose and "spatial" Reshape cases, like the + // example above. + + // Only attempts to push through if Transpose is possibly canceled with following one. + const std::string output_name = std::string(args.node.Outputs()[0]); + if (args.outputs_leading_to_transpose.find(output_name) == args.outputs_leading_to_transpose.end()) { + return HandleReshape(args); + } + + // Get input/output shapes. + auto reshape_input_shape = args.ctx.graph.GetValueInfo(args.node.Inputs()[0])->Shape(); + auto reshape_output_shape = args.ctx.graph.GetValueInfo(args.node.Outputs()[0])->Shape(); + if (!reshape_input_shape.has_value() || !reshape_output_shape.has_value()) { + return HandleReshape(args); + } + + const std::vector& input_shape = *reshape_input_shape; + const std::vector& output_shape = *reshape_output_shape; + const size_t input_rank = input_shape.size(); + const size_t output_rank = output_shape.size(); + + std::vector output_perm; + + // Determine "channel" Transpose by checking perm being channel-first to channel-last or vice versa. + const std::vector perm_3d{0, 2, 1}; + const std::vector perm_4d_cl{0, 2, 3, 1}; + const std::vector perm_4d_cf{0, 3, 1, 2}; + + // Determine "spatial" Reshape by checking the batch and channel dimensions untouched. + const bool batch_preserved = (input_shape[0] == output_shape[0]); + const bool cf_preserved = (input_shape[1] == output_shape[1]); + const bool cl_preserved = (input_shape[input_rank - 1] == output_shape[output_rank - 1]); + + if (args.perm == perm_3d) { + // There is ambiguity to determine the direction solely from this Transpose perm. + // The implementation may result in non-fully optimized pattern since the perm info from the output Transpose is + // mandatory for determination. Leave it as future work as such info is inaccessible in current infra. + if (batch_preserved && cf_preserved) { + output_perm = ChannelFirstToLastPerm(output_rank); + } else if (batch_preserved && cl_preserved) { + output_perm = ChannelLastToFirstPerm(output_rank); + } else { + return HandleReshape(args); + } + } else if (args.perm == perm_4d_cl && batch_preserved && cl_preserved) { + output_perm = ChannelLastToFirstPerm(output_rank); + } else if (args.perm == perm_4d_cf && batch_preserved && cf_preserved) { + output_perm = ChannelFirstToLastPerm(output_rank); + } else { + return HandleReshape(args); + } + + TransposeFirstInput(args.ctx, args.node, args.perm_inv); + + std::vector new_shape; + new_shape.reserve(output_rank); + for (size_t axis = 0; axis < output_rank; ++axis) { + new_shape.push_back(output_shape[static_cast(output_perm[axis])]); + } + + const uint8_t* new_shape_data = reinterpret_cast(new_shape.data()); + const std::string_view new_shape_name = args.ctx.graph.AddInitializer( + api::DataType::INT64, + {gsl::narrow_cast(new_shape.size())}, + std::vector(new_shape_data, new_shape_data + new_shape.size() * sizeof(int64_t))); + + const std::string_view old_shape_name = args.node.Inputs()[1]; + args.node.SetInput(1, new_shape_name); + if (!args.ctx.graph.HasValueConsumers(old_shape_name)) { + args.ctx.graph.RemoveInitializer(old_shape_name); + } + + TransposeOutputs(args.ctx, args.node, InvertPerm(output_perm)); + return true; + } + + // Fallback to default handler. + return HandleReshape(args); +} + +constexpr HandlerInfo ep_aware_reshape_handler = {&FirstInput, &EPAwareHandleReshape, /*transposes_outputs*/ false}; + std::vector QLinearConcatInputs(OptimizerCtx& ctx, api::NodeRef& node) { (void)ctx; std::vector indices; @@ -131,6 +221,7 @@ const HandlerMap& OrtExtendedHandlers() { HandlerMap map = { {"MaxPool", max_pool_op_handler}, {"Resize", ep_aware_resize_handler}, + {"Reshape", ep_aware_reshape_handler}, {"com.microsoft.QuantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.DequantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 7943f56d12741..a691faaffd2a0 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1385,8 +1385,7 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse HashValue hash; cann::GenerateHashValue(input_shape, hash); std::string filename = cann_state->node_name + "_" + std::to_string(hash); - std::string filename_with_suffix = filename + ".om"; - + bool dynamic_shape = false; // TODO(FFFrog): Resource Management // It is very necessary to provide a new mechanism for memory reclamation to avoid inference failure caused by // device memory exhaustion @@ -1395,8 +1394,8 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse modelID = modelIDs_[filename]; } else { std::lock_guard lock(g_mutex); - - if (cann::FileExist(filename_with_suffix)) { + auto filename_with_suffix = cann::RegexMatchFile(filename); + if (!filename_with_suffix.empty()) { CANN_RETURN_IF_ERROR(aclmdlLoadFromFile(filename_with_suffix.c_str(), &modelID)); } else { ge::Graph graph{cann_state->node_name.c_str()}; @@ -1424,11 +1423,17 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse for (size_t i = 0; i < aclmdlGetNumOutputs(prepare.modelDesc_); i++) { aclmdlIODims dims; CANN_CALL_THROW(aclmdlGetOutputDims(prepare.modelDesc_, i, &dims)); - std::vector vec{dims.dims, dims.dims + dims.dimCount}; - auto output = ctx.GetOutput(i, vec); - CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare, - const_cast(output.GetTensorRawData()), - aclmdlGetOutputSizeByIndex(prepare.modelDesc_, i)); + + if (cann::is_dynamic_shape(dims)) { + CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare, nullptr, 0); + dynamic_shape = true; + } else { + std::vector vec{dims.dims, dims.dims + dims.dimCount}; + auto output = ctx.GetOutput(i, vec); + CANN_MODEL_PREPARE_OUTPUTBUFFER(prepare, + const_cast(output.GetTensorRawData()), + aclmdlGetOutputSizeByIndex(prepare.modelDesc_, i)); + } } } ORT_CATCH(const std::exception& e) { @@ -1436,8 +1441,28 @@ Status CANNExecutionProvider::Compile(const std::vector& fuse } aclrtStream stream = static_cast(ctx.GetGPUComputeStream()); - CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream)); - + if (dynamic_shape) { + aclrtSynchronizeStream(stream); + CANN_RETURN_IF_ERROR(aclmdlExecute(modelID, prepare.inputSet_, prepare.outputSet_)); + for (size_t i = 0; i < aclmdlGetNumOutputs(prepare.modelDesc_); i++) { + std::vector shape; + aclTensorDesc* desc = aclmdlGetDatasetTensorDesc(prepare.outputSet_, i); + size_t num_dims = aclGetTensorDescNumDims(desc); + shape.reserve(num_dims); + for (size_t j = 0; j < num_dims; j++) { + int64_t dim; + CANN_RETURN_IF_ERROR(aclGetTensorDescDimV2(desc, j, &dim)); + shape.push_back(dim); + } + aclDataBuffer* dataBuffer = aclmdlGetDatasetBuffer(prepare.outputSet_, i); + void* src_data = aclGetDataBufferAddr(dataBuffer); + void* dst_data = const_cast(ctx.GetOutput(i, shape).GetTensorRawData()); + size_t count = aclGetTensorDescSize(desc); + CANN_CALL_THROW(aclrtMemcpyAsync(dst_data, count, src_data, count, ACL_MEMCPY_DEVICE_TO_DEVICE, stream)); + } + } else { + CANN_RETURN_IF_ERROR(aclmdlExecuteAsync(modelID, prepare.inputSet_, prepare.outputSet_, stream)); + } return Status::OK(); }; diff --git a/onnxruntime/core/providers/cann/cann_utils.cc b/onnxruntime/core/providers/cann/cann_utils.cc index 5b3f9e6731b34..ae648c7f8feeb 100644 --- a/onnxruntime/core/providers/cann/cann_utils.cc +++ b/onnxruntime/core/providers/cann/cann_utils.cc @@ -224,5 +224,23 @@ void GenerateHashValue(const std::string string, HashValue& hash_value) { hash_value = hash[0] | (uint64_t(hash[1]) << 32); } +bool is_dynamic_shape(const aclmdlIODims& dims) { + return std::find(dims.dims, dims.dims + dims.dimCount, -1) != dims.dims + dims.dimCount; +} + +namespace fs = std::filesystem; +std::string RegexMatchFile(const std::string& file_name) { + fs::path current_dir = fs::current_path(); + std::regex pattern(file_name); + for (const auto& entry : fs::directory_iterator(current_dir)) { + if (entry.is_regular_file()) { + std::string name = entry.path().filename().string(); + if (std::regex_search(name, pattern)) { + return name; + } + } + } + return ""; +} } // namespace cann } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cann/cann_utils.h b/onnxruntime/core/providers/cann/cann_utils.h index 3739924758ea4..4be91eadd9556 100644 --- a/onnxruntime/core/providers/cann/cann_utils.h +++ b/onnxruntime/core/providers/cann/cann_utils.h @@ -9,6 +9,8 @@ #include #include #include +#include +#include #include "core/framework/murmurhash3.h" #include "core/providers/cann/cann_common.h" @@ -124,7 +126,8 @@ Status aclrtblasGemmEx(aclTransType transA, bool FileExist(const std::string& file_name); void GenerateHashValue(const std::string string, HashValue& hash_value); - +bool is_dynamic_shape(const aclmdlIODims& dims); +std::string RegexMatchFile(const std::string& file_name); std::unique_ptr CreateModel(const GraphViewer& graph_viewer, const logging::Logger& logger); } // namespace cann diff --git a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc index f00bf51ae143d..bf6e8c0c7e5cc 100644 --- a/onnxruntime/core/providers/cuda/cuda_provider_factory.cc +++ b/onnxruntime/core/providers/cuda/cuda_provider_factory.cc @@ -7,8 +7,9 @@ #include "core/providers/cuda/cuda_provider_factory_creator.h" #include "core/providers/cuda/cuda_provider_options.h" -#include #include +#include +#include #include @@ -16,6 +17,7 @@ #include "core/providers/cuda/cuda_execution_provider.h" #include "core/providers/cuda/cuda_execution_provider_info.h" #include "core/providers/cuda/cuda_allocator.h" +#include "core/providers/cuda/cuda_stream_handle.h" #include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/math/unary_elementwise_ops_impl.h" @@ -307,13 +309,366 @@ CUDA_Provider* GetProvider() { } // namespace onnxruntime -#include "core/framework/error_code_helper.h" +// +// Plug-in EP infrastructure +// + +#include "core/session/abi_devices.h" #include "onnxruntime_config.h" // for ORT_VERSION +struct ErrorHelper { + static const OrtApi* ort_api; + + static OrtStatus* ToOrtStatus(const Status& status) { + if (status.IsOK()) { + return nullptr; // no error + } + + return ort_api->CreateStatus(static_cast(status.Code()), + status.ErrorMessage().c_str()); + } +}; + +const OrtApi* ErrorHelper::ort_api = nullptr; + +#define RETURN_IF_ERROR(fn) \ + do { \ + OrtStatus* _status = (fn); \ + if (_status != nullptr) { \ + return _status; \ + } \ + } while (0) + +#define RETURN_IF_STATUS_NOTOK(fn) \ + do { \ + Status _status = (fn); \ + if (!_status.IsOK()) { \ + return ErrorHelper::ToOrtStatus(_status); \ + } \ + } while (0) + +#define CUDA_RETURN_IF_ERROR(expr) RETURN_IF_STATUS_NOTOK(CUDA_CALL(expr)) + +struct CudaOrtAllocator : OrtAllocator { + CudaOrtAllocator(const OrtMemoryInfo* mem_info, const OrtApi& api) : memory_info_{mem_info} { + version = ORT_API_VERSION; + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special behavior for Reserve so use AllocImpl + GetStats = nullptr; // GetStatsImpl. The CUDA allocators don't have stats currently so we can skip. + + const OrtEpApi& ep_api = *api.GetEpApi(); + const OrtMemoryDevice* mem_device = ep_api.MemoryInfo_GetMemoryDevice(mem_info); + uint32_t device_id = ep_api.MemoryDevice_GetDeviceId(mem_device); + const char* name = nullptr; + auto* status = api.MemoryInfoGetName(mem_info, &name); + static_cast(status); // GetName never fails + + if (ep_api.MemoryDevice_GetMemoryType(mem_device) == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + allocator_ = std::make_unique(device_id, name); + } else { + allocator_ = std::make_unique(device_id, name); + } + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& impl = *static_cast(this_); + return impl.allocator_->Alloc(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* this_, void* p) { + auto& impl = *static_cast(this_); + impl.allocator_->Free(p); + } + + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const CudaOrtAllocator& impl = *static_cast(this_); + return impl.memory_info_; + } + + private: + const OrtMemoryInfo* memory_info_; + std::unique_ptr allocator_; +}; + +struct CudaDataTransferImpl : OrtDataTransferImpl { + CudaDataTransferImpl(const OrtApi& ort_api_in) + : ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + CanCopy = CanCopyImpl; + CopyTensors = CopyTensorsImpl; + Release = ReleaseImpl; + } + + static bool CanCopyImpl(const OrtDataTransferImpl* this_ptr, + const OrtMemoryDevice* src_memory_device, + const OrtMemoryDevice* dst_memory_device) noexcept { + const auto& impl = *static_cast(this_ptr); + + // logic copied from GPUDataTransfer::CanCopy + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); + auto src_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); + auto dst_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); + + if ((src_type == OrtDevice::GPU && src_vendor_id != OrtDevice::VendorIds::NVIDIA) || + (dst_type == OrtDevice::GPU && dst_vendor_id != OrtDevice::VendorIds::NVIDIA)) { + return false; + } + + // copy must be GPU to GPU or between GPU and CPU + return (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_GPU) || + (src_type == OrtMemoryInfoDeviceType_GPU && dst_type == OrtMemoryInfoDeviceType_CPU) || + (src_type == OrtMemoryInfoDeviceType_CPU && dst_type == OrtMemoryInfoDeviceType_GPU); + } + + static OrtStatus* CopyTensorsImpl(OrtDataTransferImpl* this_ptr, + const OrtValue** src_tensors, + OrtValue** dst_tensors, + OrtSyncStream** streams, + size_t num_tensors) noexcept { + auto& impl = *static_cast(this_ptr); + bool need_stream_sync = false; + + for (size_t idx = 0; idx < num_tensors; ++idx) { + const OrtValue* src_tensor = src_tensors[idx]; + OrtValue* dst_tensor = dst_tensors[idx]; + OrtSyncStream* stream = streams ? streams[idx] : nullptr; + + const OrtMemoryDevice *src_device = nullptr, *dst_device = nullptr; + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensor, &src_device)); + RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(dst_tensor, &dst_device)); + + size_t bytes; + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensor, &bytes)); + + const void* src_data = nullptr; + void* dst_data = nullptr; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensor, &src_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensor, &dst_data)); + + OrtMemoryInfoDeviceType src_type = impl.ep_api.MemoryDevice_GetDeviceType(src_device); + OrtMemoryInfoDeviceType dst_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); + + const bool src_is_gpu_default = src_type == OrtMemoryInfoDeviceType_GPU && + src_mem_type == OrtDeviceMemoryType_DEFAULT; + const bool dst_is_gpu_default = dst_type == OrtMemoryInfoDeviceType_GPU && + dst_mem_type == OrtDeviceMemoryType_DEFAULT; + + cudaStream_t cuda_stream = nullptr; + if (stream) { + cuda_stream = static_cast(impl.ort_api.SyncStream_GetHandle(stream)); + } + + if (dst_is_gpu_default) { + if (src_is_gpu_default) { + // Copy only if the two addresses are different. + if (dst_data != src_data) { + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToDevice)); + + // For device memory to device memory copy, no host-side synchronization is performed by cudaMemcpy. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } else { + // copy from pinned or non-pinned CPU memory to GPU + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyHostToDevice, cuda_stream)); + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyHostToDevice)); + + if (src_mem_type != OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // For cudaMemcpy from pageable host memory to device memory, DMA to final destination may not + // have completed. + // see https://docs.nvidia.com/cuda/cuda-runtime-api/api-sync-behavior.html + need_stream_sync = true; + } + } + } + } else if (src_is_gpu_default) { + // copying from GPU to CPU memory, this is blocking + + if (cuda_stream) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dst_data, src_data, bytes, cudaMemcpyDeviceToHost, cuda_stream)); + + } else { + CUDA_RETURN_IF_ERROR(cudaMemcpy(dst_data, src_data, bytes, cudaMemcpyDeviceToHost)); + } + } else { + // copying between CPU accessible memory + + if (dst_data != src_data) { + if (cuda_stream) { + if (src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE) { + // sync the stream first to make sure the data arrived + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + } + + memcpy(dst_data, src_data, bytes); + } + } + } + + if (need_stream_sync) { + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(nullptr)); + } + + return nullptr; + } + + static void ReleaseImpl(OrtDataTransferImpl* /*this_ptr*/) noexcept { + // no-op as we have a single shared instance in OrtEpFactory which is returned from CreateDataTransferImpl, and is + // owned by and freed by the factory. + } + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct CudaSyncNotificationImpl : OrtSyncNotificationImpl { + static OrtStatus* Create(cudaStream_t stream, const OrtApi& ort_api, + std::unique_ptr& notification) { + notification.reset(new CudaSyncNotificationImpl(stream, ort_api)); // can't use make_unique with private ctor + CUDA_RETURN_IF_ERROR(cudaEventCreateWithFlags(¬ification->event_, cudaEventDisableTiming)); + + return nullptr; + } + + static void ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static OrtStatus* ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventRecord(impl.event_, impl.stream_)); + + return nullptr; + } + + static OrtStatus* WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* consumer_stream) noexcept { + auto& impl = *static_cast(this_ptr); + + // setup the consumer stream to wait on our event. + void* consumer_handle = impl.ort_api.SyncStream_GetHandle(consumer_stream); + CUDA_RETURN_IF_ERROR(cudaStreamWaitEvent(static_cast(consumer_handle), impl.event_)); + + return nullptr; + } + + static OrtStatus* WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + CUDA_RETURN_IF_ERROR(cudaEventSynchronize(impl.event_)); + + return nullptr; + } + + ~CudaSyncNotificationImpl() { + cudaEventDestroy(event_); + } + + private: + CudaSyncNotificationImpl(cudaStream_t stream, const OrtApi& ort_api_in) + : stream_{stream}, ort_api{ort_api_in}, ep_api{*ort_api_in.GetEpApi()} { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + Release = ReleaseImpl; + } + + cudaStream_t& stream_; + cudaEvent_t event_; + + const OrtApi& ort_api; + const OrtEpApi& ep_api; +}; + +struct CudaSyncStreamImpl : OrtSyncStreamImpl { + CudaSyncStreamImpl(cudaStream_t&& stream, + const OrtDevice& device, + AllocatorPtr cpu_allocator, + bool release_cpu_buffer_on_cuda_stream, + const OrtApi& ort_api_in) + : stream_{ + stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, /*own*/ true, + /*external_cudnn_handle*/ nullptr, + /*external_cublas_handle*/ nullptr, + // ep_info is used by GetResource which seems to be a somewhat ugly way to make arbitrary info that is + // unrelated to the stream available to a custom op. + // avoiding adding GetResource to OrtSyncStreamImpl as we should have a cleaner setup for custom ops, + // so this argument value isn't used and doesn't matter. + /*ep_info*/ CUDAExecutionProviderInfo{}}, + ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; + GetHandle = GetHandleImpl; + CreateNotification = CreateNotificationImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + } + + static void ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); + } + + static void* GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return impl.stream_.GetHandle(); + } + + static OrtStatus* CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification_impl) noexcept { + auto& impl = *static_cast(this_ptr); + *notification_impl = nullptr; + + std::unique_ptr notification; + cudaStream_t* cuda_stream = static_cast(impl.stream_.GetHandle()); + + RETURN_IF_ERROR(CudaSyncNotificationImpl::Create(*cuda_stream, impl.ort_api, notification)); + *notification_impl = notification.release(); + + return nullptr; + } + + static OrtStatus* FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + impl.stream_.Flush(); + + return nullptr; + } + + static OrtStatus* OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + RETURN_IF_STATUS_NOTOK(impl.stream_.CleanUpOnRunEnd()); + + return nullptr; + } + + private: + // this is a little onion-ish as CudaStream is a onnxruntime::Stream and this is an OrtSyncStreamImpl that will be + // used via plugin_ep::Stream, which is also an onnxruntime::Stream. in a 'real' plugin EP implementation + // CudaStream would go away and the logic it has would be implemented directly here. + CudaStream stream_; + const OrtApi& ort_api; +}; + // OrtEpApi infrastructure to be able to use the CUDA EP as an OrtEpFactory for auto EP selection. struct CudaEpFactory : OrtEpFactory { - CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in} { - ort_version_supported = ORT_API_VERSION; + using MemoryInfoUniquePtr = std::unique_ptr>; + + CudaEpFactory(const OrtApi& ort_api_in) : ort_api{ort_api_in}, + ep_api{*ort_api_in.GetEpApi()}, + data_transfer_impl{ort_api_in} { GetName = GetNameImpl; GetVendor = GetVendorImpl; GetVendorId = GetVendorIdImpl; @@ -321,16 +676,24 @@ struct CudaEpFactory : OrtEpFactory { GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->ep_name.c_str(); + const auto& factory = *static_cast(this_ptr); + return factory.ep_name.c_str(); } static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { - const auto* factory = static_cast(this_ptr); - return factory->vendor.c_str(); + const auto& factory = *static_cast(this_ptr); + return factory.vendor.c_str(); } static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { @@ -349,15 +712,84 @@ struct CudaEpFactory : OrtEpFactory { size_t max_ep_devices, size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; - auto* factory = static_cast(this_ptr); + auto& factory = *static_cast(this_ptr); + + int num_cuda_devices = 0; + cudaGetDeviceCount(&num_cuda_devices); + RETURN_IF_ERROR(factory.CreateMemoryInfoForDevices(num_cuda_devices)); + /* in theory we can match on the LUID in the OrtHardwareDevice metadata, but that requires the CUDA Driver API + std::vector device_to_luid; + device_to_luid.resize(num_cuda_devices); + + for (int i = 0; i < num_cuda_devices; ++i) { + CUdevice device; + cuDeviceGet(&device, i); + + char luid[8]; + unsigned int nodeMask; + if (cuDeviceGetLuid(luid, &nodeMask, device) == CUDA_SUCCESS) { + device_to_luid[i] = *reinterpret_cast(luid); + } + } + */ + + int16_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; - if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && - factory->ort_api.HardwareDevice_VendorId(&device) == 0x10de) { - ORT_API_RETURN_IF_ERROR( - factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); + if (factory.ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && + factory.ort_api.HardwareDevice_VendorId(&device) == 0x10de) { + /* ideally we'd match on LUID here + for now we use an incrementing device id. could be a mismatch if you have multiple different CUDA GPUs. + alternative is to limit to one device only. + + // find the device id. On Windows we have the LUID in the OrtHardwareDevice metadata. + const OrtKeyValuePairs* metadata = factory.ort_api.HardwareDevice_Metadata(&device); + const char* luid_str = factory.ort_api.GetKeyValue(metadata, "LUID"); + + if (!luid_str && num_devices > 1) { + // if there's no LUID we can't match device + return factory.ort_api.CreateStatus(ORT_EP_FAIL, "OrtHardwareDevice does not have LUID"); + } + + char* luid_end = nullptr; + uint64_t luid = std::strtoull(luid_str, &luid_end, 10); + for (; device_id < num_cuda_devices; ++device_id) { + if (device_to_luid[device_id] == luid) { + break; + } + } + + if (device_id == num_cuda_devices) { + std::string msg("Could not match LUID to a CUDA device. LUID="); + msg += luid_str; + + return factory.ort_api.CreateStatus(ORT_EP_FAIL, msg.c_str()); + } + */ + + // create the EP options and add the device id + OrtKeyValuePairs* ep_metadata = nullptr; + OrtKeyValuePairs* ep_options = nullptr; + factory.ort_api.CreateKeyValuePairs(&ep_options); + factory.ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(device_id).c_str()); + + // create the OrtEpDevice + OrtEpDevice* ep_device = nullptr; + RETURN_IF_ERROR(factory.ort_api.GetEpApi()->CreateEpDevice(&factory, &device, ep_metadata, ep_options, + &ep_device)); + + factory.ort_api.ReleaseKeyValuePairs(ep_options); + + const OrtMemoryInfo* gpu_mem_info = factory.gpu_memory_infos[device_id].get(); + const OrtMemoryInfo* host_accessible_mem_info = factory.host_accessible_memory_infos[device_id].get(); + + RETURN_IF_ERROR(factory.ep_api.EpDevice_AddAllocatorInfo(ep_device, gpu_mem_info)); + RETURN_IF_ERROR(factory.ep_api.EpDevice_AddAllocatorInfo(ep_device, host_accessible_mem_info)); + + ep_devices[num_ep_devices++] = ep_device; + + ++device_id; } } @@ -378,10 +810,124 @@ struct CudaEpFactory : OrtEpFactory { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + // this function is free to return the same allocator instance for all calls and make ReleaseAllocator a no-op + // e.g. allocator instance is in unique_ptr in the OrtEpFactory instance. + // ORT will create a shared allocator in the environment and the user can choose to use it in an inference session. + // Otherwise ORT will create an allocator when adding the EP to an inference session. + auto& factory = *static_cast(this_ptr); + + auto cuda_allocator = std::make_unique(memory_info, factory.ort_api); + *allocator = cuda_allocator.release(); + + return nullptr; + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* allocator) noexcept { + delete static_cast(allocator); + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, + OrtDataTransferImpl** data_transfer) noexcept { + auto& factory = *static_cast(this_ptr); + *data_transfer = &factory.data_transfer_impl; + + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return true; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** ort_stream) noexcept { + auto& factory = *static_cast(this_ptr); + auto device_id = factory.ep_api.MemoryDevice_GetDeviceId(memory_device); + + // the OrtEpFactory could have a cache of stream instances if it wants to avoid creating a new one on every + // call. the CudaStreamSyncImpl::Release could return the instance to the cache. + cudaStream_t stream = nullptr; + CUDA_RETURN_IF_ERROR(cudaSetDevice(device_id)); + CUDA_RETURN_IF_ERROR(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); + + // Currently this API is only used for creating a stream that is used outside of a session, as we're using the + // 'real' CUDA IExecutionProvider implementation for the EP. Due to that we need to connect it up to an internal + // onnxruntime::Stream that has the correct settings for the session. + // We do that externally by passing the cudaStream_t in via the "user_compute_stream" provider option. + // + // For use within an inference session in a completely plugin EP we'd need the session's CPU allocator to be + // available, as well as for relevant EP instance specific options such as whether graph capture is enabled + // to be applied. + + const OrtDevice* ort_device = static_cast(memory_device); + // This OrtSyncStream isn't used for running the inference, so we don't need a CPU allocator for + // CPU scratch buffers to be created by operator kernels. + AllocatorPtr null_allocator; + + auto impl = std::make_unique(std::move(stream), *ort_device, nullptr, + /*release_cpu_buffer_on_cuda_stream*/ true, + factory.ort_api); + *ort_stream = impl.release(); + + return nullptr; + } + + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { + gpu_memory_infos.reserve(num_devices); + host_accessible_memory_infos.reserve(num_devices); + + for (int device_id = 0; device_id < num_devices; ++device_id) { + OrtMemoryInfo* mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CUDA", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_DEFAULT, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + + gpu_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + + // HOST_ACCESSIBLE memory should use the non-CPU device type + mem_info = nullptr; + RETURN_IF_ERROR(ort_api.CreateMemoryInfo_V2("CUDA host accessible", OrtMemoryInfoDeviceType_GPU, + /*vendor*/ OrtDevice::VendorIds::NVIDIA, + /* device_id */ device_id, + OrtDeviceMemoryType_HOST_ACCESSIBLE, + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, + &mem_info)); + + host_accessible_memory_infos.emplace_back(MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo)); + } + + return nullptr; + } + const OrtApi& ort_api; + const OrtEpApi& ep_api; const std::string ep_name{kCudaExecutionProvider}; // EP name const std::string vendor{"Microsoft"}; // EP vendor name uint32_t vendor_id{0x1414}; // Microsoft vendor ID + + // per-device memory info + std::vector gpu_memory_infos; + std::vector host_accessible_memory_infos; + + // we use a shared instance for the OrtDataTransferImpl instead of creating a new one on every call to + // CreateDataTransferImpl. + CudaDataTransferImpl data_transfer_impl; + + CudaEpFactory(const CudaEpFactory&) = delete; + CudaEpFactory& operator=(const CudaEpFactory&) = delete; + + CudaEpFactory(CudaEpFactory&&) = default; + CudaEpFactory& operator=(CudaEpFactory&&) = default; }; extern "C" { @@ -391,6 +937,7 @@ extern "C" { OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + ErrorHelper::ort_api = ort_api; // setup our error helper // Factory could use registration_name or define its own EP name. std::unique_ptr factory = std::make_unique(*ort_api); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index c5b6507ac847b..286db9070766d 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -1304,7 +1304,7 @@ std::vector NvExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { - return std::make_unique(CUDA_PINNED, device_id); + return std::make_unique(device_id, CUDA_PINNED); }, narrow(device_id_)); @@ -2281,11 +2281,16 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr if (max_shared_mem_size_ > 0) { trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kTACTIC_SHARED_MEMORY, max_shared_mem_size_); } - // Only set default compute capabilities if user hasn't explicitly configured them - constexpr int kDefaultNumComputeCapabilities = 1; // Default number of compute capabilities for Turing support - if (trt_config->getNbComputeCapabilities() == 0) { - trt_config->setNbComputeCapabilities(kDefaultNumComputeCapabilities); - trt_config->setComputeCapability(nvinfer1::ComputeCapability::kCURRENT, 0); + + // Only set compute capability for Turing + const std::string kTuringComputeCapability{"75"}; + + if (compute_capability_ == kTuringComputeCapability) { + constexpr int kDefaultNumComputeCapabilities = 1; + if (trt_config->getNbComputeCapabilities() == 0) { + trt_config->setNbComputeCapabilities(kDefaultNumComputeCapabilities); + trt_config->setComputeCapability(nvinfer1::ComputeCapability::kSM75, 0); + } } int num_inputs = trt_network->getNbInputs(); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index 0fc3e5443bc28..83edc6ccdd313 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -115,6 +115,23 @@ struct Nv_Provider : Provider { return std::make_shared(info); } + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t /*num_devices*/, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + const ConfigOptions* config_options = &session_options.GetConfigOptions(); + + std::array configs_array = {&provider_options, config_options}; + const void* arg = reinterpret_cast(&configs_array); + auto ep_factory = CreateExecutionProviderFactory(arg); + ep = ep_factory->CreateProvider(session_options, logger); + + return Status::OK(); + } + void Initialize() override { InitializeRegistry(); } @@ -133,3 +150,118 @@ ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } } + +#include "core/framework/error_code_helper.h" + +// OrtEpApi infrastructure to be able to use the NvTensorRTRTX EP as an OrtEpFactory for auto EP selection. +struct NvTensorRtRtxEpFactory : OrtEpFactory { + NvTensorRtRtxEpFactory(const OrtApi& ort_api_in, + const char* ep_name, + OrtHardwareDeviceType hw_type) + : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} { + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + } + + // Returns the name for the EP. Each unique factory configuration must have a unique name. + // Ex: a factory that supports NPU should have a different than a factory that supports GPU. + static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->ep_name.c_str(); + } + + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + const auto* factory = static_cast(this_ptr); + return factory->vendor.c_str(); + } + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. + // An EP created with this factory is expected to be able to execute a model with *all* supported + // hardware devices at once. A single instance of NvTensorRtRtx EP is not currently setup to partition a model among + // multiple different NvTensorRtRtx backends at once (e.g, npu, cpu, gpu), so this factory instance is set to only + // support one backend: gpu. To support a different backend, like npu, create a different factory instance + // that only supports NPU. + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) { + size_t& num_ep_devices = *p_num_ep_devices; + auto* factory = static_cast(this_ptr); + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == factory->ort_hw_device_type && + factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + OrtKeyValuePairs* ep_options = nullptr; + factory->ort_api.CreateKeyValuePairs(&ep_options); + ORT_API_RETURN_IF_ERROR( + factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, nullptr, ep_options, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; + } + + static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[NvTensorRTRTX EP] EP factory does not support this method."); + } + + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + // no-op as we never create an EP here. + } + + const OrtApi& ort_api; + const std::string ep_name; + const std::string vendor{"NVIDIA"}; + + // NVIDIA vendor ID. Refer to the ACPI ID registry (search NVIDIA): https://uefi.org/ACPI_ID_List + const uint32_t vendor_id{0x10de}; + const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice +}; + +extern "C" { +// +// Public symbols +// +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + + // Factory could use registration_name or define its own EP name. + auto factory_gpu = std::make_unique(*ort_api, + onnxruntime::kNvTensorRTRTXExecutionProvider, + OrtHardwareDeviceType_GPU); + + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + + factories[0] = factory_gpu.release(); + *num_factories = 1; + + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} +} diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def b/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 65532c31e14bd..28804d2f76492 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -120,7 +120,7 @@ BackendManager::BackendManager(SessionContext& session_context, (session_context_.device_type.find("CPU") != std::string::npos || session_context_.device_type.find("GPU") != std::string::npos || (session_context_.device_type.find("NPU") != std::string::npos && - session_context_.enable_causallm) )) || + session_context_.enable_causallm))) || (subgraph_context_.is_ep_ctx_graph)) { LOGS_DEFAULT(INFO) << "[OpenVINO-EP] Starting backend initialization. " << "Creating backend Dynamic Shapes"; @@ -443,7 +443,7 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, ORT_ENFORCE(status.IsOK(), status.ErrorMessage()); return model_proto; } else if ((session_context_.device_type.find("GPU") != std::string::npos) && - enable_ovep_qdq_optimizer) { + enable_ovep_qdq_optimizer) { // Create a copy of the model std::unique_ptr model; Status status = qdq_scales_fix::Transform(subgraph, logger, model); diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index a2067ce10485c..f6bc5ad599e18 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -133,7 +133,7 @@ class OVInferRequest { auto tensor_ptr = std::make_shared(type, shape, const_cast(ort_ptr)); SetTensor(name, tensor_ptr); cached_binding = {tensor_ptr, ort_ptr}; - } else if (ort_ptr==nullptr) { + } else if (ort_ptr == nullptr) { // a null ort_ptr is expected for a tensor that has 0 elements. // for example, a tensor of shape=[1, 8, 0, 64], which is valid. // So, we check to see if at least one shape entry is 0. diff --git a/onnxruntime/core/providers/openvino/ov_protobuf_utils.h b/onnxruntime/core/providers/openvino/ov_protobuf_utils.h index 2a6d914ee2920..ba8f910cd9218 100644 --- a/onnxruntime/core/providers/openvino/ov_protobuf_utils.h +++ b/onnxruntime/core/providers/openvino/ov_protobuf_utils.h @@ -6,5 +6,5 @@ namespace onnxruntime { namespace openvino_ep { float get_float_initializer_data(const void* initializer); void set_float_initializer_data(const void* initializer, float data); -} +} // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp index 571aa57c99f33..c1e4815c206a2 100644 --- a/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp +++ b/onnxruntime/core/providers/openvino/qdq_transformations/qdq_scales_fix.cpp @@ -605,8 +605,7 @@ float get_initializer_value(const Graph& graph, const std::string& initializer_n auto size = get_initializer_size(graph, initializer_name); ORT_ENFORCE(size == 1, "Expected an initializer to be of size 1"); return raw_data[0]; - } - else + } else return get_float_initializer_data(p_initializer); } @@ -775,7 +774,6 @@ bool scale_graph(CustomGraph& gen_graph, return needs_second_run; } - Status copy_model(const GraphViewer& src_graph_viewer, const logging::Logger& logger, std::unique_ptr& model) { model = src_graph_viewer.CreateModel(logger); @@ -938,7 +936,7 @@ Status Transform(const GraphViewer& src_graph_viewer, bool scale_output{false}; auto needs_second_run = scale_graph(g, threshold, ratio, scale_output); if (needs_second_run) - scale_graph(g, threshold * 100, ratio, scale_output); + scale_graph(g, threshold * 100, ratio, scale_output); return status; } } // namespace qdq_scales_fix diff --git a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc index 785177ce37788..6dcb64940dec6 100644 --- a/onnxruntime/core/providers/qnn/qnn_provider_factory.cc +++ b/onnxruntime/core/providers/qnn/qnn_provider_factory.cc @@ -133,6 +133,12 @@ struct QnnEpFactory : OrtEpFactory { GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } // Returns the name for the EP. Each unique factory configuration must have a unique name. @@ -201,6 +207,43 @@ struct QnnEpFactory : OrtEpFactory { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto& factory = *static_cast(this_ptr); + *allocator = nullptr; + + // we don't add allocator info to the OrtEpDevice we return so this should never be called. + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "QNN EP factory does not support CreateAllocator."); + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this*/, OrtAllocator* /*allocator*/) noexcept { + // we don't support CreateAllocator so this should never be called. + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // return nullptr to indicate that this EP does not support data transfer. + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** ort_stream) noexcept { + auto& factory = *static_cast(this_ptr); + *ort_stream = nullptr; + + // should never be called as IsStreamAware returns false + return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, + "QNN EP factory does not support CreateSyncStreamForDevice."); + } + const OrtApi& ort_api; const std::string ep_name; // EP name const std::string ep_vendor{"Microsoft"}; // EP vendor name diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc index fbccd7d4a286b..69fcda98dc75f 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc @@ -298,6 +298,8 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), false /* serialize refitted engine to disk */, detailed_build_log_); @@ -367,6 +369,8 @@ Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph make_secure_path_checks, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, (*trt_engine_).get(), true /* serialize refitted engine to disk */, detailed_build_log_); diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index 3af0143cbf14e..e89b047919cd6 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -54,6 +54,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, bool detailed_build_log) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), @@ -63,6 +65,8 @@ class TensorRTCacheModelHandler { onnx_model_folder_path_(onnx_model_folder_path), onnx_model_bytestream_(onnx_model_bytestream), onnx_model_bytestream_size_(onnx_model_bytestream_size), + onnx_external_data_bytestream_(onnx_external_data_bytestream), + onnx_external_data_bytestream_size_(onnx_external_data_bytestream_size), detailed_build_log_(detailed_build_log) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler); @@ -80,6 +84,8 @@ class TensorRTCacheModelHandler { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_; + size_t onnx_external_data_bytestream_size_; bool detailed_build_log_; }; // TRTCacheModelHandler } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 1121775bf5ef7..64be445b4c15c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1395,6 +1395,14 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv "When providing either 'trt_onnx_bytestream_size' or " "'trt_onnx_bytestream' both have to be provided")); } + onnx_external_data_bytestream_ = info.external_data_bytestream; + onnx_external_data_bytestream_size_ = info.external_data_bytestream_size; + if ((onnx_external_data_bytestream_ != nullptr && onnx_external_data_bytestream_size_ == 0) || + (onnx_external_data_bytestream_ == nullptr && onnx_external_data_bytestream_size_ != 0)) { + ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "When providing either 'trt_external_data_bytestream_size' or " + "'trt_external_data_bytestream' both have to be provided")); + } timing_cache_enable_ = info.timing_cache_enable; force_timing_cache_match_ = info.force_timing_cache; detailed_build_log_ = info.detailed_build_log; @@ -1435,6 +1443,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv engine_hw_compatible_ = info.engine_hw_compatible; op_types_to_exclude_ = info.op_types_to_exclude; preview_features_ = ParseTrtPreviewFeatures(info.preview_features); + load_user_initializer_ = info.load_user_initializer; } else { try { const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations); @@ -1836,7 +1845,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv << ", trt_cache_prefix: " << cache_prefix_ << ", trt_engine_hw_compatible: " << engine_hw_compatible_ << ", trt_onnx_model_bytestream_size_: " << onnx_model_bytestream_size_ - << ", trt_op_types_to_exclude: " << op_types_to_exclude_; + << ", trt_onnx_external_data_bytestream_size: " << onnx_external_data_bytestream_size_ + << ", trt_op_types_to_exclude: " << op_types_to_exclude_ + << ", trt_load_user_initializer: " << load_user_initializer_; } TensorrtExecutionProvider::~TensorrtExecutionProvider() { @@ -2318,7 +2329,24 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + // Save Initializer Data. + + std::vector userWeights; + + // Keep inits in memory instead of writing to ModelProto. + if (load_user_initializer_) { + auto allInitializers = graph_viewer->GetAllInitializedTensors(); + + for (auto entry : allInitializers) { + auto* tp = entry.second; + if (tp->has_raw_data()) { + userWeights.push_back( + TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + } + } + } + + graph_viewer->ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !load_user_initializer_ /*include_initializer_data*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; @@ -2343,9 +2371,22 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + bool is_model_supported = false; #if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 - auto is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + if (load_user_initializer_) { + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : userWeights) { + trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + } + is_model_supported = trt_parser->parseModelProto(); + } else { + is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); + } +#else + is_model_supported = trt_parser->supportsModelV2(string_buf.data(), string_buf.size(), model_path_); +#endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 // Note: Calling getNbSubgraphs or getSubgraphNodes before calling supportsModelV2 results in undefined behavior. auto num_subgraphs = trt_parser->getNbSubgraphs(); @@ -2363,7 +2404,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } #else trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); -#endif +#endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 1) || NV_TENSORRT_MAJOR > 10 SubGraphCollection_t next_nodes_list; const std::vector& subgraph_node_index = graph_viewer->GetNodesInTopologicalOrder(1 /*priority-based topological sort*/); @@ -2804,11 +2845,15 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log) { #if NV_TENSORRT_MAJOR >= 10 bool refit_from_file = onnx_model_bytestream == nullptr && onnx_model_bytestream_size == 0; + bool refit_with_external_data = onnx_external_data_bytestream != nullptr && onnx_external_data_bytestream_size != 0; + bool refit_complete = false; std::filesystem::path onnx_model_path{onnx_model_folder_path}; if (refit_from_file) { if (!onnx_model_filename.empty()) { @@ -2845,17 +2890,134 @@ common::Status TensorrtExecutionProvider::RefitEngine(std::string onnx_model_fil auto refitter = std::unique_ptr(nvinfer1::createInferRefitter(*trt_engine, trt_logger)); auto parser_refitter = std::unique_ptr( nvonnxparser::createParserRefitter(*refitter, trt_logger)); - if (refit_from_file) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string(); - if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + // New refit APIs + if (refit_with_external_data) { + // A valid model bytestream must be passed. + if (refit_from_file) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + "TensorRT EP's refit with external data must be called with a valid ONNX model bytestream"); } - } else { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array"; - if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + + if (!parser_refitter->loadModelProto(onnx_model_bytestream, onnx_model_bytestream_size, nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not load model from provided onnx_model_bytestream"); + } + + // Extract weight information from the Refitter. + int required_weights = refitter->getAllWeights(0, nullptr); + std::vector refit_names(required_weights); + refitter->getAllWeights(required_weights, refit_names.data()); + + // Vectors to keep track of data pointers. + std::vector names; + names.reserve(required_weights); + std::vector bytes; + bytes.reserve(required_weights); + std::vector sizes; + sizes.reserve(required_weights); + + if (refit_with_external_data) { + auto onnx_model = ModelProto::Create(); + TensorProtos* allInitializers_byte_stream; + + // Reconstruct onnx model view. + const auto onnx_model_view = std::string((const char*)onnx_model_bytestream, + onnx_model_bytestream_size); + if (!onnx_model->ParseFromString(onnx_model_view)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "The provided ONNX bytestream to refit could not be parsed."); + } + + // Extract graph and initializer information. + auto const& graph = onnx_model->mutable_graph(); + allInitializers_byte_stream = graph->mutable_initializer(); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializers that were found " << allInitializers_byte_stream->size(); + + // Loop through all initializers + for (int initializer_idx = 0; initializer_idx < allInitializers_byte_stream->size(); ++initializer_idx) { + auto& proto = allInitializers_byte_stream->at(initializer_idx); + auto& proto_name = proto.name(); + bool weight_is_refittable = std::find(refit_names.begin(), refit_names.end(), proto_name) != refit_names.end(); + if (weight_is_refittable) { + if (proto.has_data_location()) { + if (proto.data_location() == TensorProto_DataLocation_EXTERNAL) { + // Default values for reading into external_data blob. + int64_t offset = 0; + size_t length = 0; + auto external_data = proto.mutable_external_data(); + const std::string kOffset = "offset", kLength = "length"; + for (int entry_idx = 0; entry_idx < external_data->size(); ++entry_idx) { + auto current_key = external_data->at(entry_idx).mutable_key(); + auto current_value = external_data->at(entry_idx).mutable_value(); + if (*current_key == kOffset && !current_value->empty()) { + offset = std::stoll(*current_value); + } else if (*current_key == kLength && !current_value->empty()) { + length = std::stoul(*current_value); + } + } + names.push_back(proto.name()); + bytes.push_back(static_cast(onnx_external_data_bytestream) + offset); + sizes.push_back(length); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] Proto: " + proto_name + " expected to have external datalocation, but default datalocation was provided instead."); + } + } else { + if (!proto.has_raw_data()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "[TensorRT EP] Proto: " + proto_name + " has no raw data"); + } + auto& raw_data = proto.raw_data(); + names.push_back(proto.name()); + bytes.push_back(raw_data.c_str()); + sizes.push_back(raw_data.size()); + } + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Initializer with name: " << proto_name << " was not marked as refittable"; + } + } + } + + // Load extracted initializers into the parser + if (!names.empty()) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Number of initializers submitted to refitter " << names.size(); + for (size_t i = 0; i < names.size(); i++) { + bool refloadInit = parser_refitter->loadInitializer(names[i].c_str(), bytes[i], sizes[i]); + if (!refloadInit) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } + } + } + // Perform refit. + if (!parser_refitter->refitModelProto()) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestraem"); + "TensorRT EP's IParserRefitter refitModelProto() failed with the provided external data bytestream."); + } + refit_complete = true; + } +#else + // Refitting with external data is not supported prior to TensorRT 10.13. Log a warning in this case for the user. + if (refit_with_external_data) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Refitting with an onnx_external_data_bytestream is only supported on TensorRT versions >= 10.13! This parameter will be ignored for refitting, and the resulting refitted engine may be incorrect."; + } +#endif // (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + // If new refit flow was not completed, then fallback to refit_from_file. + if (!refit_complete) { + if (refit_from_file) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from file on disk: " << onnx_model_path.string(); + if (!parser_refitter->refitFromFile(onnx_model_path.string().c_str())) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in: " + onnx_model_path.string()); + } + } else { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Refitting from byte array"; + if (!parser_refitter->refitFromBytes(onnx_model_bytestream, onnx_model_bytestream_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP's IParserRefitter could not refit deserialized weight-stripped engine with weights contained in the provided bytestream"); + } } } if (refitter->refitCudaEngine()) { @@ -2926,11 +3088,26 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto model = graph_body_viewer.CreateModel(*GetLogger()); auto model_proto = model->ToProto(); + auto userWeights = std::make_unique>(); + + if (load_user_initializer_) { + auto allInitializers = graph_body_viewer.GetAllInitializedTensors(); + + for (auto entry : allInitializers) { + auto name = entry.first; + auto* tp = entry.second; + if (tp->has_raw_data()) { + userWeights->push_back( + TensorrtUserWeights{tp->name(), tp->raw_data(), (int64_t)tp->raw_data().size()}); + } + } + } + // ORT's default topological sort is using reversed DFS. // When creating model proto from graph viewer, let ORT use priority-based topological sort based on node index. // The reason is, in some cases, for example ResNet50, using default topological sort will end up with generating // the model proto that has different node ordering compared to original onnx model. - graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/); + graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true, 1 /*priority-based topological sort*/, !load_user_initializer_ /*include_initializer_data*/); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string string_buf; model_proto->SerializeToString(string_buf); @@ -2952,7 +3129,20 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); + +#if (NV_TENSORRT_MAJOR == 10 && NV_TENSORRT_MINOR > 12) || NV_TENSORRT_MAJOR > 10 + if (load_user_initializer_) { + trt_parser->loadModelProto(string_buf.data(), string_buf.size(), model_path_); + for (auto const& userWeight : *userWeights) { + trt_parser->loadInitializer(userWeight.name.c_str(), static_cast(userWeight.data.c_str()), userWeight.size); + } + trt_parser->parseModelProto(); + } else { + trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); + } +#else trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); +#endif if (max_workspace_size_ > 0) { trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_); } @@ -3489,6 +3679,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView false /* path check for security */, onnx, onnx_size, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, trt_engine.get(), true /* serialize refitted engine to disk */, detailed_build_log_); @@ -3545,6 +3737,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView engines_.emplace(fused_node.Name(), std::move(trt_engine)); contexts_.emplace(fused_node.Name(), std::move(trt_context)); networks_.emplace(fused_node.Name(), std::move(trt_network)); + weights_.emplace(fused_node.Name(), std::move(userWeights)); input_info_[fused_node.Name()].push_back(input_indexes); output_info_[fused_node.Name()].push_back(output_indexes); output_info_[fused_node.Name()].push_back(output_types); @@ -3597,7 +3790,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_, builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics, cuda_graph_enable_, cache_prefix_, cache_suffix, engine_hw_compatible_, - preview_features_}; + preview_features_, &weights_[context->node_name]}; *state = p.release(); return 0; }; @@ -3975,6 +4168,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView false /* path check for security */, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, trt_engine, true /* serialize refitted engine to disk */, detailed_build_log_); @@ -4208,6 +4403,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con onnx_model_folder_path_, onnx_model_bytestream_, onnx_model_bytestream_size_, + onnx_external_data_bytestream_, + onnx_external_data_bytestream_size_, detailed_build_log_); auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer); if (status != Status::OK()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 7e02cf7590f66..dba17f7822eac 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -157,6 +157,13 @@ class OutputAllocator : public nvinfer1::IOutputAllocator { */ using ShapeRangesMap = std::unordered_map>>>; +// Struct to hold user weights when ModelProtos are serialized with data. +struct TensorrtUserWeights { + std::string name{}; + std::string data{}; + int64_t size{}; +}; + // Information to construct kernel function state. struct TensorrtFuncState { AllocateFunc test_allocate_func = nullptr; @@ -205,6 +212,7 @@ struct TensorrtFuncState { std::string cache_suffix; bool engine_hw_compatible = false; std::vector preview_features; + std::unique_ptr>* userWeights = nullptr; }; // Minimum information to construct kernel function state for direct engine load code path @@ -287,6 +295,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool path_check, const void* onnx_model_bytestream, size_t onnx_model_bytestream_size, + const void* onnx_external_data_bytestream, + size_t onnx_external_data_bytestream_size, nvinfer1::ICudaEngine* trt_engine, bool serialize_refitted_engine, bool detailed_build_log); @@ -314,6 +324,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::string onnx_model_folder_path_; const void* onnx_model_bytestream_; size_t onnx_model_bytestream_size_; + const void* onnx_external_data_bytestream_ = nullptr; + size_t onnx_external_data_bytestream_size_ = 0; bool build_heuristics_enable_ = false; bool sparsity_enable_ = false; int builder_optimization_level_ = 3; @@ -340,6 +352,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { bool engine_hw_compatible_ = false; std::string op_types_to_exclude_; std::vector preview_features_; + bool load_user_initializer_ = false; // The format is as for TENSORRT_VERSION: (MAJOR * 100 + MINOR) * 100 + PATCH int32_t trt_version_; @@ -380,6 +393,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; + std::unordered_map>> weights_; // User provided weights. // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture cudnnHandle_t external_cudnn_handle_ = nullptr; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc index 1a515c37f7ecb..17457b11f8cef 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc @@ -57,8 +57,11 @@ constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model"; constexpr const char* kEngineHwCompatible = "trt_engine_hw_compatible"; constexpr const char* kONNXBytestream = "trt_onnx_bytestream"; constexpr const char* kONNXBytestreamSize = "trt_onnx_bytestream_size"; +constexpr const char* kExternalDataBytestream = "trt_external_data_bytestream"; +constexpr const char* kExternalDataBytestreamSize = "trt_external_data_bytestream_size"; constexpr const char* kOpTypesToExclude = "trt_op_types_to_exclude"; constexpr const char* kPreviewFeatures = "trt_preview_features"; +constexpr const char* kGraphIncludeInitializer = "trt_load_user_initializer"; } // namespace provider_option_names } // namespace tensorrt @@ -67,6 +70,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions TensorrtExecutionProviderInfo info{}; void* user_compute_stream = nullptr; void* onnx_bytestream = nullptr; + void* external_data_bytestream = nullptr; ORT_THROW_IF_ERROR( ProviderOptionsParser{} .AddValueParser( @@ -138,13 +142,24 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions return Status::OK(); }) .AddAssignmentToReference(tensorrt::provider_option_names::kONNXBytestreamSize, info.onnx_bytestream_size) + .AddValueParser( + tensorrt::provider_option_names::kExternalDataBytestream, + [&external_data_bytestream](const std::string& value_str) -> Status { + size_t address; + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, address)); + external_data_bytestream = reinterpret_cast(address); + return Status::OK(); + }) + .AddAssignmentToReference(tensorrt::provider_option_names::kExternalDataBytestreamSize, info.external_data_bytestream_size) .AddAssignmentToReference(tensorrt::provider_option_names::kOpTypesToExclude, info.op_types_to_exclude) .AddAssignmentToReference(tensorrt::provider_option_names::kPreviewFeatures, info.preview_features) + .AddAssignmentToReference(tensorrt::provider_option_names::kGraphIncludeInitializer, info.load_user_initializer) .Parse(options)); // add new provider option here. info.user_compute_stream = user_compute_stream; info.has_user_compute_stream = (user_compute_stream != nullptr); info.onnx_bytestream = onnx_bytestream; + info.external_data_bytestream = external_data_bytestream; return info; } @@ -195,8 +210,11 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.engine_hw_compatible)}, {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(info.onnx_bytestream)}, {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.onnx_bytestream_size)}, + {tensorrt::provider_option_names::kExternalDataBytestream, MakeStringWithClassicLocale(info.external_data_bytestream)}, + {tensorrt::provider_option_names::kExternalDataBytestreamSize, MakeStringWithClassicLocale(info.external_data_bytestream_size)}, {tensorrt::provider_option_names::kOpTypesToExclude, MakeStringWithClassicLocale(info.op_types_to_exclude)}, {tensorrt::provider_option_names::kPreviewFeatures, MakeStringWithClassicLocale(info.preview_features)}, + {tensorrt::provider_option_names::kGraphIncludeInitializer, MakeStringWithClassicLocale(info.load_user_initializer)}, }; return options; } @@ -262,7 +280,10 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor {tensorrt::provider_option_names::kEngineHwCompatible, MakeStringWithClassicLocale(info.trt_engine_hw_compatible)}, {tensorrt::provider_option_names::kONNXBytestream, MakeStringWithClassicLocale(reinterpret_cast(info.trt_onnx_bytestream))}, {tensorrt::provider_option_names::kONNXBytestreamSize, MakeStringWithClassicLocale(info.trt_onnx_bytestream_size)}, + {tensorrt::provider_option_names::kExternalDataBytestream, MakeStringWithClassicLocale(reinterpret_cast(info.trt_external_data_bytestream))}, + {tensorrt::provider_option_names::kExternalDataBytestreamSize, MakeStringWithClassicLocale(info.trt_external_data_bytestream_size)}, {tensorrt::provider_option_names::kOpTypesToExclude, kOpTypesToExclude_}, + {tensorrt::provider_option_names::kGraphIncludeInitializer, MakeStringWithClassicLocale(info.trt_load_user_initializer)}, }; return options; } @@ -368,7 +389,10 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options trt_provider_options_v2.trt_engine_hw_compatible = internal_options.engine_hw_compatible; trt_provider_options_v2.trt_onnx_bytestream = internal_options.onnx_bytestream; trt_provider_options_v2.trt_onnx_bytestream_size = internal_options.onnx_bytestream_size; + trt_provider_options_v2.trt_external_data_bytestream = internal_options.external_data_bytestream; + trt_provider_options_v2.trt_external_data_bytestream_size = internal_options.external_data_bytestream_size; trt_provider_options_v2.trt_op_types_to_exclude = copy_string_if_needed(internal_options.op_types_to_exclude); trt_provider_options_v2.trt_preview_features = copy_string_if_needed(internal_options.preview_features); + trt_provider_options_v2.trt_load_user_initializer = internal_options.load_user_initializer; } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h index a7c3624674dc6..f0bd653de471a 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h @@ -37,6 +37,8 @@ struct TensorrtExecutionProviderInfo { std::string onnx_model_folder_path{""}; const void* onnx_bytestream{nullptr}; size_t onnx_bytestream_size{0}; + const void* external_data_bytestream{nullptr}; + size_t external_data_bytestream_size{0}; bool engine_decryption_enable{false}; std::string engine_decryption_lib_path{""}; bool force_sequential_engine_build{false}; @@ -63,6 +65,7 @@ struct TensorrtExecutionProviderInfo { bool engine_hw_compatible{false}; std::string op_types_to_exclude{""}; std::string preview_features{""}; + bool load_user_initializer{false}; static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index da1c2514bf6a2..71ea66b0be89f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -119,8 +119,11 @@ struct Tensorrt_Provider : Provider { info.engine_hw_compatible = options.trt_engine_hw_compatible != 0; info.onnx_bytestream = options.trt_onnx_bytestream; info.onnx_bytestream_size = options.trt_onnx_bytestream_size; + info.external_data_bytestream = options.trt_external_data_bytestream; + info.external_data_bytestream_size = options.trt_external_data_bytestream_size; info.op_types_to_exclude = options.trt_op_types_to_exclude == nullptr ? "" : options.trt_op_types_to_exclude; info.preview_features = options.trt_preview_features == nullptr ? "" : options.trt_preview_features; + info.load_user_initializer = options.trt_load_user_initializer != 0; return std::make_shared(info); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 60f115ca50da4..14f12c906f11a 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -42,6 +42,35 @@ using namespace onnxruntime; #define LIBRARY_EXTENSION ".so" #endif +/// @brief Gets the path of directory containing the dynamic library that contains the address. +/// @param address An address of a function or variable in the dynamic library. +/// @return The path of the directory containing the dynamic library, or an empty string if the path cannot be determined. +static onnxruntime::PathString GetDynamicLibraryLocationByAddress(const void* address) { +#ifdef _WIN32 + HMODULE moduleHandle; + if (!::GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(address), &moduleHandle)) { + return {}; + } + std::wstring buffer; + for (std::uint32_t size{70}; size < 4096; size *= 2) { + buffer.resize(size, L'\0'); + const std::uint32_t requiredSize = ::GetModuleFileNameW(moduleHandle, buffer.data(), size); + if (requiredSize == 0) { + break; + } + if (requiredSize == size) { + continue; + } + buffer.resize(requiredSize); + return {std::move(buffer)}; + } +#else + std::ignore = address; +#endif + return {}; +} + vaip_core::OrtApiForVaip* create_org_api_hook(); struct OrtVitisAIEpAPI { void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); @@ -74,8 +103,20 @@ struct OrtVitisAIEpAPI { // this dll is already linked to the executable, normally a test program handle_ = reinterpret_cast(GetModuleHandle(TEXT("onnxruntime_vitisai_ep.dll"))); if (!handle_) { + // First try loading with full path + auto library_filename = PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); - ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + if (std::filesystem::exists(full_path)) { + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(full_path, true, &handle_)); + } else { + // Identify the path of the current dynamic library, and expect that onnxruntime_vitisai_ep is in the same directory. + PathString current_path = GetDynamicLibraryLocationByAddress(reinterpret_cast(create_org_api_hook)); + if (!current_path.empty()) { + const std::filesystem::path parent_path = std::filesystem::path{std::move(current_path)}.parent_path(); + PathString module_relative_full_path = PathString(parent_path / library_filename); + ORT_THROW_IF_ERROR(env.LoadDynamicLibrary(module_relative_full_path, true, &handle_)); + } + } } #else auto full_path = env.GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); diff --git a/onnxruntime/core/providers/vitisai/symbols.def b/onnxruntime/core/providers/vitisai/symbols.def index 4ec2f7914c208..3afed01da1966 100644 --- a/onnxruntime/core/providers/vitisai/symbols.def +++ b/onnxruntime/core/providers/vitisai/symbols.def @@ -1,2 +1,4 @@ EXPORTS GetProvider + CreateEpFactories + ReleaseEpFactory diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 6849bcfc21f88..1ef63588a1685 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -57,9 +57,6 @@ std::unique_ptr VitisAIProviderFactory::CreateProvider(const } } - // Store pointer to session options as done in SessionOptionsAppendExecutionProvider_VitisAI - provider_options["session_options"] = std::to_string((uintptr_t)(void*)&session_options); - auto ep_instance = std::make_unique(provider_options); ep_instance->SetLogger(reinterpret_cast(&session_logger)); return ep_instance; @@ -89,8 +86,101 @@ struct VitisAI_Provider : Provider { void Initialize() override { initialize_vitisai_ep(); } // Called right before unloading the shared library void Shutdown() override { deinitialize_vitisai_ep(); } + + Status CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata*/, + size_t /*num_devices*/, + ProviderOptions& provider_options, + const OrtSessionOptions& session_options, + const OrtLogger& logger, + std::unique_ptr& ep) override { + auto ep_factory = CreateExecutionProviderFactory(&provider_options); + ep = ep_factory->CreateProvider(session_options, logger); + return Status::OK(); + } } g_provider; +struct VitisAIEpFactory : OrtEpFactory { + VitisAIEpFactory(const OrtApi& ort_api_in) + : ort_api{ort_api_in} { + ort_version_supported = ORT_API_VERSION; + GetName = GetNameImpl; + GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; + CreateEp = CreateEpImpl; + ReleaseEp = ReleaseEpImpl; + } + + static const char* GetNameImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ep_name; + } + + static const char* GetVendorImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return vendor; + } + + static uint32_t GetVendorIdImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return hardware_vendor_id; + } + + static const char* ORT_API_CALL GetVersionImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return ORT_VERSION; + } + + static OrtStatus* GetSupportedDevicesImpl(OrtEpFactory* ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + VitisAIEpFactory* factory = static_cast(ep_factory); + + for (size_t i = 0; i < num_devices; ++i) { + const OrtHardwareDevice* hardware_device = devices[i]; + const std::uint32_t vendor_id = factory->ort_api.HardwareDevice_VendorId(hardware_device); + const OrtHardwareDeviceType device_type = factory->ort_api.HardwareDevice_Type(hardware_device); + + if ((vendor_id != VitisAIEpFactory::hardware_vendor_id) || + (device_type != OrtHardwareDeviceType_NPU)) { + continue; + } + + if (num_ep_devices == max_ep_devices) { + return factory->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "Not enough space to return EP devices."); + } + + auto status = factory->ort_api.GetEpApi()->CreateEpDevice(factory, hardware_device, nullptr, nullptr, + &ep_devices[num_ep_devices++]); + if (status != nullptr) { + return status; + } + } + return nullptr; + } + + static OrtStatus* CreateEpImpl(OrtEpFactory* /*this_ptr*/, + _In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) noexcept { + return CreateStatus(ORT_INVALID_ARGUMENT, "VitisAI EP factory does not support this method."); + } + + static void ReleaseEpImpl(OrtEpFactory*, OrtEp*) noexcept { + // no-op as we never create an EP here. + } + + const OrtApi& ort_api; + static constexpr const char* const ep_name{kVitisAIExecutionProvider}; + static constexpr std::uint32_t hardware_vendor_id{0x1022}; + static constexpr const char* const vendor{"AMD"}; +}; + } // namespace onnxruntime extern "C" { @@ -98,4 +188,21 @@ extern "C" { ORT_API(onnxruntime::Provider*, GetProvider) { return &onnxruntime::g_provider; } + +OrtStatus* CreateEpFactories(const char* /*registration_name*/, const OrtApiBase* ort_api_base, + OrtEpFactory** factories, size_t max_factories, size_t* num_factories) { + const OrtApi* ort_api = ort_api_base->GetApi(ORT_API_VERSION); + if (max_factories < 1) { + return ort_api->CreateStatus(ORT_INVALID_ARGUMENT, + "Not enough space to return EP factory. Need at least one."); + } + factories[0] = std::make_unique(*ort_api).release(); + *num_factories = 1; + return nullptr; +} + +OrtStatus* ReleaseEpFactory(OrtEpFactory* factory) { + delete static_cast(factory); + return nullptr; +} } diff --git a/onnxruntime/core/providers/webgpu/shader_variable.cc b/onnxruntime/core/providers/webgpu/shader_variable.cc index 79175370529e0..c197e227e2a8c 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.cc +++ b/onnxruntime/core/providers/webgpu/shader_variable.cc @@ -268,8 +268,8 @@ void ShaderVariableHelper::Impl(std::ostream& ss) const { // Implementation of "fn get_{name}_by_indices" if (usage_ & ShaderUsage::UseGetByIndices) { if (rank_ >= 2) { - SS_APPEND(ss, "fn get_", name_, "_by_indices(indices: ", IndicesType(), ")->", ValueType(), " {\n"); - SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices)"), ";\n"); + SS_APPEND(ss, "fn get_", name_, "_by_indices(indices_fnarg: ", IndicesType(), ")->", ValueType(), " {\n"); + SS_APPEND(ss, " return ", GetByOffset("i2o_" + name_ + "(indices_fnarg)"), ";\n"); SS_APPEND(ss, "}\n"); } } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 50e361ede221e..d53812f8f06e7 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -84,7 +84,7 @@ inline std::string GetTensorName(const ConstPointerContainer -inline std::vector GetNarrowedIntfromInt64(gsl::span int64_vec) { +inline std::vector GetNarrowedIntFromInt64(gsl::span int64_vec) { std::vector vec; vec.reserve(int64_vec.size()); std::transform(int64_vec.begin(), int64_vec.end(), diff --git a/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h b/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h index 2bdb29b5a9cce..a0251406fc36b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h +++ b/onnxruntime/core/providers/webnn/builders/impl/attention_helper.h @@ -9,8 +9,8 @@ namespace webnn { /* ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape) - Abbreviatios: B is batch_size, S is query sequence_length, kv_S is key/value sequence length, - N is number of attention heads, H is head size, W is hidden_size + Abbreviations: B is batch_size, S is query sequence_length, kv_S is key/value sequence length, + N is number of attention heads, H is head size, W is hidden_size query key | | diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index ae52e2cd5d936..7d075ea81777f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -49,7 +49,7 @@ class BaseOpBuilder : public IOpBuilder { // with opset version 7 or above for opset domain 'ai.onnx'. // WebNN EP ignores node support for opset less than 7 by // default as which will be fallback earlier by ONNX Runtime. - // We still set the mininal supported opset to 1 as we couldn't + // We still set the minimal supported opset to 1 as we couldn't // get the model opset version at this stage. virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 23; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index e0bfb3bd682e8..f75b6f41f7f9c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -78,7 +78,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, if (output_padding.size() == 1 && is_conv1d) { output_padding.push_back(0); } - options.set("outputPadding", emscripten::val::array(GetNarrowedIntfromInt64(output_padding))); + options.set("outputPadding", emscripten::val::array(GetNarrowedIntFromInt64(output_padding))); // If output shape is explicitly provided, compute the pads. // Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER. @@ -87,7 +87,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, auto_pad_type, pads_out, output_shape, !is_nhwc)); if (output_shape[0] != -1 && output_shape[1] != -1) { - options.set("outputSizes", emscripten::val::array(GetNarrowedIntfromInt64(output_shape))); + options.set("outputSizes", emscripten::val::array(GetNarrowedIntFromInt64(output_shape))); } pads = pads_out; } else { @@ -97,13 +97,13 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, const auto group = helper.Get("group", static_cast(1)); options.set("groups", group); - options.set("strides", emscripten::val::array(GetNarrowedIntfromInt64(strides))); - options.set("dilations", emscripten::val::array(GetNarrowedIntfromInt64(dilations))); + options.set("strides", emscripten::val::array(GetNarrowedIntFromInt64(strides))); + options.set("dilations", emscripten::val::array(GetNarrowedIntFromInt64(dilations))); // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. const std::vector padding{pads[0], pads[2], pads[1], pads[3]}; - options.set("padding", emscripten::val::array(GetNarrowedIntfromInt64(padding))); + options.set("padding", emscripten::val::array(GetNarrowedIntFromInt64(padding))); // Add bias if present. if (input_defs.size() > 2 && op_type != "ConvInteger") { @@ -123,7 +123,7 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, const auto& shape = tensor.dims(); std::vector dims = - GetNarrowedIntfromInt64(std::vector(std::begin(shape), std::end(shape))); + GetNarrowedIntFromInt64(std::vector(std::begin(shape), std::end(shape))); if (is_conv1d) { // Support conv1d by prepending a 1 size dimension. @@ -172,21 +172,21 @@ Status AddInitializerInNewLayout(ModelBuilder& model_builder, h * w_t + w; - uint32_t nnapi_idx; + uint32_t wnn_idx; if (is_conv == 1) { // L_0231 - nnapi_idx = out * h_t * w_t * in_t + - h * w_t * in_t + - w * in_t + - in; + wnn_idx = out * h_t * w_t * in_t + + h * w_t * in_t + + w * in_t + + in; } else { // L_1230 for depthwise conv weight - nnapi_idx = in * h_t * w_t * out_t + - h * w_t * out_t + - w * out_t + - out; + wnn_idx = in * h_t * w_t * out_t + + h * w_t * out_t + + w * out_t + + out; } for (size_t i = 0; i < element_size; i++) { - buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i]; + buffer[element_size * wnn_idx + i] = src[element_size * onnx_idx + i]; } } } @@ -234,7 +234,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N } else { input_shape.push_back(1); } - std::vector new_shape = GetNarrowedIntfromInt64(input_shape); + std::vector new_shape = GetNarrowedIntFromInt64(input_shape); common_options.set("label", node.Name() + "_reshape_input"); input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape), common_options); @@ -283,7 +283,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // Reshape weight to 4D for conv1d. if (!is_nhwc || !is_constant_weight) { // The weight_shape has been appended 1's, reshape weight operand. - std::vector new_shape = GetNarrowedIntfromInt64(weight_shape); + std::vector new_shape = GetNarrowedIntFromInt64(weight_shape); common_options.set("label", node.Name() + "_reshape_filter"); filter = model_builder.GetBuilder().call("reshape", filter, @@ -338,7 +338,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector w_zero_point_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[3], w_zero_point_shape, logger), "Cannot get shape of w_zero_point"); w_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, - GetNarrowedIntfromInt64(w_zero_point_shape)); + GetNarrowedIntFromInt64(w_zero_point_shape)); } else { w_zero_point = model_builder.CreateOrGetConstant(x_type, 0); w_scale = x_scale; @@ -363,7 +363,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& output_defs = node.OutputDefs(); std::vector output_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[0], output_shape, logger), "Cannot get output shape"); - std::vector new_shape = GetNarrowedIntfromInt64(output_shape); + std::vector new_shape = GetNarrowedIntFromInt64(output_shape); common_options.set("label", node.Name() + "_reshape_output"); output = model_builder.GetBuilder().call("reshape", output, diff --git a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc index 37a00fcb12abd..d3879c76e3f9e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/dropout_op_builder.cc @@ -53,7 +53,7 @@ Status DropoutOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (output_defs.size() > 1) { std::vector mask_shape; ORT_RETURN_IF_NOT(GetShape(*output_defs[1], mask_shape, logger), "Cannot get mask output's shape"); - std::vector dims = GetNarrowedIntfromInt64(mask_shape); + std::vector dims = GetNarrowedIntFromInt64(mask_shape); emscripten::val one_constant = model_builder.CreateOrGetConstant( ONNX_NAMESPACE::TensorProto_DataType_BOOL, 1, dims); diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index e9c03420de445..202f1ee6db746 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -56,7 +56,7 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(output_shape)); + emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntFromInt64(output_shape)); emscripten::val output = model_builder.GetBuilder().call("expand", input, output_shape_arr, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 7af17fdc5db78..6b1ad638a9d1a 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -55,14 +55,14 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // If the input A is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. if (a_shape.size() == 1) { a_shape.insert(a_shape.begin(), 1); - emscripten::val a_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(a_shape)); + emscripten::val a_shape_arr = emscripten::val::array(GetNarrowedIntFromInt64(a_shape)); common_options.set("label", node.Name() + "_reshape_a"); a = model_builder.GetBuilder().call("reshape", a, a_shape_arr, common_options); } // If the input B is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. if (b_shape.size() == 1) { b_shape.push_back(1); - emscripten::val b_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(b_shape)); + emscripten::val b_shape_arr = emscripten::val::array(GetNarrowedIntFromInt64(b_shape)); common_options.set("label", node.Name() + "_reshape_b"); b = model_builder.GetBuilder().call("reshape", b, b_shape_arr, common_options); } @@ -74,7 +74,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // If A or B input is 1-D, we need to reshape the output back to its original shape. if (a_shape.size() == 1 || b_shape.size() == 1) { common_options.set("label", node.Name() + "_reshape_output"); - emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(output_shape)); + emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntFromInt64(output_shape)); output = model_builder.GetBuilder().call("reshape", output, output_shape_arr, @@ -95,7 +95,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // The scale input should have the same shape as the zero point input. a_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, - GetNarrowedIntfromInt64(a_zero_point_shape)); + GetNarrowedIntFromInt64(a_zero_point_shape)); } else { // If a_zero_point is not provided, create default scalar for zero_point and scale inputs. a_zero_point = model_builder.CreateOrGetConstant(a_type, 0); @@ -115,7 +115,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N ORT_RETURN_IF_NOT(GetShape(*input_defs[3], b_zero_point_shape, logger), "Cannot get shape of b_zero_point"); b_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f, - GetNarrowedIntfromInt64(b_zero_point_shape)); + GetNarrowedIntFromInt64(b_zero_point_shape)); } else { b_zero_point = model_builder.CreateOrGetConstant(a_type, 0); b_scale = model_builder.CreateOrGetConstant(ONNX_NAMESPACE::TensorProto_DataType_FLOAT, 1.0f); @@ -143,7 +143,7 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // If A or B input is 1-D, we need to reshape the output back to its original shape. if (a_shape.size() == 1 || b_shape.size() == 1) { common_options.set("label", node.Name() + "_reshape_output"); - emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntfromInt64(output_shape)); + emscripten::val output_shape_arr = emscripten::val::array(GetNarrowedIntFromInt64(output_shape)); output = model_builder.GetBuilder().call("reshape", output, output_shape_arr, diff --git a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc index 75ce80462544e..0b927075402fe 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc @@ -62,9 +62,9 @@ std::vector repeat_sequence(int32_t sequence_length, int32_t kv_num_hea } /** GroupQueryAttention SubGraph. - Abbreviatios: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length - N is number of attention heads, kv_N is number of attention heads for kv, H is head size - G is group size, and G=N/kv_N, W=N*H, h=Sqrt(H). + Abbreviations: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length + N is number of attention heads, kv_N is number of attention heads for kv, H is head size + G is group size, and G=N/kv_N, W=N*H, h=Sqrt(H). GQA inputs: query, key, value, past_key, past_value, seqlens_k, total_sequence_length Notes: cos_cache, sin_cache inputs are not supported. If the data type of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision. diff --git a/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc index 0e26339511d27..d435750221e70 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/mha_op_builder.cc @@ -31,8 +31,8 @@ class MultiHeadAttentionOpBuilder : public BaseOpBuilder { }; /** MultiHeadAttention SubGraph. - Abbreviatios: B is batch_size, S is sequence_length, W is hidden_size - N is number of attention heads, H is head size + Abbreviations: B is batch_size, S is sequence_length, W is hidden_size + N is number of attention heads, H is head size Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision. query key value diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 9fb643f055ef3..2851590a48620 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -234,7 +234,7 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder output = model_builder.GetBuilder().call("instanceNormalization", input, options); // Reshape back to the original output shape for 3D input. if (input_shape.size() != 4) { - std::vector output_shape = GetNarrowedIntfromInt64(input_shape); + std::vector output_shape = GetNarrowedIntFromInt64(input_shape); emscripten::val reshape_output_options = emscripten::val::object(); reshape_output_options.set("label", node.Name() + "reshape_output"); output = model_builder.GetBuilder().call("reshape", diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 5d921c5176a64..7bcf56b380eb1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -77,7 +77,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Add Padding. - // Usually using autopadding is more efficient than using explicit padding. + // Usually using auto padding is more efficient than using explicit padding. // Try to see if we can map explicit padding to auto padding. const auto onnx_strides = helper.Get("strides", std::vector{1, 1}); const auto onnx_pads = helper.Get("pads", std::vector{0, 0, 0, 0}); @@ -94,7 +94,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, auto_pad_type, pads_out, !is_nhwc)); - pads = GetNarrowedIntfromInt64(pads_out); + pads = GetNarrowedIntFromInt64(pads_out); } // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], // while WebNN's padding is [beginning_height, ending_height, beginning_width, ending_width]. diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index 053c41773db40..d6fd0b1ebac3e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -99,7 +99,7 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, if (!has_zero_point) { if (zero_point_shape.empty()) { // zero_point has the same shape as the scale tensor. - zero_point_shape = GetNarrowedIntfromInt64(scale_shape); + zero_point_shape = GetNarrowedIntFromInt64(scale_shape); } // Create a zero constant with the same shape as the scale tensor. // The zero value has been pre-processed in the CreateOrGetConstant function, diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index ca5fb5150aa5b..b0a01b0ed63aa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -223,7 +223,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, ORT_RETURN_IF_NOT(GetResizeSizesAndAxes(model_builder.GetGraphViewer(), node, sizes, axes, is_nhwc, input_shape, logger), "Error getting Resize sizes"); - webnn_sizes = GetNarrowedIntfromInt64(sizes); + webnn_sizes = GetNarrowedIntFromInt64(sizes); options.set("sizes", emscripten::val::array(webnn_sizes)); } else { ORT_RETURN_IF_NOT(GetResizeScalesAndAxes(model_builder.GetGraphViewer(), node, scales, axes, is_nhwc, logger), @@ -231,7 +231,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("scales", emscripten::val::array(scales)); } - std::vector webnn_axes = GetNarrowedIntfromInt64(axes); + std::vector webnn_axes = GetNarrowedIntFromInt64(axes); options.set("axes", emscripten::val::array(webnn_axes)); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc index 37071b1030e11..4395c2854dcfb 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/rotaryEmbedding_op_builder.cc @@ -334,13 +334,13 @@ Status RotaryEmbeddingOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_build if (input_is_4d) { // The output is in 4D shape, we need to transpose it back to the original shape. // Reuse the transpose_options' permutation because the original permutation also - // happens to be its own inverse. (inserve({0, 2, 1, 3} == {0, 2, 1, 3}) + // happens to be its own inverse. (inverse({0, 2, 1, 3} == {0, 2, 1, 3}) transpose_options.set("label", node_name + "_transpose_output"); output = wnn_builder.call("transpose", output, transpose_options); } else { // The output is in 3D shape, we need to reshape it back to the original shape. // The output shape is same as the input shape. - const std::vector output_shape = GetNarrowedIntfromInt64(input_shape); + const std::vector output_shape = GetNarrowedIntFromInt64(input_shape); emscripten::val reshape_output_options = emscripten::val::object(); reshape_output_options.set("label", node_name + "_reshape_output"); output = wnn_builder.call( diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index 5efbfe932c602..253a791158bdc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -114,8 +114,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const emscripten::val output = reverse_output; if (is_slice_required) { - std::vector starts = GetNarrowedIntfromInt64(compute_metadata.starts_); - std::vector steps = GetNarrowedIntfromInt64(compute_metadata.steps_); + std::vector starts = GetNarrowedIntFromInt64(compute_metadata.starts_); + std::vector steps = GetNarrowedIntFromInt64(compute_metadata.steps_); std::vector sizes(rank); std::transform(compute_metadata.ends_.cbegin(), compute_metadata.ends_.cend(), compute_metadata.starts_.cbegin(), sizes.begin(), [](int64_t i, int64_t j) { return SafeInt(i - j); }); diff --git a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc index 7e34e35ebac16..4a1e58feb10a2 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/squeeze_unsqueeze_op_builder.cc @@ -85,7 +85,7 @@ Status SqueezeUnsqueezeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_buil emscripten::val output = emscripten::val::undefined(); // Use WebNN's reshape to implement Squeeze/Unsqueeze. - std::vector new_shape = GetNarrowedIntfromInt64(input_shape); + std::vector new_shape = GetNarrowedIntFromInt64(input_shape); // Sort axes_data in ascending order. std::sort(axes_data.begin(), axes_data.end()); if (op_type == "Squeeze") { diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 452071f469c4f..0d0021d09a077 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -40,7 +40,7 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); - std::vector permutation = GetNarrowedIntfromInt64(perm); + std::vector permutation = GetNarrowedIntFromInt64(perm); options.set("permutation", emscripten::val::array(permutation)); emscripten::val output = model_builder.GetBuilder().call("transpose", input, options); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 372f9b2fd273a..472da627e4272 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -184,7 +184,7 @@ Status ModelBuilder::RegisterConstant(const onnx::TensorProto& tensor, emscripte std::vector int32_data; if (should_convert_int64_to_int32) { try { - int32_data = GetNarrowedIntfromInt64( + int32_data = GetNarrowedIntFromInt64( gsl::span(reinterpret_cast(tensor_ptr), num_elements)); LOGS(logger, VERBOSE) << "Initializer '" << tensor.name() << "' is converted from int64 to int32."; } catch (const std::exception& e) { diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index d2cd0639affd0..a9e88d355482b 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -120,7 +120,7 @@ class ModelBuilder { // Create or retrieve one of the following: // - A WebNN constant MLOperand filled with the specified value, data type, and shape. // - A WebNN scalar constant MLOperand with the specified value and data type. -// For scalar constant, it is workaround for builer.constant(type, value) method since +// For scalar constant, it is workaround for builder.constant(type, value) method since // it has not been implemented now. // https://webmachinelearning.github.io/webnn/#api-mlgraphbuilder-constant-type-value // diff --git a/onnxruntime/core/session/abi_devices.h b/onnxruntime/core/session/abi_devices.h index 67253a83ab490..50469126996b2 100644 --- a/onnxruntime/core/session/abi_devices.h +++ b/onnxruntime/core/session/abi_devices.h @@ -64,7 +64,11 @@ struct OrtEpDevice { OrtKeyValuePairs ep_metadata; OrtKeyValuePairs ep_options; - OrtEpFactory* ep_factory; + mutable OrtEpFactory* ep_factory; const OrtMemoryInfo* device_memory_info{nullptr}; const OrtMemoryInfo* host_accessible_memory_info{nullptr}; + + // the user provides const OrtEpDevice instances, but the OrtEpFactory API takes non-const instances for all + // get/create methods to be as flexible as possible. this helper converts to a non-const factory instance. + OrtEpFactory* GetMutableFactory() const { return ep_factory; } }; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index cefcee8f408d7..c5fc4e7ccf76f 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -436,17 +436,14 @@ ORT_API_STATUS_IMPL(OrtApis::ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ Ort } case OrtOpAttrType::ORT_OP_ATTR_STRING: { const auto& s = attr->s(); - if (len < s.size() + 1) { + if (len < s.size()) { ret = OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Size of data not large enough to hold the string."); } else { char* output_c = reinterpret_cast(data); - for (char c : s) { - *output_c++ = c; - } - *output_c = '\0'; + memcpy(output_c, s.data(), s.size()); } - *out = s.size() + 1; + *out = s.size(); break; } case OrtOpAttrType::ORT_OP_ATTR_STRINGS: { diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 8acf1df06b46d..493c0a106074c 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -9,6 +9,7 @@ #include "core/framework/allocator.h" #include "core/framework/allocator_utils.h" #include "core/framework/error_code_helper.h" +#include "core/framework/plugin_data_transfer.h" #include "core/graph/constants.h" #include "core/graph/op.h" #include "core/platform/device_discovery.h" @@ -163,8 +164,8 @@ Status Environment::UnregisterAllocator(const OrtMemoryInfo& mem_info) { Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool error_if_not_found) { auto it = FindExistingAllocator(shared_allocators_, mem_info); - - if (error_if_not_found && it == shared_allocators_.end()) { + const bool found_shared_allocator = it != shared_allocators_.end(); + if (!found_shared_allocator && error_if_not_found) { return Status(ONNXRUNTIME, INVALID_ARGUMENT, "No allocator for this device has been registered for sharing."); } @@ -174,12 +175,18 @@ Status Environment::UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool // so when we remove that from shared_allocators_ we release the OrtAllocator instance. // shared_ort_allocators_ are internal only so never an error if there's no match - auto it2 = FindExistingAllocator(shared_ort_allocators_, mem_info); - if (it2 != shared_ort_allocators_.end()) { + if (auto it2 = FindExistingAllocator(shared_ort_allocators_, mem_info); it2 != shared_ort_allocators_.end()) { shared_ort_allocators_.erase(it2); } - shared_allocators_.erase(it); + // also remove an arena wrapped allocator from an EP if the user called CreateSharedAllocator to create one + if (auto it3 = arena_ort_allocators_.find(&mem_info); it3 != arena_ort_allocators_.end()) { + arena_ort_allocators_.erase(it3); + } + + if (found_shared_allocator) { + shared_allocators_.erase(it); + } return Status::OK(); } @@ -420,7 +427,10 @@ Status Environment::CreateAndRegisterAllocatorV2(const std::string& provider_typ provider_type + " is not implemented in CreateAndRegisterAllocatorV2()"}; } -Environment::~Environment() = default; +Environment::~Environment() { + // need to make sure all the OrtAllocator instances are released prior to any plugin EPs being freed + shared_allocators_.clear(); +} Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator) { std::lock_guard lock{mutex_}; @@ -444,6 +454,29 @@ Status Environment::GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocat } #if !defined(ORT_MINIMAL_BUILD) + +// +// Plugin EP support +// + +namespace { +Status CreateDataTransferForFactory(OrtEpFactory& ep_factory, + std::unique_ptr& data_transfer) { + OrtDataTransferImpl* data_transfer_impl = nullptr; + OrtStatus* status = ep_factory.CreateDataTransfer(&ep_factory, &data_transfer_impl); + if (status != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "Error creating data transfer: ", ToStatusAndRelease(status).ToString()); + } + + if (data_transfer_impl != nullptr) { + data_transfer = std::make_unique(*data_transfer_impl); + } + + return Status::OK(); +} +} // namespace + Status Environment::RegisterExecutionProviderLibrary(const std::string& registration_name, std::unique_ptr ep_library, const std::vector& internal_factories) { @@ -477,6 +510,16 @@ Status Environment::RegisterExecutionProviderLibrary(const std::string& registra } } + for (auto* factory : ep_info->factories) { + std::unique_ptr data_transfer; + ORT_RETURN_IF_ERROR(CreateDataTransferForFactory(*factory, data_transfer)); + + if (data_transfer) { + ep_info->data_transfers.push_back(data_transfer.get()); // store so we can unregister in the unload + ORT_RETURN_IF_ERROR(data_transfer_mgr_.RegisterDataTransfer(std::move(data_transfer))); + } + } + for (const auto& internal_factory : internal_factories) { internal_ep_factories_.insert(internal_factory); } @@ -534,6 +577,10 @@ Status Environment::UnregisterExecutionProviderLibrary(const std::string& ep_nam // something goes wrong in any of the following steps.. ep_libraries_.erase(ep_name); + for (auto* data_transfer : ep_info->data_transfers) { + ORT_RETURN_IF_ERROR(data_transfer_mgr_.UnregisterDataTransfer(data_transfer)); + } + for (auto* internal_factory : ep_info->internal_factories) { internal_ep_factories_.erase(internal_factory); } @@ -590,7 +637,22 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator_out, bool replace_existing) { - // if we're replacing an existing allocator we don't care who added it + // NOTE: memory_info is guaranteed to come from the OrtEpDevice when this is called + + // we need to remove from shared_ort_allocators_ first in case the entry in shared_allocators_ owns the pointer in + // shared_ort_allocators_. + if (auto it = FindExistingAllocator(shared_ort_allocators_, memory_info, /*match_name*/ true); + it != shared_ort_allocators_.end()) { + shared_ort_allocators_.erase(it); + } + + // if a previous call created an arena wrapped allocator for the EP's memory_info we also need to remove that + if (auto it = arena_ort_allocators_.find(&memory_info); it != arena_ort_allocators_.end()) { + arena_ort_allocators_.erase(it); + } + + // we only want one shared allocator for an OrtDevice in the shared_allocators_ so that it's deterministic which + // one will be used for an inference session. ignore the name so that is the case. if (auto it = FindExistingAllocator(shared_allocators_, memory_info, /*match_name*/ false); it != shared_allocators_.end()) { if (!replace_existing) { @@ -600,12 +662,6 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, shared_allocators_.erase(it); } - // clear out any exact match in the internal shared allocators - if (auto it = FindExistingAllocator(shared_ort_allocators_, memory_info, /*match_name*/ true); - it != shared_ort_allocators_.end()) { - shared_ort_allocators_.erase(it); - } - OrtAllocator* allocator = nullptr; auto* ort_status = ep_device.ep_factory->CreateAllocator(ep_device.ep_factory, &memory_info, allocator_options, &allocator); @@ -613,10 +669,6 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, return ToStatusAndRelease(ort_status); } - if (allocator_out != nullptr) { - *allocator_out = allocator; - } - auto ort_allocator = OrtAllocatorUniquePtr(allocator, [&ep_device](OrtAllocator* allocator) { ep_device.ep_factory->ReleaseAllocator(ep_device.ep_factory, allocator); @@ -625,15 +677,13 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, AllocatorPtr shared_allocator; if (allocator_type == OrtArenaAllocator) { - // wrap with arena + // wrap with ORT arena OrtArenaCfg arena_cfg; if (allocator_options != nullptr) { auto status = OrtArenaCfg::FromKeyValuePairs(*allocator_options, arena_cfg); } - // pending Stream support being added to plugin EP API in separate PR - // ep_device.ep_factory->IsStreamAware(ep_device.ep_factory); - bool stream_aware_arena = false; + bool stream_aware_arena = ep_device.ep_factory->IsStreamAware(ep_device.ep_factory); AllocatorCreationInfo alloc_creation_info{ [&ort_allocator](int) -> std::unique_ptr { @@ -646,13 +696,27 @@ Status Environment::CreateSharedAllocatorImpl(const OrtEpDevice& ep_device, }; shared_allocator = CreateAllocator(alloc_creation_info); + + // need an OrtAllocator to return to the user so we need yet another layer. + // we pass in a copy of the AllocatorPtr (which is a shared_ptr) in order to maintain the overall condition that + // shared_allocators_ is the main owner of the allocator and the last place we delete from when removing + // from shared_ort_allocators_, arena_ort_allocators_ and shared_allocators_. + auto arena_ort_allocator = std::make_unique(AllocatorPtr(shared_allocator)); + allocator = arena_ort_allocator.get(); + + // store the entry using the EPs memory info for easier lookup when removing + arena_ort_allocators_.insert({&memory_info, std::move(arena_ort_allocator)}); } else { + shared_ort_allocators_.insert(allocator); shared_allocator = std::make_shared(std::move(ort_allocator)); } - shared_ort_allocators_.insert(allocator); shared_allocators_.push_back(std::move(shared_allocator)); + if (allocator_out != nullptr) { + *allocator_out = allocator; + } + return Status::OK(); } @@ -702,13 +766,13 @@ Status Environment::EpInfo::Create(std::unique_ptr library_in, std::u instance.internal_factories = internal_factories; ORT_RETURN_IF_ERROR(instance.library->Load()); - const auto& factories = instance.library->GetFactories(); + instance.factories = instance.library->GetFactories(); // OrtHardwareDevice instances to pass to GetSupportedDevices. sorted by type to be slightly more structured. // the set of hardware devices is static so this can also be static. const static std::vector sorted_devices = SortDevicesByType(); - for (auto* factory_ptr : factories) { + for (auto* factory_ptr : instance.factories) { ORT_ENFORCE(factory_ptr != nullptr, "Factory pointer was null. EpLibrary should prevent this. Library:", instance.library->RegistrationName()); diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/ep_api.cc index ad965845041f7..6fdd47c537cb3 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/ep_api.cc @@ -12,6 +12,7 @@ #include "core/framework/ort_value.h" #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/tensor.h" #include "core/graph/ep_api_types.h" #include "core/session/abi_devices.h" diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/ep_api_utils.h index a0904c32011a7..77528565eced7 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/ep_api_utils.h @@ -45,6 +45,34 @@ struct ForwardToFactory { session_options, logger, ep); } + static OrtStatus* ORT_API_CALL CreateAllocator(_In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_ OrtAllocator** allocator) noexcept { + return static_cast(this_ptr)->CreateAllocator(memory_info, allocator_options, allocator); + } + + static void ORT_API_CALL ReleaseAllocator(_In_ OrtEpFactory* this_ptr, _In_ OrtAllocator* allocator) noexcept { + return static_cast(this_ptr)->ReleaseAllocator(allocator); + } + + static OrtStatus* ORT_API_CALL CreateDataTransfer(_In_ OrtEpFactory* this_ptr, + _Outptr_opt_ OrtDataTransferImpl** data_transfer) noexcept { + return static_cast(this_ptr)->CreateDataTransfer(data_transfer); + } + + static bool ORT_API_CALL IsStreamAware(_In_ const OrtEpFactory* this_ptr) noexcept { + return static_cast(this_ptr)->IsStreamAware(); + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDevice(_In_ OrtEpFactory* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_opt_ OrtSyncStreamImpl** stream) noexcept { + // ignore the OrtEp input as we won't ever have one for internal EPs + return static_cast(this_ptr)->CreateSyncStreamForDevice(memory_device, stream_options, stream); + } + static void ORT_API_CALL ReleaseEp(OrtEpFactory* this_ptr, OrtEp* ep) noexcept { static_cast(this_ptr)->ReleaseEp(ep); } diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/ep_factory_internal.cc index fa4ef2515ca92..9804aa6a5c42d 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/ep_factory_internal.cc @@ -14,14 +14,8 @@ namespace onnxruntime { using Forward = ForwardToFactory; -EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, - GetSupportedFunc&& get_supported_func, - CreateFunc&& create_func) - : ep_name_{ep_name}, - vendor_{vendor}, - vendor_id_{vendor_id}, - get_supported_func_{std::move(get_supported_func)}, - create_func_{create_func} { +EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) + : impl_{std::move(impl)} { ort_version_supported = ORT_API_VERSION; OrtEpFactory::GetName = Forward::GetFactoryName; @@ -31,20 +25,17 @@ EpFactoryInternal::EpFactoryInternal(const std::string& ep_name, const std::stri OrtEpFactory::GetSupportedDevices = Forward::GetSupportedDevices; OrtEpFactory::CreateEp = Forward::CreateEp; OrtEpFactory::ReleaseEp = Forward::ReleaseEp; + OrtEpFactory::CreateAllocator = Forward::CreateAllocator; + OrtEpFactory::ReleaseAllocator = Forward::ReleaseAllocator; + OrtEpFactory::CreateDataTransfer = Forward::CreateDataTransfer; + OrtEpFactory::IsStreamAware = Forward::IsStreamAware; + OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; } const char* EpFactoryInternal::GetVersion() const noexcept { return ORT_VERSION; } -OrtStatus* EpFactoryInternal::GetSupportedDevices(const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) noexcept { - return get_supported_func_(this, devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); -} - OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, size_t /*num_devices*/, @@ -54,25 +45,23 @@ OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); } -OrtStatus* EpFactoryInternal::CreateIExecutionProvider(const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "EpFactoryInternal currently only supports one device at a time."); +// Prior to addition to SessionOptions the EP options do not have a prefix. +// They are prefixed with 'ep..' when added to SessionOptions. +// +// Use this function to get the options without the prefix from SessionOptions. +// Required by the option parsing for multiple existing EPs. +ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); + ProviderOptions ep_options; + + for (const auto& [key, value] : session_options.config_options.configurations) { + if (key.find(option_prefix) == 0) { + // remove the prefix and add + ep_options[key.substr(option_prefix.length())] = value; + } } - return create_func_(this, devices, ep_metadata_pairs, num_devices, session_options, session_logger, ep); -} - -void EpFactoryInternal::ReleaseEp(OrtEp* /*ep*/) { - // we never create an OrtEp so we should never be trying to release one - ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); + return ep_options; } InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/ep_factory_internal.h index ee08e2233c529..ae450efa394e8 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/ep_factory_internal.h @@ -10,43 +10,100 @@ #include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" namespace onnxruntime { +class EpFactoryInternal; class EpLibraryInternal; struct SessionOptions; -class EpFactoryInternal : public OrtEpFactory { +// class with virtual methods that are implemented for each internal EP +class EpFactoryInternalImpl { public: - // factory is non-const as a pointer to the factory is added to OrtEpDevice and needs to be non-const - // to provide flexibility to the CreateEp call to modify internal state. - using GetSupportedFunc = std::function; - - using CreateFunc = std::function* ep)>; - - EpFactoryInternal(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id, - GetSupportedFunc&& get_supported_func, - CreateFunc&& create_func); + EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) + : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { + } const char* GetName() const noexcept { return ep_name_.c_str(); } const char* GetVendor() const noexcept { return vendor_.c_str(); } uint32_t GetVendorId() const noexcept { return vendor_id_; } const char* GetVersion() const noexcept; + virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices) noexcept = 0; + + virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) = 0; + + virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, + _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, + _Outptr_ OrtAllocator** allocator) noexcept { + // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned + // so this should never be called + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); + } + + virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { + // we don't create any allocators so we don't need to release any + } + + virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + *data_transfer = nullptr; + return nullptr; // Default implementation does nothing + } + + virtual bool IsStreamAware() const { + return false; + } + + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + *stream = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + } + + // Function ORT calls to release an EP instance. + void ReleaseEp(OrtEp* ep); + + virtual ~EpFactoryInternalImpl() = default; + + protected: + ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; + + private: + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID +}; + +// this class can't have any virtual methods as they break using it as an OrtEpFactory* in OrtEpDevice. +class EpFactoryInternal : public OrtEpFactory { + public: + EpFactoryInternal(std::unique_ptr impl); + + const char* GetName() const noexcept { return impl_->GetName(); } + const char* GetVendor() const noexcept { return impl_->GetVendor(); } + uint32_t GetVendorId() const noexcept { return impl_->GetVendorId(); } + const char* GetVersion() const noexcept; + OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, _Inout_ OrtEpDevice** ep_devices, _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices) noexcept; + _Out_ size_t* num_ep_devices) noexcept { + return impl_->GetSupportedDevices(*this, devices, num_devices, ep_devices, max_ep_devices, num_ep_devices); + } // we don't implement this. CreateIExecutionProvider should be used. OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -60,19 +117,43 @@ class EpFactoryInternal : public OrtEpFactory { _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, _In_ size_t num_devices, _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ std::unique_ptr* ep); + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) { + return impl_->CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, session_options, logger, ep); + } + + OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* memory_info, + _In_opt_ const OrtKeyValuePairs* allocator_options, + _Outptr_ OrtAllocator** allocator) noexcept { + return impl_->CreateAllocator(memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(_In_ OrtAllocator* allocator) noexcept { + return impl_->ReleaseAllocator(allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + return impl_->CreateDataTransfer(data_transfer); + } + + bool IsStreamAware() const { + return impl_->IsStreamAware(); + } + + OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); + } // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* ep); + void ReleaseEp(OrtEp* /*ep*/) { + // we never create an OrtEp so we should never be trying to release one + ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); + } private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID - const GetSupportedFunc get_supported_func_; // function to return supported devices - const CreateFunc create_func_; // function to create the EP instance - - std::vector> eps_; // EP instances created by this factory + std::unique_ptr impl_; }; // IExecutionProviderFactory for EpFactoryInternal that is required for SessionOptionsAppendExecutionProvider_V2 diff --git a/onnxruntime/core/session/ep_library.cc b/onnxruntime/core/session/ep_library.cc deleted file mode 100644 index a5aa6e925311c..0000000000000 --- a/onnxruntime/core/session/ep_library.cc +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library.h" - -#include "core/framework/provider_options.h" -#include "core/session/abi_session_options_impl.h" - -namespace onnxruntime { -// Prior to addition to SessionOptions the EP options do not have a prefix. -// They are prefixed with 'ep..' when added to SessionOptions. -// -// Use this function to get the options without the prefix from SessionOptions. -// Required by the option parsing for multiple existing EPs. -ProviderOptions EpLibrary::GetOptionsFromSessionOptions(const std::string& ep_name, - const SessionOptions& session_options) { - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(ep_name.c_str()); - ProviderOptions ep_options; - - for (const auto& [key, value] : session_options.config_options.configurations) { - if (key.find(option_prefix) == 0) { - // remove the prefix and add - ep_options[key.substr(option_prefix.length())] = value; - } - } - - return ep_options; -} -} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library.h b/onnxruntime/core/session/ep_library.h index 648bec022580f..24ab74e1c77fc 100644 --- a/onnxruntime/core/session/ep_library.h +++ b/onnxruntime/core/session/ep_library.h @@ -26,9 +26,5 @@ class EpLibrary { virtual ~EpLibrary() = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(EpLibrary); - - protected: - static ProviderOptions GetOptionsFromSessionOptions(const std::string& ep_name, - const SessionOptions& session_options); }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc index ce5736f601b45..986ccb1fa17fc 100644 --- a/onnxruntime/core/session/ep_library_internal.cc +++ b/onnxruntime/core/session/ep_library_internal.cc @@ -4,6 +4,7 @@ #include "core/session/ep_library_internal.h" #include "core/framework/error_code_helper.h" +#include "core/framework/ortmemoryinfo.h" #include "core/framework/session_options.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/abi_devices.h" @@ -21,33 +22,38 @@ #endif namespace onnxruntime { -std::unique_ptr EpLibraryInternal::CreateCpuEp() { - const auto get_supported = [](OrtEpFactory* factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) -> OrtStatus* { + +class CpuEpFactory : public EpFactoryInternalImpl { + public: + CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override { size_t& num_ep_devices = *p_num_ep_devices; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { ORT_API_RETURN_IF_ERROR( - OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, &ep_devices[num_ep_devices++])); } } return nullptr; - }; - - const auto create_cpu_ep = [](OrtEpFactory* /*factory*/, - const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) -> OrtStatus* { + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { if (num_devices != 1) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "CPU EP factory currently only supports one device at a time."); @@ -58,62 +64,93 @@ std::unique_ptr EpLibraryInternal::CreateCpuEp() { (*ep)->SetLogger(session_logger->ToInternal()); return nullptr; - }; + } +}; - std::string ep_name = kCpuExecutionProvider; - auto cpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - get_supported, create_cpu_ep); - return std::make_unique(std::move(cpu_factory)); +std::unique_ptr EpLibraryInternal::CreateCpuEp() { + auto cpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); } #if defined(USE_DML) -std::unique_ptr EpLibraryInternal::CreateDmlEp() { - static const std::string ep_name = kDmlExecutionProvider; - const auto is_supported = [](OrtEpFactory* factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) -> OrtStatus* { +class DmlEpFactory : public EpFactoryInternalImpl { + public: + DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override { size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { std::unique_ptr ep_options; - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is associated with - // a specific device. + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is + // associated with a specific device. // How would we know what options should not allow user overrides if set in OrtEpDevice? + int32_t device_id = 0; // If no device_id was found default to 0 if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { ep_options = std::make_unique(); - ep_options->Add("device_id", it->second.c_str()); + device_id = std::stoi(it->second); } - auto* api_status = OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, ep_options.get(), - &ep_devices[num_ep_devices++]); + ep_options->Add("device_id", std::to_string(device_id)); + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices]); + + if (device_memory_infos.size() < device_id + 1) { + device_memory_infos.resize(device_id + 1); + device_allocators.resize(device_id + 1); + } + + if (device_memory_infos[device_id] == nullptr) { + // Create memory info for the device if it doesn't already exist + device_memory_infos[device_id] = std::make_unique( + "DML", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + narrow(device_id))); + } + + // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. + // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], + // device_memory_infos[device_id].get()); if (api_status != nullptr) { return api_status; } + + ++num_ep_devices; } } return nullptr; - }; - - const auto create_dml_ep = [](OrtEpFactory* /*factory*/, - const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) -> OrtStatus* { + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { + *ep = nullptr; + if (num_devices != 1) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "DML EP factory currently only supports one device at a time."); } - auto ep_options = GetOptionsFromSessionOptions(ep_name, session_options->value); + auto ep_options = GetOptionsFromSessionOptions(session_options->value); auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, ep_options); @@ -121,45 +158,73 @@ std::unique_ptr EpLibraryInternal::CreateDmlEp() { (*ep)->SetLogger(session_logger->ToInternal()); return nullptr; - }; + } + + OrtStatus* CreateAllocator(const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept override { + // TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That + // requires pulling lots of things out of the DML EP to get the D3D12 device and create a + // BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp + //*allocator = device_allocators[memory_info->device.Id()].get(); + *allocator = nullptr; + return nullptr; + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; + } - auto dml_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_dml_ep); + std::vector> device_memory_infos; // memory info for each device + std::vector> device_allocators; // allocators for each device +}; - return std::make_unique(std::move(dml_factory)); +std::unique_ptr EpLibraryInternal::CreateDmlEp() { + auto dml_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(dml_factory_impl)); + return std::make_unique(std::move(internal_factory)); } #endif #if defined(USE_WEBGPU) -std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { - static const std::string ep_name = kWebGpuExecutionProvider; - - const auto is_supported = [](OrtEpFactory* factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) -> OrtStatus* { +class WebGpuEpFactory : public EpFactoryInternalImpl { + public: + WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override { size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { // TODO: any metadata or options to add? - ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(factory, &device, nullptr, nullptr, + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, nullptr, &ep_devices[num_ep_devices++])); } } return nullptr; - }; - - const auto create_webgpu_ep = [](OrtEpFactory* /*factory*/, - const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) -> OrtStatus* { + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { + *ep = nullptr; + if (num_devices != 1) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "WebGPU EP factory currently only supports one device at a time."); @@ -170,12 +235,28 @@ std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { (*ep)->SetLogger(session_logger->ToInternal()); return nullptr; - }; - - auto webgpu_factory = std::make_unique(ep_name, "Microsoft", OrtDevice::VendorIds::MICROSOFT, - is_supported, create_webgpu_ep); + } + + /* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of + an InferenceSession. + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + *allocator = device_allocators[memory_info->device.Id()].get(); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; + } + */ +}; - return std::make_unique(std::move(webgpu_factory)); +std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { + auto webgpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); } #endif diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc index 70937bdc5d3e8..ae553891beaa7 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ b/onnxruntime/core/session/ep_library_provider_bridge.cc @@ -13,6 +13,81 @@ #include "core/session/ep_factory_internal.h" namespace onnxruntime { +class ProviderBridgeEpFactory : public EpFactoryInternalImpl { + public: + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), + ep_factory.GetVendor(&ep_factory), + ep_factory.GetVendorId(&ep_factory)), + ep_factory_{ep_factory}, + provider_library_{provider_library} { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept override { + ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = &ep_factory; + } + } + + return nullptr; + } + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override { + // get the provider specific options + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto& provider = provider_library_.Get(); + + auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, + ep_options, *session_options, *session_logger, *ep); + + return ToOrtStatus(status); + } + + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(OrtAllocator* allocator) noexcept override { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); + } + + bool IsStreamAware() const noexcept override { + return ep_factory_.IsStreamAware(&ep_factory_); + } + + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept override { + return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); + } + + OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP + ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP +}; + Status EpLibraryProviderBridge::Load() { std::lock_guard lock{mutex_}; @@ -34,47 +109,9 @@ Status EpLibraryProviderBridge::Load() { // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. for (const auto& factory : ep_library_plugin_->GetFactories()) { - const auto is_supported_fn = [&factory](OrtEpFactory* ep_factory_internal, // from factory_ptrs_ - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) -> OrtStatus* { - ORT_API_RETURN_IF_ERROR(factory->GetSupportedDevices(factory, devices, num_devices, ep_devices, max_ep_devices, - num_ep_devices)); - - // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. - for (size_t i = 0; i < *num_ep_devices; ++i) { - auto* ep_device = ep_devices[i]; - if (ep_device) { - ep_device->ep_factory = ep_factory_internal; - } - } + auto factory_impl = std::make_unique(*factory, *provider_library_); + auto internal_factory = std::make_unique(std::move(factory_impl)); - return nullptr; - }; - - const auto create_fn = [this, &factory](OrtEpFactory* /*ep_factory_internal from factory_ptrs_*/, - const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* logger, std::unique_ptr* ep) { - // get the provider options - auto ep_options = GetOptionsFromSessionOptions(factory->GetName(factory), session_options->value); - auto& provider = provider_library_->Get(); - - auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, - ep_options, *session_options, *logger, *ep); - - return ToOrtStatus(status); - }; - - auto internal_factory = std::make_unique(factory->GetName(factory), - factory->GetVendor(factory), - factory->GetVendorId(factory), - is_supported_fn, - create_fn); factory_ptrs_.push_back(internal_factory.get()); internal_factory_ptrs_.push_back(internal_factory.get()); factories_.push_back(std::move(internal_factory)); diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/ep_library_provider_bridge.h index 3c7f083df227e..0717ccd957de7 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/ep_library_provider_bridge.h @@ -49,7 +49,7 @@ class EpLibraryProviderBridge : public EpLibrary { std::unique_ptr provider_library_; // provider bridge EP library // EpLibraryPlugin that provides the CreateEpFactories and ReleaseEpFactory implementations. - // we wrap the OrtEpFactory instances it contains to pass through GetSupportedDevices calls, and + // we wrap the OrtEpFactory instances it contains to pass through function calls, and // implement EpFactoryInternal::CreateIExecutionProvider by calling Provider::CreateIExecutionProvider. std::unique_ptr ep_library_plugin_; diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc index 878a5384dfee7..52cf6c62c9702 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.cc @@ -12,6 +12,7 @@ #include "core/framework/error_code_helper.h" #include "core/framework/model_metadef_id_generator.h" #include "core/framework/plugin_data_transfer.h" +#include "core/framework/plugin_ep_stream.h" #include "core/graph/ep_api_types.h" #include "core/graph/model_editor_api_types.h" #include "core/session/abi_devices.h" @@ -244,11 +245,12 @@ Status PluginExecutionProvider::FusedNodeState::AddFusedNode(const Node& fused_n /// Note that the EP plugin uses the model editor API to create the OrtNode instances. /// /// Name of the plugin EP. +/// fused nodes provided by ORT. /// EPContext nodes provided by the plugin EP. /// Output parameter set to the resulting array of EPContext nodes. /// Output parameter that stores the NodeArgs used by the EPContext nodes. /// A status indicating success or an error. -static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector plugin_ep_context_nodes, +static Status ConvertEpContextNodes(const std::string& ep_name, const std::vector& fused_nodes, const std::vector plugin_ep_context_nodes, /*out*/ std::vector>& result_nodes, /*out*/ std::vector>& result_node_args) { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -260,8 +262,10 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto std::vector> ep_context_node_args_holder; ep_context_nodes_holder.reserve(plugin_ep_context_nodes.size()); - + int index = -1; for (const OrtNode* ort_node : plugin_ep_context_nodes) { + ++index; + auto& fused_node_filtered_graph = fused_nodes[index].filtered_graph; ORT_RETURN_IF_NOT(ort_node != nullptr, ep_name, ": OrtEp::Compile() returned a NULL EPContext node."); const ModelEditorNode* editor_node = ModelEditorNode::ToInternal(ort_node); @@ -276,13 +280,17 @@ static Status ConvertEpContextNodes(const std::string& ep_name, const std::vecto output_node_args.reserve(editor_node->output_names.size()); for (const std::string& input_name : editor_node->input_names) { - auto node_arg = std::make_unique(input_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type. + auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(input_name); + const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr; + auto node_arg = std::make_unique(input_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available. input_node_args.push_back(node_arg.get()); ep_context_node_args_holder.push_back(std::move(node_arg)); } for (const std::string& output_name : editor_node->output_names) { - auto node_arg = std::make_unique(output_name, /*p_arg_type*/ nullptr); // Graph.Resolve() sets type. + auto node_arg_on_fused_graph = fused_node_filtered_graph.get().GetNodeArg(output_name); + const ONNX_NAMESPACE::TypeProto* p_arg_type = node_arg_on_fused_graph ? node_arg_on_fused_graph->TypeAsProto() : nullptr; + auto node_arg = std::make_unique(output_name, p_arg_type); // Graph.Resolve() cannot set type because EP Context OP does not have proper shape inference function available. output_node_args.push_back(node_arg.get()); ep_context_node_args_holder.push_back(std::move(node_arg)); } @@ -422,7 +430,7 @@ Status PluginExecutionProvider::Compile(const std::vector& fu // We store the converted Node and NodeArg instances as members to ensure they can be returned to the ORT graph // partitioner via a call to IExecutionProvider::GetEpContextNodes(). if (generate_ep_ctx_model_) { - ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), plugin_ep_context_nodes, + ORT_RETURN_IF_ERROR(ConvertEpContextNodes(Type(), fused_nodes_and_graphs, plugin_ep_context_nodes, /*out*/ ep_context_nodes_, /*out*/ ep_context_node_args_)); } @@ -556,7 +564,11 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { for (const auto* memory_info : allocator_mem_infos_) { OrtAllocator* ort_allocator_ptr = nullptr; - OrtStatus* ort_status = ep_factory_.CreateAllocator(&ep_factory_, memory_info, nullptr, &ort_allocator_ptr); + // prefer OrtEp function if available, otherwise fall back to using the OrtEpFactory implementation. + OrtStatus* ort_status = ort_ep_->CreateAllocator + ? ort_ep_->CreateAllocator(ort_ep_.get(), memory_info, &ort_allocator_ptr) + : ep_factory_.CreateAllocator(&ep_factory_, memory_info, /*options*/ nullptr, + &ort_allocator_ptr); // throw or log? start with throw if (ort_status != nullptr) { @@ -574,4 +586,39 @@ std::vector PluginExecutionProvider::CreatePreferredAllocators() { return allocators; } +void PluginExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry& registry, + AllocatorMap& /*allocators*/) const { + if (!ep_factory_.IsStreamAware(&ep_factory_)) { + return; + } + + for (const auto* mem_info : allocator_mem_infos_) { + if (mem_info->device.UsesCpuMemory()) { + // CPU memory does not need a stream + continue; + } + + auto device_type = mem_info->device.Type(); + + registry.RegisterCreateStreamFn( + device_type, + [mem_info, this](const OrtDevice& device) { + OrtSyncStreamImpl* stream = nullptr; + const OrtMemoryDevice* memory_device = static_cast(&mem_info->device); + + // prefer OrtEp function if available, otherwise fall back to using the OrtEpFactory implementation. + OrtStatus* status = ort_ep_->CreateSyncStreamForDevice + ? ort_ep_->CreateSyncStreamForDevice(ort_ep_.get(), memory_device, &stream) + : ep_factory_.CreateSyncStreamForDevice(&ep_factory_, memory_device, + /*stream_options*/ nullptr, &stream); + + ORT_ENFORCE(status == nullptr && stream != nullptr, + "Error creating sync stream for device: ", ToStatusAndRelease(status).ToString()); + return std::make_unique(device, *stream, *GetLogger()); + }); + + registry.RegisterWaitFn(device_type, device_type, plugin_ep::Notification::WaitNotificationOnDevice); + registry.RegisterWaitFn(device_type, OrtDevice::CPU, plugin_ep::Notification::WaitNotificationOnHost); + } +} } // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/ep_plugin_provider_interfaces.h index 3ba3118fcaa36..728f959ad67cb 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.h +++ b/onnxruntime/core/session/ep_plugin_provider_interfaces.h @@ -94,6 +94,8 @@ class PluginExecutionProvider : public IExecutionProvider { std::unique_ptr GetDataTransfer() const override; + void RegisterStreamHandlers(IStreamCommandHandleRegistry&, AllocatorMap&) const override; + // create per-session allocators // longer term we should prefer shared allocators in Environment and only create per-session allocators as // needed based on matching against allocator_mem_infos_. diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f147242da668f..25cabd256e318 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1349,7 +1349,8 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool ORT_RETURN_IF_ERROR_SESSIONID_(apply_transformer_once(ensure_unique_dq_for_node_unit, *session_logger_, graph)); } - // apply execution provider independent level 1 graph optimizations. + // apply execution provider independent level 0 and 1 graph optimizations. + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.ApplyTransformers(graph, TransformerLevel::Default, *session_logger_)); ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.ApplyTransformers(graph, TransformerLevel::Level1, *session_logger_)); // if saving model to ORT format we only assign nodes a custom EP can handle and don't compile them. @@ -3325,6 +3326,92 @@ std::pair InferenceSession::GetModelOutput return std::make_pair(common::Status::OK(), &model_->MainGraph().GetOutputs()); } +common::Status InferenceSession::GetInputOutputMemoryInfo(SessionInputOutputType type, + InlinedVector& memory_info) const { + memory_info.clear(); + + if (!is_inited_) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized."); + } + + std::pair result; + switch (type) { + case SessionInputOutputType::kInput: + result = GetModelInputs(); + break; + case SessionInputOutputType::kOutput: + result = GetModelOutputs(); + break; + case SessionInputOutputType::kOverridableInitializer: + // add if/when needed + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "GetInputOutputMemoryInfo for kOverridableInitializer is not implemented."); + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Unexpected SessionInputOutputType of ", static_cast(type)); + } + + ORT_RETURN_IF_ERROR(result.first); + + const auto& def_list = *result.second; + memory_info.reserve(def_list.size()); + + for (const auto* def : def_list) { + InlinedVector node_info_vec; + if (type == SessionInputOutputType::kOutput) { + ORT_RETURN_IF_ERROR(session_state_->GetOutputNodeInfo(def->Name(), node_info_vec)); + } else { + ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + } + + // all entries are for the same OrtDevice so use the first one. + // we need to get an OrtMemoryInfo* that will remain valid, so we get the allocator for the OrtDevice + // from the session state and use its OrtMemoryInfo. + auto allocator = session_state_->GetAllocator(*node_info_vec.front().device); + memory_info.push_back(&allocator->Info()); + } + + return Status::OK(); +} + +common::Status InferenceSession::GetEpDeviceForInputs(InlinedVector& ep_devices) const { + ep_devices.clear(); + +#if defined(ORT_MINIMAL_BUILD) + return common::Status(common::ONNXRUNTIME, common::FAIL, + "GetEpDeviceForInputs is not available in a minimal build."); +#else + if (!is_inited_) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Session has not been initialized."); + } + + std::pair inputs = GetModelInputs(); + + ORT_RETURN_IF_ERROR(inputs.first); + + const auto& def_list = *inputs.second; + ep_devices.reserve(def_list.size()); + + const auto& available_eps = environment_.GetOrtEpDevices(); + + for (const auto* def : def_list) { + InlinedVector node_info_vec; + ORT_RETURN_IF_ERROR(session_state_->GetInputNodeInfo(def->Name(), node_info_vec)); + + // if we have a lot of inputs or there are a lot of execution providers it may be worth creating a map + // instead of doing a linear search each time. + const auto& ep_name = node_info_vec.front().p_node->GetExecutionProviderType(); + auto it = std::find_if(available_eps.begin(), available_eps.end(), [&ep_name](const OrtEpDevice* entry) { + return entry->ep_name == ep_name; + }); + + ep_devices.push_back(it != available_eps.end() ? *it : nullptr); + } + + return Status::OK(); +#endif +} + common::Status InferenceSession::NewIOBinding(std::unique_ptr* io_binding) { { std::lock_guard l(session_mutex_); @@ -3645,34 +3732,50 @@ common::Status InferenceSession::AddPredefinedTransformers( RecordRuntimeOptimizationProducedNodeOpSchemaFn record_runtime_optimization_produced_op_schema_fn, const logging::Logger& logger) const { const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); - for (int i = static_cast(TransformerLevel::Level1); i <= static_cast(TransformerLevel::MaxLevel); i++) { + for (int i = static_cast(TransformerLevel::Default); i <= static_cast(TransformerLevel::MaxLevel); i++) { TransformerLevel level = static_cast(i); - if (graph_optimization_level >= level) { - // Generate and register transformers for level - auto transformers_to_register = [&]() { - const bool use_full_build_optimizations = - level == TransformerLevel::Level1 || - minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; - - if (use_full_build_optimizations) { + std::function>()> transformers_to_register; + + // Enable free dimension override even when the graph optimization level is 0. + // If the optimization level is above 0, the override will be applied during level 1 optimization. + if (level == TransformerLevel::Default) { + if (graph_optimization_level == TransformerLevel::Default) { + transformers_to_register = [&]() { return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, optimizers_to_disable_, GetIntraOpThreadPoolToUse()); - } else { - const auto sat_context = - minimal_build_optimization_handling == - MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations - ? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{ - record_runtime_optimization_produced_op_schema_fn}} - : SatApplyContextVariant{SatDirectApplicationContext{}}; - return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, - logger, - optimizers_to_disable_, - GetIntraOpThreadPoolToUse()); - } - }(); + }; + } + } else { + if (graph_optimization_level >= level) { + // Generate and register transformers for level + transformers_to_register = [&]() { + const bool use_full_build_optimizations = + level == TransformerLevel::Level1 || + minimal_build_optimization_handling == MinimalBuildOptimizationHandling::ApplyFullBuildOptimizations; + + if (use_full_build_optimizations) { + return optimizer_utils::GenerateTransformers(level, session_options_, cpu_ep, logger, + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); + } else { + const auto sat_context = + minimal_build_optimization_handling == + MinimalBuildOptimizationHandling::SaveMinimalBuildRuntimeOptimizations + ? SatApplyContextVariant{SatRuntimeOptimizationSaveContext{ + record_runtime_optimization_produced_op_schema_fn}} + : SatApplyContextVariant{SatDirectApplicationContext{}}; + return optimizer_utils::GenerateTransformersForMinimalBuild(level, session_options_, sat_context, cpu_ep, + logger, + optimizers_to_disable_, + GetIntraOpThreadPoolToUse()); + } + }; + } + } - for (auto& entry : transformers_to_register) { + if (transformers_to_register) { // Ensure the lambda is initialized before invoking it + for (auto& entry : transformers_to_register()) { ORT_RETURN_IF_ERROR(transformer_manager.Register(std::move(entry), level)); } } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 4e25187ff1b47..8bea15c169ed4 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -465,6 +465,25 @@ class InferenceSession { */ std::pair GetModelOutputs() const; + enum class SessionInputOutputType : uint8_t { + kInput = 0, + kOutput = 1, + kOverridableInitializer = 2 + }; + + /** + * Get the OrtMemoryInfo for the inputs or outputs of the model. + * + * This is required for a user to know the location of the input/output when autoep selection is enabled. + */ + common::Status GetInputOutputMemoryInfo(SessionInputOutputType type, + InlinedVector& memory_info) const; + /** + * Get the OrtEpDevice (if available) for the inputs of the model. + * + * This is required for a user to know the location of the input/output when autoep selection is enabled. + */ + common::Status GetEpDeviceForInputs(InlinedVector& memory_info) const; /** * Get the current number of in-progress concurrent Run calls. */ diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index db2a62c77d1bc..64b93ee0f592a 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -22,6 +22,7 @@ #include "core/framework/execution_provider.h" #include "core/framework/onnxruntime_typeinfo.h" #include "core/framework/ort_value.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/tensor.h" #include "core/framework/tensor_type_and_shape.h" #include "core/framework/tensorprotoutils.h" @@ -32,6 +33,7 @@ #include "core/graph/model_editor_api_types.h" #include "core/graph/ep_api_types.h" #include "core/providers/get_execution_providers.h" +#include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" @@ -298,7 +300,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateTensorAsOrtValue, _Inout_ OrtAllocator* alloc API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape, +ORT_API_STATUS_IMPL(OrtApis::CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, + _In_ const int64_t* dense_shape, size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out) { API_IMPL_BEGIN #if !defined(DISABLE_SPARSE_TENSORS) @@ -1312,9 +1315,17 @@ ORT_API_STATUS_IMPL(OrtApis::GetStringTensorElement, _In_ const OrtValue* value, using DefListResult = std::pair; using GetDefListFn = DefListResult (*)(const ::onnxruntime::InferenceSession*); -const auto get_inputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetModelInputs(); }; -const auto get_outputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetModelOutputs(); }; -const auto get_overridable_initializers_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { return session->GetOverridableInitializers(); }; +const auto get_inputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { + return session->GetModelInputs(); +}; + +const auto get_outputs_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { + return session->GetModelOutputs(); +}; + +const auto get_overridable_initializers_fn = [](const ::onnxruntime::InferenceSession* session) -> DefListResult { + return session->GetOverridableInitializers(); +}; static ORT_STATUS_PTR GetNodeDefListCountHelper(const OrtSession* sess, GetDefListFn get_fn, size_t* out) { API_IMPL_BEGIN @@ -3187,6 +3198,148 @@ ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { return OrtExecutionProviderApi::GetEpApi(); } +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* ort_session, + _Out_writes_(num_values) const OrtEpDevice** inputs_ep_devices, + _In_ size_t num_values) { + API_IMPL_BEGIN + if (ort_session == nullptr || inputs_ep_devices == nullptr || num_values == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid argument provided to SessionGetEpDeviceForInputs."); + } + + auto session = reinterpret_cast(ort_session); + + InlinedVector ep_devices; + + ORT_API_RETURN_IF_STATUS_NOT_OK(session->GetEpDeviceForInputs(ep_devices)); + + auto num_found = ep_devices.size(); + if (num_found > num_values) { + auto msg = MakeString("Number of inputs ", num_found, " exceeds the provided size of ", num_values); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, msg.c_str()); + } + + for (size_t i = 0; i < num_values; ++i) { + inputs_ep_devices[i] = (i < num_found) ? ep_devices[i] : nullptr; + } + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStream** ort_stream) { + API_IMPL_BEGIN + if (ep_device == nullptr || ort_stream == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device and stream must be provided."); + } + + const OrtDevice* device = ep_device->device_memory_info ? &ep_device->device_memory_info->device : nullptr; + + if (device == nullptr || device->MemType() != OrtDevice::MemType::DEFAULT) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "ep_device does not use DEFAULT memory of a non-CPU device."); + } + + const auto* factory = ep_device->ep_factory; + if (!factory->IsStreamAware(factory)) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "The execution provider does not support streams."); + } + + // get the stream implementation from the EP factory + OrtSyncStreamImpl* stream_impl = nullptr; + ORT_API_RETURN_IF_ERROR(factory->CreateSyncStreamForDevice(ep_device->GetMutableFactory(), + static_cast(device), // alias + stream_options, &stream_impl)); + + if (stream_impl == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Failed to get a stream implementation from the EP factory."); + } + + // create the wrapper class that uses the EP implementation + auto stream = std::make_unique(ep_device->device_memory_info->device, + *stream_impl, LoggingManager::DefaultLogger()); + + // cast to base type, and to API alias type + *ort_stream = static_cast(static_cast(stream.release())); + + return nullptr; + + API_IMPL_END +} + +ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* stream) { + return stream->GetHandle(); +} + +ORT_API(void, OrtApis::ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* ort_stream) { + // convert from API alias to internal type + auto* stream = static_cast(ort_stream); + + // the only way for the user to get a non-const OrtSyncStream is from CreateSyncStreamForEpDevice, + // so we can safely cast to the plugin_ep::Stream type. + std::unique_ptr ep_stream(reinterpret_cast(stream)); +} + +ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* env, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors) { + API_IMPL_BEGIN + if (env == nullptr || src_tensors == nullptr || dst_tensors == nullptr || num_tensors == 0) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Invalid arguments provided to CopyTensors."); + } + + const OrtMemoryInfo* src_memory_info = nullptr; + const OrtMemoryInfo* dst_memory_info = nullptr; + + const auto validate_and_get_mem_info = + [](const OrtValue* const* values, size_t num_values, const OrtMemoryInfo*& mem_info) -> OrtStatus* { + for (size_t i = 0; i < num_values; ++i) { + const OrtValue* value = values[i]; + if (value == nullptr || !value->IsTensor() || !value->IsAllocated()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "OrtValue must contain Tensor with data."); + } + + if (i == 0) { + mem_info = &value->Get().Location(); + } else if (*mem_info != value->Get().Location()) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "All OrtValue instances must have the same OrtMemoryInfo"); + } + } + + return nullptr; + }; + + ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(src_tensors, num_tensors, src_memory_info)); + ORT_API_RETURN_IF_ERROR(validate_and_get_mem_info(const_cast(dst_tensors), num_tensors, + dst_memory_info)); + + auto& data_transfer_mgr = env->GetEnvironment().GetDataTransferManager(); + const auto* data_transfer = data_transfer_mgr.GetDataTransfer(src_memory_info->device, dst_memory_info->device); + + if (data_transfer == nullptr) { + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "Data transfer implementation between source and destination device was not found."); + } + + std::vector pairs; + pairs.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; ++i) { + pairs.push_back({ + src_tensors[i]->Get(), + *dst_tensors[i]->GetMutable(), + stream, + }); + } + + ORT_API_RETURN_IF_STATUS_NOT_OK(data_transfer->CopyTensors(pairs)); + + return nullptr; + + API_IMPL_END +} + #else // defined(ORT_MINIMAL_BUILD) ORT_API_STATUS_IMPL(OrtApis::RegisterExecutionProviderLibrary, _In_ OrtEnv* /*env*/, const char* /*registration_name*/, const ORTCHAR_T* /*path*/) { @@ -3218,7 +3371,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS size_t /*num_ep_options*/) { API_IMPL_BEGIN return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); - API_IMPL_END } @@ -3227,6 +3379,41 @@ ORT_API(const OrtEpApi*, OrtApis::GetEpApi) { return nullptr; } +ORT_API_STATUS_IMPL(OrtApis::SessionGetEpDeviceForInputs, _In_ const OrtSession* /*ort_session*/, + _Out_writes_(num_values) const OrtEpDevice** /*inputs_ep_devices*/, + _In_ size_t /*num_values*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* /*ep_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_ OrtSyncStream** /*ort_stream*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + API_IMPL_END +} + +ORT_API(void*, OrtApis::SyncStream_GetHandle, _In_ OrtSyncStream* /*stream*/) { + fprintf(stderr, "OrtSyncStream is not supported in a minimal build.\n"); + return nullptr; +} + +ORT_API(void, OrtApis::ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* /*ort_stream*/) { + fprintf(stderr, "OrtSyncStream is not supported in a minimal build.\n"); +} + +ORT_API_STATUS_IMPL(OrtApis::CopyTensors, _In_ const OrtEnv* /*env*/, + _In_reads_(num_tensors) const OrtValue* const* /*src_tensors*/, + _In_reads_(num_tensors) OrtValue* const* /*dst_tensors*/, + _In_opt_ OrtSyncStream* /*stream*/, + _In_ size_t /*num_tensors*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "This API in not supported in a minimal build."); + API_IMPL_END +} + #endif // !defined(ORT_MINIMAL_BUILD) // OrtEpDevice accessors @@ -3270,8 +3457,61 @@ ORT_API(const OrtHardwareDevice*, OrtApis::EpDevice_Device, _In_ const OrtEpDevi return ep_device->device; } -ORT_API(const OrtMemoryInfo*, OrtApis::EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device) { - return ep_device->device_memory_info; +ORT_API(const OrtMemoryInfo*, OrtApis::EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType memory_type) { + switch (memory_type) { + case OrtDeviceMemoryType_DEFAULT: + return ep_device->device_memory_info; + case OrtDeviceMemoryType_HOST_ACCESSIBLE: + return ep_device->host_accessible_memory_info; + default: + return nullptr; + } +} + +namespace { +OrtStatus* GetInputOutputMemoryInfo(const OrtSession* ort_session, + InferenceSession::SessionInputOutputType type, + const OrtMemoryInfo** memory_info, + _In_ size_t num_values) { + auto session = reinterpret_cast(ort_session); + + InlinedVector mem_info; + ORT_API_RETURN_IF_STATUS_NOT_OK( + session->GetInputOutputMemoryInfo(InferenceSession::SessionInputOutputType::kInput, mem_info)); + + auto num_found = mem_info.size(); + if (num_found > num_values) { + auto msg = MakeString("Number of ", + type == InferenceSession::SessionInputOutputType::kOutput ? "outputs " : "inputs ", + mem_info.size(), " exceeds the provided size of ", num_values); + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, msg.c_str()); + } + + for (size_t i = 0; i < num_values; ++i) { + memory_info[i] = (i < num_found) ? mem_info[i] : nullptr; + } + + return nullptr; +} +} // namespace + +ORT_API_STATUS_IMPL(OrtApis::SessionGetMemoryInfoForInputs, _In_ const OrtSession* ort_session, + _Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info, + _In_ size_t num_inputs) { + API_IMPL_BEGIN + return GetInputOutputMemoryInfo(ort_session, InferenceSession::SessionInputOutputType::kInput, + inputs_memory_info, num_inputs); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info, + _In_ size_t num_outputs) { + API_IMPL_BEGIN + return GetInputOutputMemoryInfo(session, InferenceSession::SessionInputOutputType::kOutput, + outputs_memory_info, num_outputs); + API_IMPL_END } static constexpr OrtApiBase ort_api_base = { @@ -3708,6 +3948,8 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::AllocatorGetStats, &OrtApis::CreateMemoryInfo_V2, + &OrtApis::MemoryInfoGetDeviceMemType, + &OrtApis::MemoryInfoGetVendorId, &OrtApis::ValueInfo_GetValueProducer, &OrtApis::ValueInfo_GetValueNumConsumers, @@ -3764,6 +4006,16 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::GetTensorData, &OrtApis::GetSessionOptionsConfigEntries, + + &OrtApis::SessionGetMemoryInfoForInputs, + &OrtApis::SessionGetMemoryInfoForOutputs, + &OrtApis::SessionGetEpDeviceForInputs, + + &OrtApis::CreateSyncStreamForEpDevice, + &OrtApis::SyncStream_GetHandle, + &OrtApis::ReleaseSyncStream, + + &OrtApis::CopyTensors, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9ab927006c320..5f68c56cb044e 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -604,10 +604,13 @@ ORT_API_STATUS_IMPL(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ ORT_API_STATUS_IMPL(AllocatorGetStats, _In_ const OrtAllocator* ptr, _Outptr_ OrtKeyValuePairs** out); ORT_API_STATUS_IMPL(CreateMemoryInfo_V2, _In_ const char* name, _In_ enum OrtMemoryInfoDeviceType device_type, - _In_ uint32_t vendor_id, _In_ int16_t device_id, _In_ enum OrtDeviceMemoryType mem_type, + _In_ uint32_t vendor_id, _In_ int32_t device_id, _In_ enum OrtDeviceMemoryType mem_type, _In_ size_t alignment, enum OrtAllocatorType allocator_type, _Outptr_ OrtMemoryInfo** out); +ORT_API_STATUS_IMPL(MemoryInfoGetDeviceMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtDeviceMemoryType* out); +ORT_API_STATUS_IMPL(MemoryInfoGetVendorId, _In_ const OrtMemoryInfo* ptr, _Out_ uint32_t* out); + // OrtValueInfo ORT_API_STATUS_IMPL(ValueInfo_GetValueProducer, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtNode** producer_node, _Out_opt_ size_t* producer_output_index); @@ -685,7 +688,8 @@ ORT_API_STATUS_IMPL(Node_GetEpName, _In_ const OrtNode* node, _Outptr_result_may ORT_API_STATUS_IMPL(GetRunConfigEntry, _In_ const OrtRunOptions* options, _In_z_ const char* config_key, _Outptr_result_maybenull_z_ const char** config_value); -ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device); +ORT_API(const OrtMemoryInfo*, EpDevice_MemoryInfo, _In_ const OrtEpDevice* ep_device, + _In_ OrtDeviceMemoryType memory_type); ORT_API_STATUS_IMPL(CreateSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDevice* ep_device, _In_ OrtDeviceMemoryType mem_type, _In_ OrtAllocatorType allocator_type, @@ -700,4 +704,30 @@ ORT_API_STATUS_IMPL(ReleaseSharedAllocator, _In_ OrtEnv* env, _In_ const OrtEpDe ORT_API_STATUS_IMPL(GetTensorData, _In_ const OrtValue* value, _Outptr_ const void** out); ORT_API_STATUS_IMPL(GetSessionOptionsConfigEntries, _In_ const OrtSessionOptions* options, _Outptr_ OrtKeyValuePairs** out); + +ORT_API_STATUS_IMPL(SessionGetMemoryInfoForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtMemoryInfo** inputs_memory_info, + _In_ size_t num_inputs); + +ORT_API_STATUS_IMPL(SessionGetMemoryInfoForOutputs, _In_ const OrtSession* session, + _Out_writes_(num_outputs) const OrtMemoryInfo** outputs_memory_info, + _In_ size_t num_outputs); + +ORT_API_STATUS_IMPL(SessionGetEpDeviceForInputs, _In_ const OrtSession* session, + _Out_writes_(num_inputs) const OrtEpDevice** inputs_ep_devices, + _In_ size_t num_inputs); + +ORT_API_STATUS_IMPL(CreateSyncStreamForEpDevice, _In_ const OrtEpDevice* ep_device, + _In_opt_ const OrtKeyValuePairs* stream_options, + _Outptr_ OrtSyncStream** stream); + +ORT_API(void*, SyncStream_GetHandle, _In_ OrtSyncStream* stream); + +ORT_API(void, ReleaseSyncStream, _Frees_ptr_opt_ OrtSyncStream* stream); + +ORT_API_STATUS_IMPL(CopyTensors, _In_ const OrtEnv* env, + _In_reads_(num_tensors) const OrtValue* const* src_tensors, + _In_reads_(num_tensors) OrtValue* const* dst_tensors, + _In_opt_ OrtSyncStream* stream, + _In_ size_t num_tensors); } // namespace OrtApis diff --git a/onnxruntime/core/session/standalone_op_invoker.cc b/onnxruntime/core/session/standalone_op_invoker.cc index 2706448d831cc..4ac5ef5d89269 100644 --- a/onnxruntime/core/session/standalone_op_invoker.cc +++ b/onnxruntime/core/session/standalone_op_invoker.cc @@ -345,7 +345,7 @@ onnxruntime::Status CreateOpAttr(const char* name, const void* data, int len, Or attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_FLOATS); break; case OrtOpAttrType::ORT_OP_ATTR_STRING: - attr->set_s(std::string{str}); + attr->set_s(std::string{str, static_cast(len)}); attr->set_type(ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_STRING); break; case OrtOpAttrType::ORT_OP_ATTR_STRINGS: diff --git a/onnxruntime/test/autoep/library/ep.cc b/onnxruntime/test/autoep/library/ep.cc index 44b3f9a213abf..287eba05a0595 100644 --- a/onnxruntime/test/autoep/library/ep.cc +++ b/onnxruntime/test/autoep/library/ep.cc @@ -13,6 +13,7 @@ #include #include "ep_factory.h" +#include "ep_stream_support.h" /// /// Example implementation of ONNX Mul. Does not handle many things like broadcasting. @@ -163,6 +164,8 @@ ExampleEp::ExampleEp(ExampleEpFactory& factory, const std::string& name, const C GetCapability = GetCapabilityImpl; Compile = CompileImpl; ReleaseNodeComputeInfos = ReleaseNodeComputeInfosImpl; + CreateAllocator = CreateAllocatorImpl; // optional. can be nullptr + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // optional. can be nullptr auto status = ort_api.Logger_LogMessage(&logger_, OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, @@ -419,12 +422,16 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes std::array attributes = {}; DeferOrtRelease defer_release_attrs(attributes.data(), attributes.size(), ort_api.ReleaseOpAttr); - RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", "binary_data", 1, ORT_OP_ATTR_STRING, &attributes[0])); + std::string ep_ctx = "binary_data"; + RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_cache_context", ep_ctx.c_str(), static_cast(ep_ctx.length()), + ORT_OP_ATTR_STRING, &attributes[0])); RETURN_IF_ERROR(ort_api.CreateOpAttr("main_context", &is_main_context, 1, ORT_OP_ATTR_INT, &attributes[1])); RETURN_IF_ERROR(ort_api.CreateOpAttr("embed_mode", &embed_mode, 1, ORT_OP_ATTR_INT, &attributes[2])); RETURN_IF_ERROR(ort_api.CreateOpAttr("ep_sdk_version", "1", 1, ORT_OP_ATTR_STRING, &attributes[3])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, 1, ORT_OP_ATTR_STRING, &attributes[4])); - RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), 1, ORT_OP_ATTR_STRING, &attributes[5])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("partition_name", fused_node_name, static_cast(strlen(fused_node_name)), + ORT_OP_ATTR_STRING, &attributes[4])); + RETURN_IF_ERROR(ort_api.CreateOpAttr("source", this->name_.c_str(), static_cast(this->name_.length()), + ORT_OP_ATTR_STRING, &attributes[5])); RETURN_IF_ERROR(model_editor_api.CreateNode("EPContext", "com.microsoft", fused_node_name, input_names.data(), input_names.size(), @@ -435,6 +442,43 @@ OrtStatus* ExampleEp::CreateEpContextNodes(gsl::span fused_nodes return nullptr; } + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept { + // A per-session allocator could be created here. + // Logging of any issues should use ep->logger_ which is the session logger. + + ExampleEp* ep = static_cast(this_ptr); + + // for simplicity in this example we use the factory implementation. + return ep->factory_.CreateAllocator(&ep->factory_, memory_info, nullptr, allocator); +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEp::CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept { + // A per-session OrtSyncStreamImpl can be created here if the session options affect the implementation. + // Logging of any issues should use logger_ which is the session logger. + + ExampleEp* ep = static_cast(this_ptr); + + // we only create streams for the default device memory. + if (auto mem_type = ep->factory_.ep_api.MemoryDevice_GetMemoryType(memory_device); + mem_type != OrtDeviceMemoryType_DEFAULT) { + std::string error = "Invalid OrtMemoryDevice. Expected OrtDeviceMemoryType_DEFAULT(0). Got "; + error += std::to_string(mem_type); + return ep->ort_api.CreateStatus(ORT_INVALID_ARGUMENT, error.c_str()); + } + + auto sync_stream = std::make_unique(ep->factory_, ep, nullptr); + *stream = sync_stream.release(); + + return nullptr; +} + // // Implementation of ExampleNodeComputeInfo // diff --git a/onnxruntime/test/autoep/library/ep.h b/onnxruntime/test/autoep/library/ep.h index dfebcc52a0caf..fa6eb24c5cc04 100644 --- a/onnxruntime/test/autoep/library/ep.h +++ b/onnxruntime/test/autoep/library/ep.h @@ -30,12 +30,23 @@ class ExampleEp : public OrtEp, public ApiPtrs { private: static const char* ORT_API_CALL GetNameImpl(const OrtEp* this_ptr) noexcept; - static OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, - OrtEpGraphSupportInfo* graph_support_info) noexcept; + + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryInfo* memory_info, + _Outptr_result_maybenull_ OrtAllocator** allocator) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(_In_ OrtEp* this_ptr, + _In_ const OrtMemoryDevice* memory_device, + _Outptr_ OrtSyncStreamImpl** stream) noexcept; + + static OrtStatus* ORT_API_CALL GetCapabilityImpl(OrtEp* this_ptr, const OrtGraph* graph, + OrtEpGraphSupportInfo* graph_support_info) noexcept; + static OrtStatus* ORT_API_CALL CompileImpl(_In_ OrtEp* this_ptr, _In_ const OrtGraph** graphs, _In_ const OrtNode** fused_nodes, _In_ size_t count, _Out_writes_all_(count) OrtNodeComputeInfo** node_compute_infos, _Out_writes_(count) OrtNode** ep_context_nodes) noexcept; + static void ORT_API_CALL ReleaseNodeComputeInfosImpl(OrtEp* this_ptr, OrtNodeComputeInfo** node_compute_infos, size_t num_node_compute_infos) noexcept; diff --git a/onnxruntime/test/autoep/library/ep_allocator.h b/onnxruntime/test/autoep/library/ep_allocator.h index 5e6f81fd0aa9e..624b4fcb484cd 100644 --- a/onnxruntime/test/autoep/library/ep_allocator.h +++ b/onnxruntime/test/autoep/library/ep_allocator.h @@ -5,16 +5,35 @@ #include "example_plugin_ep_utils.h" +// from onnxruntime/core/framework/allocator_stats.h +struct AllocatorStats { + int64_t num_allocs; // Number of allocations. + int64_t num_reserves; // Number of reserves. (Number of calls to Reserve() in arena-based allocators) + int64_t bytes_in_use; // Number of bytes in use. + int64_t total_allocated_bytes; // The total number of allocated bytes by the allocator. + int64_t max_bytes_in_use; // The maximum bytes in use. + int64_t max_alloc_size; // The max single allocation seen. + int64_t bytes_limit; // The upper limit what the allocator can allocate, if such a limit + // is known. Certain allocator may return 0 to indicate the limit is + // unknown. +}; + struct CustomAllocator : OrtAllocator { - CustomAllocator(const OrtMemoryInfo* mem_info) : memory_info{mem_info} { + CustomAllocator(const OrtMemoryInfo* mem_info, const ApiPtrs& api_ptrs_in) + : memory_info{mem_info}, api_ptrs{api_ptrs_in} { + version = ORT_API_VERSION; Alloc = AllocImpl; Free = FreeImpl; Info = InfoImpl; - Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + GetStats = GetStatsImpl; // this can be set to nullptr if you don't want to implement it } - static void* ORT_API_CALL AllocImpl(struct OrtAllocator* /*this_*/, size_t size) { - // CustomAllocator& impl = *static_cast(this_); + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + CustomAllocator& impl = *static_cast(this_); + ++impl.stats.num_allocs; + impl.stats.max_alloc_size = std::max(size, impl.stats.max_alloc_size); + return malloc(size); } @@ -29,6 +48,24 @@ struct CustomAllocator : OrtAllocator { return impl.memory_info; } + static OrtStatus* ORT_API_CALL GetStatsImpl(const struct OrtAllocator* this_, OrtKeyValuePairs** out) noexcept { + const CustomAllocator& impl = *static_cast(this_); + + OrtKeyValuePairs* kvps; + impl.api_ptrs.ort_api.CreateKeyValuePairs(&kvps); + + // if you wish to return stats the values in GetStatus should be formatted like this: + // https://github.com/microsoft/onnxruntime/blob/2f878c60296de169a8a523e692d3d65893f7c133/onnxruntime/core/session/allocator_adapters.cc#L75-L85 + + impl.api_ptrs.ort_api.AddKeyValuePair(kvps, "NumAllocs", std::to_string(impl.stats.num_allocs).c_str()); + impl.api_ptrs.ort_api.AddKeyValuePair(kvps, "MaxAllocSize", std::to_string(impl.stats.max_alloc_size).c_str()); + + *out = kvps; + return nullptr; + } + private: const OrtMemoryInfo* memory_info; + const ApiPtrs api_ptrs; + AllocatorStats stats{}; }; diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.cc b/onnxruntime/test/autoep/library/ep_data_transfer.cc index 48f97fe88ec44..ca8d0cf089fc6 100644 --- a/onnxruntime/test/autoep/library/ep_data_transfer.cc +++ b/onnxruntime/test/autoep/library/ep_data_transfer.cc @@ -7,12 +7,10 @@ #include /*static*/ -bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(void* this_ptr, +bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(const OrtDataTransferImpl* this_ptr, const OrtMemoryDevice* src_memory_device, const OrtMemoryDevice* dst_memory_device) noexcept { - static constexpr uint32_t VendorId = 0xBE57; // Example vendor ID for demonstration purposes. - - auto& impl = *static_cast(this_ptr); + const auto& impl = *static_cast(this_ptr); bool src_is_our_device = impl.ep_api.MemoryDevice_AreEqual(src_memory_device, impl.device_mem_info); bool dst_is_our_device = impl.ep_api.MemoryDevice_AreEqual(dst_memory_device, impl.device_mem_info); @@ -24,30 +22,37 @@ bool ORT_API_CALL ExampleDataTransfer::CanCopyImpl(void* this_ptr, // and the vendor and device IDs as needed. OrtMemoryInfoDeviceType src_device_type = impl.ep_api.MemoryDevice_GetDeviceType(src_memory_device); OrtMemoryInfoDeviceType dst_device_type = impl.ep_api.MemoryDevice_GetDeviceType(dst_memory_device); - // OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_memory_device); - // OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_memory_device); - // uint32_t src_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(src_memory_device); - // uint32_t dst_device_vendor_id = impl.ep_api.MemoryDevice_GetVendorId(dst_memory_device); - // uint32_t src_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(src_memory_device); - // uint32_t dst_device_device_id = impl.ep_api.MemoryDevice_GetDeviceId(dst_memory_device); + OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_memory_device); + OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_memory_device); + // we can copy to/from CPU or CPU accessible memory if (src_is_our_device) { - // check device type and vendor to see if compatible - return (dst_device_type == OrtMemoryInfoDeviceType_CPU); + return (dst_device_type == OrtMemoryInfoDeviceType_CPU || dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE); } if (dst_is_our_device) { - // check device type and vendor to see if compatible - return (src_device_type == OrtMemoryInfoDeviceType_CPU); + return (src_device_type == OrtMemoryInfoDeviceType_CPU || src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE); } return false; } +namespace { +void CopyImpl(const void* src_data, void* dst_data, size_t bytes, OrtSyncStream* stream) { + // in our example setup this is really CPU to CPU + + if (stream) { + // EP can do an async copy using the stream. e.g. an NVIDIA EP would provide the stream to cudaMemcpyAsync + } + + memcpy(dst_data, src_data, bytes); +} +} // namespace + // function to copy one or more tensors. // implementation can optionally use async copy if a stream is available for the input. /*static*/ -OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, +OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr, OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, @@ -56,11 +61,10 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, auto src_tensors = gsl::make_span(src_tensors_ptr, num_tensors); auto dst_tensors = gsl::make_span(dst_tensors_ptr, num_tensors); - auto streams = gsl::make_span(streams_ptr, num_tensors); for (size_t i = 0; i < num_tensors; ++i) { - // NOTE: Stream support will be a separate PR. ignore teh streams_ptr values for now - + // the implementation for a 'real' EP would be something along these lines. + // See CudaDataTransferImpl in onnxruntime\core\providers\cuda\cuda_provider_factory.cc const OrtMemoryDevice* src_device = nullptr; const OrtMemoryDevice* dst_device = nullptr; RETURN_IF_ERROR(impl.ep_api.Value_GetMemoryDevice(src_tensors[i], &src_device)); @@ -71,13 +75,16 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, // OrtDeviceMemoryType src_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(src_device); // OrtDeviceMemoryType dst_mem_type = impl.ep_api.MemoryDevice_GetMemoryType(dst_device); - // bool copy_involves_pinned_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || - // dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; + // bool copy_involves_host_accessible_memory = src_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE || + // dst_mem_type == OrtDeviceMemoryType_HOST_ACCESSIBLE; const void* src_data = nullptr; void* dst_data = nullptr; + size_t bytes; + RETURN_IF_ERROR(impl.ort_api.GetTensorData(src_tensors[i], &src_data)); RETURN_IF_ERROR(impl.ort_api.GetTensorMutableData(dst_tensors[i], &dst_data)); + RETURN_IF_ERROR(impl.ort_api.GetTensorSizeInBytes(src_tensors[i], &bytes)); if (dst_device_type == OrtMemoryInfoDeviceType_GPU) { if (src_device_type == OrtMemoryInfoDeviceType_GPU) { @@ -88,15 +95,18 @@ OrtStatus* ORT_API_CALL ExampleDataTransfer::CopyTensorsImpl(void* this_ptr, } else if (src_device_type == OrtMemoryInfoDeviceType_GPU) { // GPU -> CPU } else { - // CPU -> CPU involves copy to/from pinned memory and a synchronize may be required first + // CPU -> CPU. may involve copy a to/from host accessible memory and a synchronize may be required first } + + // but in our example EP it's simpler as it's really a (fake) CPU to CPU copy + CopyImpl(src_data, dst_data, bytes, streams_ptr ? streams_ptr[i] : nullptr); } return nullptr; } /*static*/ -void ORT_API_CALL ExampleDataTransfer::ReleaseImpl(void* /*this_ptr*/) noexcept { +void ORT_API_CALL ExampleDataTransfer::ReleaseImpl(OrtDataTransferImpl* /*this_ptr*/) noexcept { // In our setup the factory owns a shared ExampleDataTransfer instance so it will do the cleanup, and we ignore // the call to Release from the plugin_ep::DataTransfer dtor (see /onnxruntime/core/framework/plugin_data_transfer.h) // diff --git a/onnxruntime/test/autoep/library/ep_data_transfer.h b/onnxruntime/test/autoep/library/ep_data_transfer.h index d73b9e457b844..da74d42b4affe 100644 --- a/onnxruntime/test/autoep/library/ep_data_transfer.h +++ b/onnxruntime/test/autoep/library/ep_data_transfer.h @@ -7,28 +7,26 @@ struct ExampleDataTransfer : OrtDataTransferImpl, ApiPtrs { ExampleDataTransfer(ApiPtrs api_ptrs, - const OrtMemoryDevice* device_mem_info_, - const OrtMemoryDevice* shared_mem_info_ = nullptr) - : ApiPtrs(api_ptrs), device_mem_info{device_mem_info_}, shared_mem_info{shared_mem_info_} { + const OrtMemoryDevice* device_mem_info_) + : ApiPtrs(api_ptrs), device_mem_info{device_mem_info_} { CanCopy = CanCopyImpl; CopyTensors = CopyTensorsImpl; Release = ReleaseImpl; } - static bool ORT_API_CALL CanCopyImpl(void* this_ptr, + static bool ORT_API_CALL CanCopyImpl(const OrtDataTransferImpl* this_ptr, const OrtMemoryDevice* src_memory_device, const OrtMemoryDevice* dst_memory_device) noexcept; // function to copy one or more tensors. // implementation can optionally use async copy if a stream is available for the input. - static OrtStatus* ORT_API_CALL CopyTensorsImpl(void* this_ptr, + static OrtStatus* ORT_API_CALL CopyTensorsImpl(OrtDataTransferImpl* this_ptr, const OrtValue** src_tensors_ptr, OrtValue** dst_tensors_ptr, OrtSyncStream** streams_ptr, size_t num_tensors) noexcept; - static void ORT_API_CALL ReleaseImpl(void* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(OrtDataTransferImpl* this_ptr) noexcept; private: - const OrtMemoryDevice* device_mem_info; - const OrtMemoryDevice* shared_mem_info; + const OrtMemoryDevice* device_mem_info; // device our EP runs on }; diff --git a/onnxruntime/test/autoep/library/ep_factory.cc b/onnxruntime/test/autoep/library/ep_factory.cc index 19a44008b8c97..97cdea15187d9 100644 --- a/onnxruntime/test/autoep/library/ep_factory.cc +++ b/onnxruntime/test/autoep/library/ep_factory.cc @@ -8,6 +8,7 @@ #include "ep.h" #include "ep_allocator.h" #include "ep_data_transfer.h" +#include "ep_stream_support.h" ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) : ApiPtrs(apis), ep_name_{ep_name} { @@ -27,35 +28,27 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) CreateDataTransfer = CreateDataTransferImpl; - // for the sake of this example we specify a CPU allocator with no arena and 1K alignment (arbitrary) - // as well as GPU and GPU shared memory. the actual EP implementation would typically define two at most for a - // device (one for device memory and one for shared memory for data transfer between device and CPU) + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; // setup the OrtMemoryInfo instances required by the EP. + // We pretend the device the EP is running on is GPU. OrtMemoryInfo* mem_info = nullptr; - auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP CPU", OrtMemoryInfoDeviceType_CPU, + auto* status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU, /*vendor*/ 0xBE57, /* device_id */ 0, OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 1024, - OrtAllocatorType::OrtDeviceAllocator, // no arena + /*alignment*/ 0, + OrtAllocatorType::OrtDeviceAllocator, &mem_info); assert(status == nullptr); // should never fail. + default_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - cpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); + // create data transfer for the device + const OrtMemoryDevice* device = ep_api.MemoryInfo_GetMemoryDevice(default_memory_info_.get()); + data_transfer_impl_ = std::make_unique(apis, device); - // - // GPU allocator OrtMemoryInfo for example purposes - mem_info = nullptr; - status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU", OrtMemoryInfoDeviceType_GPU, - /*vendor*/ 0xBE57, /* device_id */ 0, - OrtDeviceMemoryType_DEFAULT, - /*alignment*/ 0, - OrtAllocatorType::OrtDeviceAllocator, - &mem_info); - assert(status == nullptr); // should never fail. - default_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - - // HOST_ACCESSIBLE memory should use the non-CPU device type + // HOST_ACCESSIBLE memory example. use the non-CPU device type so it's clear which device the memory is also + // accessible from. we infer from the type of HOST_ACCESSIBLE that it's CPU accessible. mem_info = nullptr; status = ort_api.CreateMemoryInfo_V2("ExampleEP GPU pinned", OrtMemoryInfoDeviceType_GPU, /*vendor*/ 0xBE57, /* device_id */ 0, @@ -63,17 +56,7 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis) /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator, &mem_info); - assert(status == nullptr); // should never fail. - host_accessible_gpu_memory_info_ = MemoryInfoUniquePtr(mem_info, ort_api.ReleaseMemoryInfo); - - // if we were to use GPU we'd create it like this - data_transfer_impl_ = std::make_unique( - apis, - ep_api.MemoryInfo_GetMemoryDevice(default_gpu_memory_info_.get()), // device memory - ep_api.MemoryInfo_GetMemoryDevice(host_accessible_gpu_memory_info_.get()) // shared memory - ); - - data_transfer_impl_.reset(); // but we're CPU only so we return nullptr for the IDataTransfer. + ort_api.ReleaseMemoryInfo(mem_info); } /*static*/ @@ -137,9 +120,8 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::GetSupportedDevicesImpl(OrtEpFactory* } // register the allocator info required by the EP. - // in this example we register CPU info which is unnecessary unless you need to override the default ORT allocator - // for a non-CPU EP this would be device info (GPU/NPU) and possible host accessible info. - RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->cpu_memory_info_.get())); + // registering OrtMemoryInfo for host accessible memory would be done in an additional call. + RETURN_IF_ERROR(factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, factory->default_memory_info_.get())); ep_devices[num_ep_devices++] = ep_device; } @@ -218,32 +200,17 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateAllocatorImpl(OrtEpFactory* this auto& factory = *static_cast(this_ptr); *allocator = nullptr; - // NOTE: The factory implementation can return a shared OrtAllocator* instead of creating a new instance on each call. - // To do this just make ReleaseAllocatorImpl a no-op. - - // NOTE: If OrtMemoryInfo has allocator type (call MemoryInfoGetType) of OrtArenaAllocator, an ORT BFCArena - // will be added to wrap the returned OrtAllocator. The EP is free to implement its own arena, and if it - // wants to do this the OrtMemoryInfo MUST be created with an allocator type of OrtDeviceAllocator. - - // NOTE: The OrtMemoryInfo pointer should only ever be coming straight from an OrtEpDevice, and pointer based - // matching should work. - if (memory_info == factory.cpu_memory_info_.get()) { - // create a CPU allocator. use the basic OrtAllocator for this example. - auto cpu_allocator = std::make_unique(memory_info); - *allocator = cpu_allocator.release(); - } else if (memory_info == factory.default_gpu_memory_info_.get()) { - // create a GPU allocator - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Example is not implemented."); - } else if (memory_info == factory.host_accessible_gpu_memory_info_.get()) { - // create a pinned/shared memory allocator. Use the real device type (i.e. GPU/NPU) and id and a memory type of - // OrtDeviceMemoryType_HOST_ACCESSIBLE. - return factory.ort_api.CreateStatus(ORT_NOT_IMPLEMENTED, "Example is not implemented."); - } else { + if (memory_info != factory.default_memory_info_.get()) { return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, "INTERNAL ERROR! Unknown memory info provided to CreateAllocator. " "Value did not come directly from an OrtEpDevice returned by this factory."); } + // NOTE: The factory implementation is free to return a shared OrtAllocator* instance instead of creating a new + // allocator on each call. To do this have an allocator instance as an OrtEpFactory class member and make + // ReleaseAllocatorImpl a no-op. + auto cpu_allocator = std::make_unique(memory_info, factory); + *allocator = cpu_allocator.release(); return nullptr; } @@ -260,3 +227,25 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateDataTransferImpl(OrtEpFactory* t return nullptr; } + +/*static*/ +bool ORT_API_CALL ExampleEpFactory::IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return true; // the example EP implements stream synchronization. +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept { + auto& factory = *static_cast(this_ptr); + *stream = nullptr; + + // we only need stream synchronization on the device stream + if (factory.ep_api.MemoryDevice_GetMemoryType(memory_device) == OrtDeviceMemoryType_DEFAULT) { + auto sync_stream = std::make_unique(factory, /*OrtEp**/ nullptr, stream_options); + *stream = sync_stream.release(); + } + + return nullptr; +} diff --git a/onnxruntime/test/autoep/library/ep_factory.h b/onnxruntime/test/autoep/library/ep_factory.h index 72fa1c1301841..261730b8adf83 100644 --- a/onnxruntime/test/autoep/library/ep_factory.h +++ b/onnxruntime/test/autoep/library/ep_factory.h @@ -52,6 +52,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* this_ptr, OrtDataTransferImpl** data_transfer) noexcept; + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* this_ptr) noexcept; + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* memory_device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept; + const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID @@ -59,12 +66,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { // CPU allocator so we can control the arena behavior. optional as ORT always provides a CPU allocator if needed. using MemoryInfoUniquePtr = std::unique_ptr>; - MemoryInfoUniquePtr cpu_memory_info_; - - // for example purposes. if the EP used GPU, and pinned/shared memory was required for data transfer, these are the - // OrtMemoryInfo instance required for that. - MemoryInfoUniquePtr default_gpu_memory_info_; - MemoryInfoUniquePtr host_accessible_gpu_memory_info_; + MemoryInfoUniquePtr default_memory_info_; std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory }; diff --git a/onnxruntime/test/autoep/library/ep_stream_support.cc b/onnxruntime/test/autoep/library/ep_stream_support.cc new file mode 100644 index 0000000000000..a948fe1bfce1e --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_stream_support.cc @@ -0,0 +1,88 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "ep_stream_support.h" + +// +// StreamImpl implementation +// + +/*static*/ +OrtStatus* ORT_API_CALL StreamImpl::CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** notification) noexcept { + auto& impl = *static_cast(this_ptr); + *notification = std::make_unique(impl).release(); + return nullptr; +} + +/*static*/ +void* ORT_API_CALL StreamImpl::GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + return impl.handle_; +} + +/*static*/ +OrtStatus* ORT_API_CALL StreamImpl::FlushImpl(_In_ OrtSyncStreamImpl* /*this_ptr*/) noexcept { + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL StreamImpl::OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* /*this_ptr*/) noexcept { + return nullptr; +} + +// callback for EP library to release any internal state +/*static*/ +void ORT_API_CALL StreamImpl::ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} + +// +// Notification support +// + +/*static*/ +OrtStatus* ORT_API_CALL NotificationImpl::ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + static_cast(impl); + + // e.g. + // CUDA: cudaEventRecord + // CANN: aclrtRecordEvent + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL NotificationImpl::WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept { + auto& impl = *static_cast(this_ptr); + void* handle = impl.ort_api.SyncStream_GetHandle(stream); + static_cast(handle); + + // Setup the event or similar that will be activated on notification. + // See CudaNotification or CannNotification for examples. + // + // e.g. + // CUDA: cudaStreamWaitEvent(static_cast(device_stream.GetHandle()), event_) + // CANN: aclrtStreamWaitEvent(static_cast(device_stream.GetHandle()), event_) + // + // `event_` should be a member that is created in the ctor. + // The stream handle should come from the StreamImpl instance and can be the real type so no static_cast is needed. + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL NotificationImpl::WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + auto& impl = *static_cast(this_ptr); + static_cast(impl); + + // e.g. + // CUDA: cudaEventSynchronize(event_) + // CANN: aclrtSynchronizeEvent(event_) + return nullptr; +} + +/*static*/ +void ORT_API_CALL NotificationImpl::ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept { + delete static_cast(this_ptr); +} diff --git a/onnxruntime/test/autoep/library/ep_stream_support.h b/onnxruntime/test/autoep/library/ep_stream_support.h new file mode 100644 index 0000000000000..10c4804722f8b --- /dev/null +++ b/onnxruntime/test/autoep/library/ep_stream_support.h @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "onnxruntime_c_api.h" +#include "example_plugin_ep_utils.h" + +// +// Class implementing Stream support for synchronization. +// +class StreamImpl : public OrtSyncStreamImpl, public ApiPtrs { + public: + StreamImpl(ApiPtrs apis, const OrtEp* ep, const OrtKeyValuePairs* /*stream_options*/) + : ApiPtrs(apis), ep_{ep} { + ort_version_supported = ORT_API_VERSION; + CreateNotification = CreateNotificationImpl; + GetHandle = GetHandleImpl; + Flush = FlushImpl; + OnSessionRunEnd = OnSessionRunEndImpl; + Release = ReleaseImpl; + } + + private: + static OrtStatus* ORT_API_CALL CreateNotificationImpl(_In_ OrtSyncStreamImpl* this_ptr, + _Outptr_ OrtSyncNotificationImpl** sync_notification) noexcept; + static void* ORT_API_CALL GetHandleImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL FlushImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL OnSessionRunEndImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncStreamImpl* this_ptr) noexcept; + + void* handle_{nullptr}; // use the real stream type, like cudaStream_t or aclrtStream, etc. + + // EP instance if the stream is being created internally for inferencing. + // nullptr when the stream is created outside of an inference session for data copies. + const OrtEp* ep_; +}; + +// +// Class implementing synchronization notification support. +// +class NotificationImpl : public OrtSyncNotificationImpl, public ApiPtrs { + public: + NotificationImpl(ApiPtrs apis) : ApiPtrs(apis) { + ort_version_supported = ORT_API_VERSION; + Activate = ActivateImpl; + Release = ReleaseImpl; + WaitOnDevice = WaitOnDeviceImpl; + WaitOnHost = WaitOnHostImpl; + } + + private: + static OrtStatus* ORT_API_CALL ActivateImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static OrtStatus* ORT_API_CALL WaitOnDeviceImpl(_In_ OrtSyncNotificationImpl* this_ptr, + _In_ OrtSyncStream* stream) noexcept; + static OrtStatus* ORT_API_CALL WaitOnHostImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + static void ORT_API_CALL ReleaseImpl(_In_ OrtSyncNotificationImpl* this_ptr) noexcept; + + void* event_{NULL}; // placeholder. e.g. CANN uses aclrtEvent, CUDA uses cudaEvent_t +}; diff --git a/onnxruntime/test/autoep/test_allocators.cc b/onnxruntime/test/autoep/test_allocators.cc new file mode 100644 index 0000000000000..84b6e284ccb8e --- /dev/null +++ b/onnxruntime/test/autoep/test_allocators.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +#include +#include +#include + +#include "core/framework/allocator.h" +#include "core/session/abi_key_value_pairs.h" +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +namespace { +struct DummyAllocator : OrtAllocator { + DummyAllocator(const OrtMemoryInfo* mem_info) + : memory_info{mem_info} { + Alloc = AllocImpl; + Free = FreeImpl; + Info = InfoImpl; + Reserve = AllocImpl; // no special reserve logic and most likely unnecessary unless you have your own arena + GetStats = nullptr; // this can be set to nullptr if not implemented + } + + static void* ORT_API_CALL AllocImpl(struct OrtAllocator* this_, size_t size) { + auto& impl = *static_cast(this_); + ++impl.stats.num_allocs; + impl.stats.max_alloc_size = std::max(size, impl.stats.max_alloc_size); + + return malloc(size); + } + + static void ORT_API_CALL FreeImpl(struct OrtAllocator* /*this_*/, void* p) { + return free(p); + } + + static const struct OrtMemoryInfo* ORT_API_CALL InfoImpl(const struct OrtAllocator* this_) { + const DummyAllocator& impl = *static_cast(this_); + return impl.memory_info; + } + + private: + const OrtMemoryInfo* memory_info; + AllocatorStats stats{}; +}; +} // namespace + +// validate CreateSharedAllocator allows adding an arena to the shared allocator +TEST(SharedAllocators, AddArenaToSharedAllocator) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + + const auto* ep_memory_info = c_api.EpDevice_MemoryInfo(example_ep.get(), OrtDeviceMemoryType_DEFAULT); + + // validate there is a shared allocator + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, ep_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + + // call CreateSharedAllocator to replace with arena based allocator. arena is configured with kvps + OrtKeyValuePairs allocator_options; + auto initial_chunk_size = "25600"; // arena allocates in 256 byte amounts + allocator_options.Add(OrtArenaCfg::ConfigKeyNames::InitialChunkSizeBytes, initial_chunk_size); + + ASSERT_ORTSTATUS_OK(c_api.CreateSharedAllocator(*ort_env, example_ep.get(), + OrtDeviceMemoryType_DEFAULT, OrtArenaAllocator, &allocator_options, + &allocator)); + + // first allocation should init the arena to the initial chunk size + void* mem = allocator->Alloc(allocator, 16); + allocator->Free(allocator, mem); + + // stats should prove the arena was used + OrtKeyValuePairs* allocator_stats = nullptr; + ASSERT_ORTSTATUS_OK(allocator->GetStats(allocator, &allocator_stats)); + + using ::testing::Contains; + using ::testing::Pair; + const auto& stats = allocator_stats->Entries(); + EXPECT_THAT(stats, Contains(Pair("NumAllocs", "1"))); + EXPECT_THAT(stats, Contains(Pair("NumArenaExtensions", "1"))); + EXPECT_THAT(stats, Contains(Pair("TotalAllocated", initial_chunk_size))); + + // optional. ORT owns the allocator but we want to test the release implementation + ASSERT_ORTSTATUS_OK(c_api.ReleaseSharedAllocator(*ort_env, example_ep.get(), OrtDeviceMemoryType_DEFAULT)); +} + +TEST(SharedAllocators, GetSharedAllocator) { + const OrtApi& c_api = Ort::GetApi(); + + // default CPU allocator should be available. + // create a memory info with a different name to validate the shared allocator lookup ignores the name + OrtMemoryInfo* test_cpu_memory_info = nullptr; + ASSERT_ORTSTATUS_OK(c_api.CreateMemoryInfo_V2("dummy", OrtMemoryInfoDeviceType_CPU, 0, 0, + OrtDeviceMemoryType_DEFAULT, 0, OrtDeviceAllocator, + &test_cpu_memory_info)); + + const auto get_allocator_and_check_name = [&](const std::string& expected_name) { + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, test_cpu_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + + const OrtMemoryInfo* ort_cpu_memory_info = nullptr; + ASSERT_ORTSTATUS_OK(c_api.AllocatorGetInfo(allocator, &ort_cpu_memory_info)); + const char* allocator_name; + ASSERT_ORTSTATUS_OK(c_api.MemoryInfoGetName(ort_cpu_memory_info, &allocator_name)); + ASSERT_EQ(expected_name, allocator_name); // Default ORT CPU allocator + }; + + // check we get the default ORT CPU allocator initially + get_allocator_and_check_name(onnxruntime::CPU); + + // register custom allocator and make sure that is accessible by exact match + DummyAllocator dummy_alloc{test_cpu_memory_info}; + c_api.RegisterAllocator(*ort_env, &dummy_alloc); + + // GetSharedAllocator should now match the custom allocator + get_allocator_and_check_name("dummy"); + + // unregister custom allocator + ASSERT_ORTSTATUS_OK(c_api.UnregisterAllocator(*ort_env, test_cpu_memory_info)); + + // there should always be a CPU allocator available + get_allocator_and_check_name(onnxruntime::CPU); + + c_api.ReleaseMemoryInfo(test_cpu_memory_info); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_autoep_utils.cc b/onnxruntime/test/autoep/test_autoep_utils.cc new file mode 100644 index 0000000000000..7045ccca2f576 --- /dev/null +++ b/onnxruntime/test/autoep/test_autoep_utils.cc @@ -0,0 +1,57 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include "test/autoep/test_autoep_utils.h" + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/api_asserts.h" + +namespace onnxruntime { +namespace test { + +Utils::ExamplePluginInfo Utils::example_ep_info; + +void Utils::GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device) { + const OrtApi& c_api = Ort::GetApi(); + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices; + ASSERT_ORTSTATUS_OK(c_api.GetEpDevices(env, &ep_devices, &num_devices)); + + auto it = std::find_if(ep_devices, ep_devices + num_devices, + [&c_api, &ep_name](const OrtEpDevice* ep_device) { + // example ep uses registration name as ep name + return c_api.EpDevice_EpName(ep_device) == ep_name; + }); + + if (it == ep_devices + num_devices) { + ep_device = nullptr; + } else { + ep_device = *it; + } +} + +void Utils::RegisterAndGetExampleEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& registered_ep) { + const OrtApi& c_api = Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(c_api.RegisterExecutionProviderLibrary(env, + example_ep_info.registration_name.c_str(), + example_ep_info.library_path.c_str())); + const OrtEpDevice* example_ep = nullptr; + GetEp(env, example_ep_info.registration_name, example_ep); + ASSERT_NE(example_ep, nullptr); + + registered_ep = RegisteredEpDeviceUniquePtr(example_ep, [&env, c_api](const OrtEpDevice* /*ep*/) { + c_api.UnregisterExecutionProviderLibrary(env, example_ep_info.registration_name.c_str()); + }); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_autoep_utils.h b/onnxruntime/test/autoep/test_autoep_utils.h new file mode 100644 index 0000000000000..2dd7b5f0428e2 --- /dev/null +++ b/onnxruntime/test/autoep/test_autoep_utils.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +namespace onnxruntime { +namespace test { + +using RegisteredEpDeviceUniquePtr = std::unique_ptr>; + +struct Utils { + struct ExamplePluginInfo { + const std::filesystem::path library_path = +#if _WIN32 + "example_plugin_ep.dll"; +#else + "libexample_plugin_ep.so"; +#endif + const std::string registration_name = "example_ep"; + }; + + static ExamplePluginInfo example_ep_info; + + // get the OrtEpDevice for an arbitrary EP from the environment + static void GetEp(Ort::Env& env, const std::string& ep_name, const OrtEpDevice*& ep_device); + + // Register the example EP library, get the OrtEpDevice for it, and return a unique pointer that will + // automatically unregister the EP library. + static void RegisterAndGetExampleEp(Ort::Env& env, RegisteredEpDeviceUniquePtr& example_ep); +}; +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/autoep/test_data_transfer.cc b/onnxruntime/test/autoep/test_data_transfer.cc new file mode 100644 index 0000000000000..cc09699b754b6 --- /dev/null +++ b/onnxruntime/test/autoep/test_data_transfer.cc @@ -0,0 +1,81 @@ + + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/common/random_generator.h" +#include "test/util/include/api_asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +TEST(OrtEpLibrary, DataTransfer) { + const OrtApi& c_api = Ort::GetApi(); + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + const OrtEpDevice* ep_device = example_ep.get(); + + const OrtMemoryInfo* device_memory_info = c_api.EpDevice_MemoryInfo(ep_device, OrtDeviceMemoryType_DEFAULT); + + // create a tensor using the default CPU allocator + Ort::AllocatorWithDefaultOptions cpu_allocator; + std::vector shape{2, 3, 4}; // shape doesn't matter + const size_t num_elements = 2 * 3 * 4; + + RandomValueGenerator random{}; + std::vector input_data = random.Gaussian(shape, 0.0f, 2.f); + Ort::Value cpu_tensor = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), + input_data.data(), input_data.size(), + shape.data(), shape.size()); + + // create an on-device Tensor using the example EPs alternative CPU allocator. + // it has a different vendor to the default ORT CPU allocator so we can copy between them even though both are + // really CPU based. + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetSharedAllocator(*ort_env, device_memory_info, &allocator)); + ASSERT_NE(allocator, nullptr); + Ort::Value device_tensor = Ort::Value::CreateTensor(allocator, shape.data(), shape.size()); + + std::vector src_tensor_ptrs{cpu_tensor}; + std::vector dst_tensor_ptrs{device_tensor}; + + ASSERT_ORTSTATUS_OK(c_api.CopyTensors(*ort_env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), nullptr, + src_tensor_ptrs.size())); + + const float* src_data = nullptr; + const float* dst_data = nullptr; + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(cpu_tensor, reinterpret_cast(&src_data))); + ASSERT_ORTSTATUS_OK(c_api.GetTensorData(device_tensor, reinterpret_cast(&dst_data))); + + size_t bytes; + ASSERT_ORTSTATUS_OK(c_api.GetTensorSizeInBytes(cpu_tensor, &bytes)); + ASSERT_EQ(bytes, num_elements * sizeof(float)); + + ASSERT_NE(src_data, dst_data) << "Should have copied between two different memory locations"; + + auto src_span = gsl::make_span(src_data, num_elements); + auto dst_span = gsl::make_span(dst_data, num_elements); + + EXPECT_THAT(src_span, ::testing::ContainerEq(dst_span)); + + // must release this before we unload the EP and the allocator is deleted + device_tensor = Ort::Value(); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc new file mode 100644 index 0000000000000..f1ef67e1f6ba4 --- /dev/null +++ b/onnxruntime/test/autoep/test_execution.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +// #include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/shared_lib/utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +namespace { +void RunModelWithPluginEp(Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); + + // Create input + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data(6, 2.0f); + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); +} +} // namespace + +// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + const OrtEpDevice* plugin_ep_device = example_ep.get(); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + + RunModelWithPluginEp(session_options); +} + +// Creates a session with the example plugin EP and runs a model with a single Mul node. +// Uses the PREFER_CPU policy to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + + { + // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunModelWithPluginEp(session_options); + } +} + +// Generate an EPContext model with a plugin EP. +// This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. +TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { + RegisteredEpDeviceUniquePtr example_ep; + Utils::RegisterAndGetExampleEp(*ort_env, example_ep); + const OrtEpDevice* plugin_ep_device = example_ep.get(); + + { + const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); + const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_ctx.onnx"); + std::filesystem::remove(output_model_file); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + + session_options.AppendExecutionProvider_V2(*ort_env, {Ort::ConstEpDevice(plugin_ep_device)}, ep_options); + + // Create model compilation options from the session options. + Ort::ModelCompilationOptions compile_options(*ort_env, session_options); + compile_options.SetInputModelPath(input_model_file); + compile_options.SetOutputModelPath(output_model_file); + + // Compile the model. + Ort::Status status = Ort::CompileModel(*ort_env, compile_options); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + + // Make sure the compiled model was generated. + ASSERT_TRUE(std::filesystem::exists(output_model_file)); + } +} +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_registration.cc b/onnxruntime/test/autoep/test_registration.cc new file mode 100644 index 0000000000000..88c2e320990e1 --- /dev/null +++ b/onnxruntime/test/autoep/test_registration.cc @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// registration/selection is only supported on windows as there's no device discovery on other platforms +#ifdef _WIN32 + +#include +#include +#include + +#include "core/session/onnxruntime_cxx_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" + +#include "test/autoep/test_autoep_utils.h" +#include "test/util/include/api_asserts.h" +#include "test/util/include/asserts.h" + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { + const std::filesystem::path& library_path = Utils::example_ep_info.library_path; + const std::string& registration_name = Utils::example_ep_info.registration_name; + + const OrtApi* c_api = &Ort::GetApi(); + // this should load the library and create OrtEpDevice + ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(*ort_env, registration_name.c_str(), + library_path.c_str())); + + const OrtEpDevice* const* ep_devices = nullptr; + size_t num_devices = 0; + + ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(*ort_env, &ep_devices, &num_devices)); + // should be one device for the example EP + auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, + [®istration_name, &c_api](const OrtEpDevice* device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return c_api->EpDevice_EpName(device) == registration_name; + }); + ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; + + // and this should unload it + ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(*ort_env, + registration_name.c_str())); +} + +TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { + const std::filesystem::path& library_path = Utils::example_ep_info.library_path; + const std::string& registration_name = Utils::example_ep_info.registration_name; + + // this should load the library and create OrtEpDevice + ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); + + std::vector ep_devices = ort_env->GetEpDevices(); + + // should be one device for the example EP + auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), + [®istration_name](Ort::ConstEpDevice& device) { + // the example uses the registration name for the EP name + // but that is not a requirement and the two can differ. + return device.EpName() == registration_name; + }); + ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; + + // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc + ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); + + auto metadata = test_ep_device->EpMetadata(); + ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); + ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); + + auto options = test_ep_device->EpOptions(); + ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); + + // the CPU device info will vary by machine so check for the lowest common denominator values + Ort::ConstHardwareDevice device = test_ep_device->Device(); + ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); + ASSERT_GE(device.VendorId(), 0); + ASSERT_GE(device.DeviceId(), 0); + ASSERT_NE(device.Vendor(), nullptr); + Ort::ConstKeyValuePairs device_metadata = device.Metadata(); + std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); + ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows + + // and this should unload it without throwing + ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); +} + +} // namespace test +} // namespace onnxruntime + +#endif // _WIN32 diff --git a/onnxruntime/test/autoep/test_autoep_selection.cc b/onnxruntime/test/autoep/test_selection.cc similarity index 67% rename from onnxruntime/test/autoep/test_autoep_selection.cc rename to onnxruntime/test/autoep/test_selection.cc index 01dece34e50b0..72f39be917f90 100644 --- a/onnxruntime/test/autoep/test_autoep_selection.cc +++ b/onnxruntime/test/autoep/test_selection.cc @@ -5,18 +5,17 @@ #ifdef _WIN32 #include -#include +// #include +#include #include -#include "core/common/common.h" -#include "core/framework/provider_options.h" #include "core/graph/constants.h" #include "core/session/abi_key_value_pairs.h" #include "core/session/abi_session_options_impl.h" #include "core/session/onnxruntime_cxx_api.h" -#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "test_allocator.h" +#include "test/autoep/test_autoep_utils.h" #include "test/shared_lib/utils.h" #include "test/util/include/api_asserts.h" #include "test/util/include/asserts.h" @@ -501,212 +500,6 @@ TEST(AutoEpSelection, PolicyDelegateReturnsError) { Ort::Exception); } -namespace { -struct ExamplePluginInfo { - const std::filesystem::path library_path = -#if _WIN32 - "example_plugin_ep.dll"; -#else - "libexample_plugin_ep.so"; -#endif - const std::string registration_name = "example_ep"; -}; - -static const ExamplePluginInfo example_plugin_info; -} // namespace - -TEST(OrtEpLibrary, LoadUnloadPluginLibrary) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - OrtEnv* c_api_env = *ort_env; - const OrtApi* c_api = &Ort::GetApi(); - // this should load the library and create OrtEpDevice - ASSERT_ORTSTATUS_OK(Ort::GetApi().RegisterExecutionProviderLibrary(c_api_env, registration_name.c_str(), - library_path.c_str())); - - const OrtEpDevice* const* ep_devices = nullptr; - size_t num_devices = 0; - - ASSERT_ORTSTATUS_OK(Ort::GetApi().GetEpDevices(c_api_env, &ep_devices, &num_devices)); - // should be one device for the example EP - auto num_test_ep_devices = std::count_if(ep_devices, ep_devices + num_devices, - [®istration_name, &c_api](const OrtEpDevice* device) { - // the example uses the registration name for the EP name - // but that is not a requirement and the two can differ. - return c_api->EpDevice_EpName(device) == registration_name; - }); - ASSERT_EQ(num_test_ep_devices, 1) << "Expected an OrtEpDevice to have been created by the test library."; - - // and this should unload it - ASSERT_ORTSTATUS_OK(Ort::GetApi().UnregisterExecutionProviderLibrary(c_api_env, - registration_name.c_str())); -} - -TEST(OrtEpLibrary, LoadUnloadPluginLibraryCxxApi) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - // this should load the library and create OrtEpDevice - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - std::vector ep_devices = ort_env->GetEpDevices(); - - // should be one device for the example EP - auto test_ep_device = std::find_if(ep_devices.begin(), ep_devices.end(), - [®istration_name](Ort::ConstEpDevice& device) { - // the example uses the registration name for the EP name - // but that is not a requirement and the two can differ. - return device.EpName() == registration_name; - }); - ASSERT_NE(test_ep_device, ep_devices.end()) << "Expected an OrtEpDevice to have been created by the test library."; - - // test all the C++ getters. expected values are from \onnxruntime\test\autoep\library\example_plugin_ep.cc - ASSERT_STREQ(test_ep_device->EpVendor(), "Contoso"); - - auto metadata = test_ep_device->EpMetadata(); - ASSERT_STREQ(metadata.GetValue(kOrtEpDevice_EpMetadataKey_Version), "0.1.0"); - ASSERT_STREQ(metadata.GetValue("supported_devices"), "CrackGriffin 7+"); - - auto options = test_ep_device->EpOptions(); - ASSERT_STREQ(options.GetValue("run_really_fast"), "true"); - - // the CPU device info will vary by machine so check for the lowest common denominator values - Ort::ConstHardwareDevice device = test_ep_device->Device(); - ASSERT_EQ(device.Type(), OrtHardwareDeviceType_CPU); - ASSERT_GE(device.VendorId(), 0); - ASSERT_GE(device.DeviceId(), 0); - ASSERT_NE(device.Vendor(), nullptr); - Ort::ConstKeyValuePairs device_metadata = device.Metadata(); - std::unordered_map metadata_entries = device_metadata.GetKeyValuePairs(); - ASSERT_GT(metadata_entries.size(), 0); // should have at least SPDRP_HARDWAREID on Windows - - // and this should unload it without throwing - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} - -static void RunModelWithPluginEp(Ort::SessionOptions& session_options) { - Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), session_options); - - // Create input - Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); - std::vector shape = {3, 2}; - std::vector input0_data(6, 2.0f); - std::vector ort_inputs; - std::vector ort_input_names; - - ort_inputs.emplace_back(Ort::Value::CreateTensor( - memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); - ort_input_names.push_back("X"); - - // Run session and get outputs - std::array output_names{"Y"}; - std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), - ort_inputs.size(), output_names.data(), output_names.size()); - - // Check expected output values - Ort::Value& ort_output = ort_outputs[0]; - const float* output_data = ort_output.GetTensorData(); - gsl::span output_span(output_data, 6); - EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); -} - -// Creates a session with the example plugin EP and runs a model with a single Mul node. -// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. -TEST(OrtEpLibrary, PluginEp_AppendV2_MulInference) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - { - std::vector ep_devices = ort_env->GetEpDevices(); - - // Find the OrtEpDevice associated with our example plugin EP. - Ort::ConstEpDevice plugin_ep_device; - for (Ort::ConstEpDevice& device : ep_devices) { - if (std::string(device.EpName()) == registration_name) { - plugin_ep_device = device; - break; - } - } - ASSERT_NE(plugin_ep_device, nullptr); - - // Create session with example plugin EP - Ort::SessionOptions session_options; - std::unordered_map ep_options; - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - - RunModelWithPluginEp(session_options); - } - - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} - -// Creates a session with the example plugin EP and runs a model with a single Mul node. -// Uses the PREFER_CPU policy to append the example plugin EP to the session. -TEST(OrtEpLibrary, PluginEp_PreferCpu_MulInference) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - { - // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. - Ort::SessionOptions session_options; - session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); - RunModelWithPluginEp(session_options); - } - - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} - -// Generate an EPContext model with a plugin EP. -// This test uses the OrtCompileApi but could also be done by setting the appropriate session option configs. -TEST(OrtEpLibrary, PluginEp_GenEpContextModel) { - const std::filesystem::path& library_path = example_plugin_info.library_path; - const std::string& registration_name = example_plugin_info.registration_name; - - ort_env->RegisterExecutionProviderLibrary(registration_name.c_str(), library_path.c_str()); - - { - const ORTCHAR_T* input_model_file = ORT_TSTR("testdata/mul_1.onnx"); - const ORTCHAR_T* output_model_file = ORT_TSTR("plugin_ep_mul_1_ctx.onnx"); - std::filesystem::remove(output_model_file); - - std::vector ep_devices = ort_env->GetEpDevices(); - - // Find the OrtEpDevice associated with our example plugin EP. - Ort::ConstEpDevice plugin_ep_device; - for (Ort::ConstEpDevice& device : ep_devices) { - if (std::string(device.EpName()) == registration_name) { - plugin_ep_device = device; - break; - } - } - ASSERT_NE(plugin_ep_device, nullptr); - - // Create session with example plugin EP - Ort::SessionOptions session_options; - std::unordered_map ep_options; - - session_options.AppendExecutionProvider_V2(*ort_env, std::vector{plugin_ep_device}, ep_options); - - // Create model compilation options from the session options. - Ort::ModelCompilationOptions compile_options(*ort_env, session_options); - compile_options.SetInputModelPath(input_model_file); - compile_options.SetOutputModelPath(output_model_file); - - // Compile the model. - Ort::Status status = Ort::CompileModel(*ort_env, compile_options); - ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); - - // Make sure the compiled model was generated. - ASSERT_TRUE(std::filesystem::exists(output_model_file)); - } - - ort_env->UnregisterExecutionProviderLibrary(registration_name.c_str()); -} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 9bc50ce88ef16..2335db69ff571 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -12,6 +12,7 @@ #include "core/common/common.h" #include "core/common/optional.h" #include "core/common/type_utils.h" +#include "core/framework/float16.h" #include "core/framework/int4.h" #include "test/util/include/test_random_seed.h" diff --git a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc old mode 100644 new mode 100755 index bc4f72b09ef09..334be3e03b483 --- a/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc +++ b/onnxruntime/test/contrib_ops/gather_block_quantized_op_test.cc @@ -100,7 +100,8 @@ void RunGatherBlockQuantized(const std::vector& data, const int64_t bits, const std::vector& output, const std::vector& output_shape, - OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess) { + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, + bool touch_on_device_data = false) { CheckDataAndShape(data, data_shape, "data in RunGatherBlockQuantized"); CheckDataAndShape(indices, indices_shape, "indices in RunGatherBlockQuantized"); CheckDataAndShape(scales, scales_shape, "scales in RunGatherBlockQuantized"); @@ -126,9 +127,12 @@ void RunGatherBlockQuantized(const std::vector& data, test.AddOutput("output", output_shape, output); - std::vector> eps; - eps.push_back(DefaultCpuExecutionProvider()); - test.Run(expect_result, "", {}, nullptr, &eps); + if (touch_on_device_data) { + // test would need to see data on device + test.Run(expect_result, "", {kWebGpuExecutionProvider}, nullptr); + } else { + test.Run(expect_result, ""); + } }; run_test(false); @@ -181,7 +185,8 @@ void RunUnpackedData( const int64_t bits, const std::vector& output, const std::vector& output_shape, - bool expect_success) { + bool expect_success, + bool touch_on_device_data = false) { CheckDataAndShape(unpacked_data, unpacked_data_shape, "unpacked_data"); CheckDataAndShape(indices, indices_shape, "indices"); CheckDataAndShape(scales, scales_shape, "scales"); @@ -214,7 +219,8 @@ void RunUnpackedData( bits, ToType(output), output_shape, - expect_result); + expect_result, + touch_on_device_data); return; } @@ -239,7 +245,8 @@ void RunUnpackedData( bits, ToType(output), output_shape, - expect_result); + expect_result, + touch_on_device_data); } template @@ -400,7 +407,7 @@ void Test_InvalidIndices_WithZeroPoints() { constexpr int64_t block_size = 16; constexpr int64_t bits = 4; RunUnpackedData(data, data_shape, indices, indices_shape, scales, scales_shape, zero_points, - gather_axis, quantize_axis, block_size, bits, output, output_shape, false); + gather_axis, quantize_axis, block_size, bits, output, output_shape, false, true); } TEST(GatherBlockQuantizedOpTest, InvalidIndices) { diff --git a/onnxruntime/test/optimizer/free_dimension_override_test.cc b/onnxruntime/test/optimizer/free_dimension_override_test.cc index ce778cddb45a3..08f7ebf1c42fc 100644 --- a/onnxruntime/test/optimizer/free_dimension_override_test.cc +++ b/onnxruntime/test/optimizer/free_dimension_override_test.cc @@ -18,7 +18,7 @@ using namespace ONNX_NAMESPACE; namespace onnxruntime { namespace test { -void TestFreeDimensions(FreeDimensionOverrideType overrideType) { +void TestFreeDimensions(FreeDimensionOverrideType overrideType, TransformerLevel level) { auto model_uri = ORT_TSTR("testdata/abs_free_dimensions.onnx"); std::shared_ptr model; @@ -43,9 +43,9 @@ void TestFreeDimensions(FreeDimensionOverrideType overrideType) { auto graph_transformer = std::make_unique(overrides); onnxruntime::GraphTransformerManager graph_transformation_mgr(5); - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(graph_transformer), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(graph_transformer), level)); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, level, DefaultLoggingManager().DefaultLogger())); // Verify that the shape of the input graph has the correct values @@ -73,8 +73,10 @@ void TestFreeDimensions(FreeDimensionOverrideType overrideType) { } TEST(FreeDimensionOverrideDenotationTransformerTest, Test) { - TestFreeDimensions(FreeDimensionOverrideType::Denotation); - TestFreeDimensions(FreeDimensionOverrideType::Name); + TestFreeDimensions(FreeDimensionOverrideType::Denotation, TransformerLevel::Level1); + TestFreeDimensions(FreeDimensionOverrideType::Name, TransformerLevel::Level1); + TestFreeDimensions(FreeDimensionOverrideType::Denotation, TransformerLevel::Default); + TestFreeDimensions(FreeDimensionOverrideType::Name, TransformerLevel::Default); } } // namespace test diff --git a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc index d8d7cd58d50e2..a59ce60d65136 100644 --- a/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc +++ b/onnxruntime/test/optimizer/runtime_optimization/graph_runtime_optimization_test.cc @@ -280,6 +280,50 @@ void CheckNhwcTransformerIsApplied(const PathString& ort_model_path, graph_op_counts_checker, graph_checker)); }; + +#if !defined(ORT_MINIMAL_BUILD) +// if level 0 optimization is enabled the free dimension override should be enabled. +void CheckFreeDimensionOverrideIsApplied(const PathString& model_path, + TransformerLevel level, + FreeDimensionOverrideType overrideType) { + SessionOptions so{}; + so.graph_optimization_level = level; + if (overrideType == FreeDimensionOverrideType::Denotation) { + so.free_dimension_overrides.push_back( + onnxruntime::FreeDimensionOverride{"DATA_BATCH", overrideType, 1}); + so.free_dimension_overrides.push_back( + onnxruntime::FreeDimensionOverride{"DATA_CHANNEL", overrideType, 42}); + } else { + so.free_dimension_overrides.push_back( + onnxruntime::FreeDimensionOverride{"Dim1", overrideType, 1}); + so.free_dimension_overrides.push_back( + onnxruntime::FreeDimensionOverride{"Dim2", overrideType, 42}); + } + + GraphCheckerFn graph_checker = [](const Graph& graph) { + // Verify that the shape of the input graph has the correct values + + const auto& graph_inputs = graph.GetInputs(); + ASSERT_TRUE(graph_inputs.size() == 1); // This model only has a single input ('x') + + const auto* input_shape = graph_inputs[0]->Shape(); + ASSERT_TRUE(input_shape->dim_size() == 3); // Model takes a 3D tensor as input; two of those dimensions are (were) free dimensions + + ASSERT_TRUE(input_shape->dim(0).denotation() == "DATA_BATCH"); + ASSERT_TRUE(input_shape->dim(0).has_dim_value()); + ASSERT_TRUE(input_shape->dim(0).dim_value() == 1); + + ASSERT_TRUE(input_shape->dim(1).denotation() == "DATA_CHANNEL"); + ASSERT_TRUE(input_shape->dim(1).has_dim_value()); + ASSERT_TRUE(input_shape->dim(1).dim_value() == 42); + }; + + ASSERT_NO_FATAL_FAILURE(LoadAndInitializeSession( + so, model_path, + nullptr, + graph_checker)); +}; +#endif // !defined(ORT_MINIMAL_BUILD) } // namespace TEST(GraphRuntimeOptimizationTest, QDQConv) { @@ -374,8 +418,14 @@ TEST(GraphRuntimeOptimizationTest, TestNhwcTransformerDirectlyUpdatesQLinearConv {"com.microsoft.QLinearConv", n}})); }); } - #if !defined(ORT_MINIMAL_BUILD) +TEST(GraphRuntimeOptimizationTest, TestFreeDimensionOverride) { + CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Default, FreeDimensionOverrideType::Denotation); + CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Default, FreeDimensionOverrideType::Name); + CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Level1, FreeDimensionOverrideType::Denotation); + CheckFreeDimensionOverrideIsApplied(ORT_TSTR("testdata/abs_free_dimensions.onnx"), TransformerLevel::Level1, FreeDimensionOverrideType::Name); +} + TEST(GraphRuntimeOptimizationTest, TestOnlyApplyMinimalBuildOptimizations) { // This test assumes that AttentionFusion is not included in the minimal build optimizations. // Update it if that changes. diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 4718a38ce4e1c..8858ae75fb39a 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -322,7 +322,7 @@ TEST(NvExecutionProviderTest, ContextEmbedAndReloadDataDynamic) { std::string graph_name = "test"; std::vector dims = {1, -1, -1}; - CreateBaseModel(model_name, graph_name, dims, true); + CreateBaseModel(model_name, graph_name, dims); auto env = Ort::Env(); auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; @@ -371,7 +371,7 @@ TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { std::string graph_name = "test" + dtype_name; std::vector dims = {1, -1, -1}; - CreateBaseModel(model_name, graph_name, dims, true); + CreateBaseModel(model_name, graph_name, dims); auto env = Ort::Env(); auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; @@ -389,5 +389,47 @@ TYPED_TEST(NvExecutionProviderTest, IOTypeTests) { } } +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + +#if defined(WIN32) +// Tests autoEP feature to automatically select an EP that supports the GPU. +// Currently only works on Windows. +TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { + PathString model_name = ORT_TSTR("nv_execution_provider_data_dyn_test.onnx"); + std::string graph_name = "test"; + std::vector dims = {1, -1, -1}; + + CreateBaseModel(model_name, graph_name, dims, true); + + auto env = Ort::Env(); + auto logging_level = OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING; + env.UpdateEnvWithCustomLogLevel(logging_level); + + { + env.RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); + + Ort::SessionOptions so; + so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); + Ort::Session session_object(env, model_name.c_str(), so); + EXPECT_TRUE(SessionHasEp(session_object, kNvTensorRTRTXExecutionProvider)); + } + + env.UnregisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider); +} +#endif // defined(WIN32) + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/providers/qnn/optimizer/transpose_optimizer_test.cc new file mode 100644 index 0000000000000..77cafc4b08389 --- /dev/null +++ b/onnxruntime/test/providers/qnn/optimizer/transpose_optimizer_test.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "gtest/gtest.h" + +#include "core/graph/constants.h" +#include "core/optimizer/transpose_optimizer.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/test_environment.h" +#include "test/util/include/asserts.h" + +namespace onnxruntime { +namespace test { + +static void TestTransposeReshapeTranspose(const std::vector& input_shape, + const std::vector& transpose1_perm, + const std::vector& reshape_shape, + const std::vector& transpose2_perm, + const bool expected_optimized = true) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input_arg = builder.MakeInput(input_shape, 0.0, 1.0); + auto* reshape_shape_value = builder.MakeInitializer({int64_t(reshape_shape.size())}, reshape_shape); + + auto* transpose1_out = builder.MakeIntermediate(); + auto* reshape_out = builder.MakeIntermediate(); + auto* transpose2_out = builder.MakeOutput(); + + auto& transpose1 = builder.AddNode("Transpose", {input_arg}, {transpose1_out}); + transpose1.AddAttribute("perm", transpose1_perm); + transpose1.SetExecutionProviderType(kQnnExecutionProvider); + + auto& reshape = builder.AddNode("Reshape", {transpose1_out, reshape_shape_value}, {reshape_out}); + reshape.SetExecutionProviderType(kQnnExecutionProvider); + + auto& transpose2 = builder.AddNode("Transpose", {reshape_out}, {transpose2_out}); + transpose2.AddAttribute("perm", transpose2_perm); + transpose2.SetExecutionProviderType(kQnnExecutionProvider); + }; + + auto& logger = DefaultLoggingManager().DefaultLogger(); + Model model("TransformerTester", false, logger); + Graph& graph = model.MainGraph(); + ModelTestBuilder helper(graph); + build_test_case(helper); + helper.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + + std::unique_ptr optimizer = std::make_unique(CPUAllocator::DefaultInstance(), + kQnnExecutionProvider); + bool modified = false; + ASSERT_STATUS_OK(optimizer->Apply(graph, modified, logger)); + ASSERT_EQ(modified, expected_optimized); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], expected_optimized ? 0 : 2); + ASSERT_EQ(op_to_count["Reshape"], 1); +} + +TEST(QnnTransposeOptimizerTests, TransposeReshapeTranspose) { + TestTransposeReshapeTranspose({1, 3, 32}, {0, 2, 1}, {1, 1, 32, 3}, {0, 3, 1, 2}); + TestTransposeReshapeTranspose({1, 32, 3}, {0, 2, 1}, {1, 3, 1, 32}, {0, 2, 3, 1}); + TestTransposeReshapeTranspose({1, 3, 32, 32}, {0, 2, 3, 1}, {1, 32 * 32, 3}, {0, 2, 1}); + TestTransposeReshapeTranspose({1, 3, 32, 32}, {0, 2, 3, 1}, {1, 32 * 32, 1, 3}, {0, 3, 1, 2}); + TestTransposeReshapeTranspose({1, 32, 32, 3}, {0, 3, 1, 2}, {1, 3, 32 * 32}, {0, 2, 1}); + TestTransposeReshapeTranspose({1, 32, 32, 3}, {0, 3, 1, 2}, {1, 3, 32 * 32, 1}, {0, 2, 3, 1}); + + TestTransposeReshapeTranspose({1, 3, 32}, {0, 2, 1}, {1, 8, 2, 6}, {0, 3, 1, 2}, false); + TestTransposeReshapeTranspose({1, 3, 32, 32}, {0, 2, 3, 1}, {1, 32, 16, 6}, {0, 3, 1, 2}, false); + TestTransposeReshapeTranspose({1, 32, 32, 3}, {0, 3, 1, 2}, {1, 6, 16, 32}, {0, 2, 3, 1}, false); +} + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py new file mode 100644 index 0000000000000..3163bb33a3a82 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -0,0 +1,1167 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# ------------------------------------------------------------------------- +import math +import os +import platform +import random +import unittest +from dataclasses import dataclass + +import numpy +import torch +from einops import rearrange, repeat +from onnx import TensorProto, helper +from parameterized import parameterized + +from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers + +# Set seed for reproducibility +torch.manual_seed(0) +random.seed(69) + +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + +# ################################################################################################# +# Configuration and Helper Classes +# ################################################################################################# + + +@dataclass +class GQAConfig: + batch_size: int + q_sequence_length: int + kv_sequence_length: int + num_heads: int + kv_num_heads: int + head_size: int + past_kv_sequence_length: int = 0 + buffer_sequence_length: int = 0 + # Test-specific parameters + local_window_size: int = -1 + rotary: bool = False + rotary_interleaved: bool = False + packed: bool = False + softcap: float = 0.0 + use_smooth_softmax: bool = False + # CPU-only parameters + has_position_ids: bool = False + has_attention_bias: bool = False + has_head_sink: bool = False + + +# ################################################################################################# +# Rotary Embedding Implementations (CPU and CUDA) +# ################################################################################################# + + +# PyTorch implementation for CPU and fallback +class LlamaMSRotaryEmbedding(torch.nn.Module): + def __init__(self): + super().__init__() + + def rotate_tensor(self, x, cos, sin, pos, interleaved): + rot_dim = 2 * cos.shape[3] + x_rot = x[:, :, :, :rot_dim] + + if interleaved: + x1 = x_rot[:, :, :, 0::2] + x2 = x_rot[:, :, :, 1::2] + else: + half = x_rot.shape[-1] // 2 + x1 = x_rot[:, :, :, 0:half] + x2 = x_rot[:, :, :, half : 2 * half] + + seq_len = x.shape[1] + batch_size = x.shape[0] + + cos = cos.squeeze(0).squeeze(1) + sin = sin.squeeze(0).squeeze(1) + + if seq_len == 1: + pos_i = pos.long() + cos_x = cos[pos_i].unsqueeze(1) + sin_x = sin[pos_i].unsqueeze(1) + else: + cos_x_list = [] + sin_x_list = [] + for b in range(batch_size): + pos_b = pos[b] + cos_x_list.append(cos[pos_b : pos_b + seq_len]) + sin_x_list.append(sin[pos_b : pos_b + seq_len]) + cos_x = torch.stack(cos_x_list, dim=0) + sin_x = torch.stack(sin_x_list, dim=0) + + cos_x = cos_x.unsqueeze(2) + sin_x = sin_x.unsqueeze(2) + + real = cos_x * x1 - sin_x * x2 + imag = sin_x * x1 + cos_x * x2 + + if interleaved: + x_rot[:, :, :, 0::2] = real + x_rot[:, :, :, 1::2] = imag + else: + x_rot = torch.cat((real, imag), dim=-1) + + return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1) + + def forward(self, x, cos, sin, pos, interleaved): + return self.rotate_tensor(x, cos, sin, pos, interleaved) + + +# Triton-based implementation for CUDA +def rotary_embedding_cuda(*args, **kwargs): + from rotary_flash import apply_rotary_emb # noqa: PLC0415 + + return apply_rotary_emb(*args, **kwargs) + + +# Unified wrapper for rotary embeddings +def apply_rotary_embedding(x, cos, sin, pos, interleaved, device="cpu"): + """Applies rotary embedding, using Triton for CUDA if available, otherwise fallback to PyTorch.""" + use_cuda_triton = device == "cuda" and platform.system() == "Linux" + if use_cuda_triton: + try: + return rotary_embedding_cuda(x, cos, sin, seqlen_offsets=pos, interleaved=interleaved) + except ImportError: + print("WARNING: Triton-based rotary embedding not found. Falling back to PyTorch version.") + + # PyTorch implementation for CPU or as a fallback for CUDA + rot = LlamaMSRotaryEmbedding().to(device) + # Unsqueeze to match the expected shape in the PyTorch version + cos_unsqueezed = cos.unsqueeze(0).unsqueeze(2) + sin_unsqueezed = sin.unsqueeze(0).unsqueeze(2) + return rot(x, cos_unsqueezed, sin_unsqueezed, pos, interleaved) + + +# ################################################################################################# +# ONNX Graph Creation +# ################################################################################################# + + +def create_group_query_attention_graph_prompt( + config: GQAConfig, + ort_type, + share_buffer=True, +): + assert not (config.has_head_sink and config.use_smooth_softmax) + past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 + present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + + nodes = [ + helper.make_node( + op_type="GroupQueryAttention", + inputs=[ + "query", + "key" if not config.packed else "", + "value" if not config.packed else "", + "past_key" if share_buffer else "", + "past_value" if share_buffer else "", + "seqlens_k", + "total_sequence_length", + "cos_cache" if config.rotary else "", + "sin_cache" if config.rotary else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", + ], + outputs=["output", "present_key", "present_value"], + name="GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=config.local_window_size, + do_rotary=config.rotary, + rotary_interleaved=config.rotary_interleaved, + softcap=config.softcap, + smooth_softmax=1 if config.use_smooth_softmax else 0, + domain="com.microsoft", + ), + ] + + q_hidden_size = ( + (config.num_heads * config.head_size) + if not config.packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) + ) + graph_input = [ + helper.make_tensor_value_info("query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size]), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [config.batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + if not config.packed: + graph_input.extend( + [ + helper.make_tensor_value_info( + "key", + ort_type, + [config.batch_size, config.kv_sequence_length, config.kv_num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "value", + ort_type, + [config.batch_size, config.kv_sequence_length, config.kv_num_heads * config.head_size], + ), + ] + ) + if share_buffer: + # Shape is (batch_size, kv_num_heads, sequence_length, head_size) + k_shape = [config.batch_size, config.kv_num_heads, past_kv_seqlen, config.head_size] + v_shape = k_shape + graph_input.extend( + [ + helper.make_tensor_value_info("past_key", ort_type, k_shape), + helper.make_tensor_value_info("past_value", ort_type, v_shape), + ] + ) + if config.rotary: + rotary_dim = (math.floor(config.head_size / 16) * 16) // 2 + cache_seq_len = config.buffer_sequence_length if share_buffer else config.kv_sequence_length + graph_input.extend( + [ + helper.make_tensor_value_info("cos_cache", ort_type, [cache_seq_len, rotary_dim]), + helper.make_tensor_value_info("sin_cache", ort_type, [cache_seq_len, rotary_dim]), + ] + ) + if config.has_position_ids: + graph_input.append( + helper.make_tensor_value_info( + "position_ids", TensorProto.INT64, [config.batch_size, config.q_sequence_length] + ) + ) + if config.has_attention_bias: + graph_input.append( + helper.make_tensor_value_info( + "attention_bias", ort_type, [config.batch_size, 1, config.q_sequence_length, config.kv_sequence_length] + ) + ) + if config.has_head_sink: + graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + + # Shape is (batch_size, kv_num_heads, sequence_length, head_size) + output_k_shape = [config.batch_size, config.kv_num_heads, present_kv_seqlen, config.head_size] + output_v_shape = output_k_shape + + graph_output = [ + helper.make_tensor_value_info( + "output", ort_type, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size] + ), + helper.make_tensor_value_info("present_key", ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", ort_type, output_v_shape), + ] + + graph = helper.make_graph(nodes, "GroupQueryAttention_Graph", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_group_query_attention_graph_past( + config: GQAConfig, + ort_type, + share_buffer=True, +): + assert not (config.has_head_sink and config.use_smooth_softmax) + + if share_buffer: + past_kv_seqlen = config.buffer_sequence_length + present_kv_seqlen = config.buffer_sequence_length + else: + past_kv_seqlen = config.past_kv_sequence_length + present_kv_seqlen = config.past_kv_sequence_length + config.kv_sequence_length + + nodes = [ + helper.make_node( + "GroupQueryAttention", + [ + "query", + "key" if not config.packed else "", + "value" if not config.packed else "", + "past_key", + "past_value", + "seqlens_k", + "total_sequence_length", + "cos_cache" if config.rotary else "", + "sin_cache" if config.rotary else "", + "position_ids" if config.has_position_ids else "", + "attention_bias" if config.has_attention_bias else "", + "head_sink" if config.has_head_sink else "", + ], + ["output", "present_key", "present_value"], + "GroupQueryAttention_0", + num_heads=config.num_heads, + kv_num_heads=config.kv_num_heads, + local_window_size=config.local_window_size, + do_rotary=config.rotary, + rotary_interleaved=config.rotary_interleaved, + softcap=config.softcap, + smooth_softmax=1 if config.use_smooth_softmax else 0, + domain="com.microsoft", + ), + ] + + q_hidden_size = ( + (config.num_heads * config.head_size) + if not config.packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) + ) + # Shape is (batch_size, kv_num_heads, sequence_length, head_size) + past_k_shape = [config.batch_size, config.kv_num_heads, past_kv_seqlen, config.head_size] + graph_input = [ + helper.make_tensor_value_info("query", ort_type, [config.batch_size, config.q_sequence_length, q_hidden_size]), + helper.make_tensor_value_info("past_key", ort_type, past_k_shape), + helper.make_tensor_value_info("past_value", ort_type, past_k_shape), + helper.make_tensor_value_info("seqlens_k", TensorProto.INT32, [config.batch_size]), + helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), + ] + + if not config.packed: + graph_input.extend( + [ + helper.make_tensor_value_info( + "key", + ort_type, + [config.batch_size, config.q_sequence_length, config.kv_num_heads * config.head_size], + ), + helper.make_tensor_value_info( + "value", + ort_type, + [config.batch_size, config.q_sequence_length, config.kv_num_heads * config.head_size], + ), + ] + ) + + if config.rotary: + rotary_dim = (math.floor(config.head_size / 16) * 16) // 2 + cache_len = config.buffer_sequence_length + graph_input.extend( + [ + helper.make_tensor_value_info("cos_cache", ort_type, [cache_len, rotary_dim]), + helper.make_tensor_value_info("sin_cache", ort_type, [cache_len, rotary_dim]), + ] + ) + + if config.has_position_ids: + graph_input.append( + helper.make_tensor_value_info( + "position_ids", TensorProto.INT64, [config.batch_size, config.q_sequence_length] + ) + ) + if config.has_attention_bias: + graph_input.append( + helper.make_tensor_value_info( + "attention_bias", ort_type, [config.batch_size, 1, config.q_sequence_length, present_kv_seqlen] + ) + ) + if config.has_head_sink: + graph_input.append(helper.make_tensor_value_info("head_sink", ort_type, [config.num_heads])) + + output_k_shape = [ + config.batch_size, + config.kv_num_heads, + present_kv_seqlen, + config.head_size, + ] + + graph_output = [ + helper.make_tensor_value_info( + "output", ort_type, [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size] + ), + helper.make_tensor_value_info("present_key", ort_type, output_k_shape), + helper.make_tensor_value_info("present_value", ort_type, output_k_shape), + ] + + graph = helper.make_graph(nodes, "GroupQueryAttention_Graph", graph_input, graph_output) + model = helper.make_model(graph) + return model.SerializeToString() + + +# ################################################################################################# +# ONNX Runtime Execution Functions +# ################################################################################################# + + +def gqa_prompt_func( + q, + k, + v, + config: GQAConfig, + new_k, + new_v, + cos, + sin, + seqlens_k, + position_ids, + attention_bias, + head_sink, + ep, + device, + share_buffer=True, + ort_type=TensorProto.FLOAT16, + numpy_type=numpy.float16, +): + onnx_model_str = create_group_query_attention_graph_prompt( + config=config, + ort_type=ort_type, + share_buffer=share_buffer, + ) + + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + io_binding = ort_session.io_binding() + + # Common inputs + ort_inputs = { + "query": q.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), + } + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + # CPU-specific inputs + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + if share_buffer: + past_k_ort = OrtValue.ortvalue_from_numpy(k.detach().cpu().numpy(), device, 0) + past_v_ort = OrtValue.ortvalue_from_numpy(v.detach().cpu().numpy(), device, 0) + io_binding.bind_input("past_key", device, 0, numpy_type, past_k_ort.shape(), past_k_ort.data_ptr()) + io_binding.bind_input("past_value", device, 0, numpy_type, past_v_ort.shape(), past_v_ort.data_ptr()) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", past_k_ort) + io_binding.bind_ortvalue_output("present_value", past_v_ort) + else: + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + return torch.tensor(ort_output), present_k, present_v + + +def gqa_past_func( + q, + k, + v, + config: GQAConfig, + new_k, + new_v, + cos, + sin, + seqlens_k, + position_ids, + attention_bias, + head_sink, + ep, + device, + share_buffer=True, + ort_type=TensorProto.FLOAT16, + numpy_type=numpy.float16, +): + onnx_model_str = create_group_query_attention_graph_past( + config=config, + ort_type=ort_type, + share_buffer=share_buffer, + ) + + q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.q_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.q_sequence_length, -1)) + + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=[ep]) + io_binding = ort_session.io_binding() + + # Common inputs + total_seq_len = ( + config.past_kv_sequence_length if share_buffer else config.past_kv_sequence_length + config.q_sequence_length + ) + ort_inputs = { + "query": q.detach().cpu().numpy(), + "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), + "total_sequence_length": torch.tensor([total_seq_len], dtype=torch.int32).detach().cpu().numpy(), + } + io_binding.bind_cpu_input("query", ort_inputs["query"]) + io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) + io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) + + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) + + # CPU-specific inputs + if config.has_position_ids: + ort_inputs["position_ids"] = position_ids.detach().cpu().numpy() + io_binding.bind_cpu_input("position_ids", ort_inputs["position_ids"]) + if config.has_attention_bias: + ort_inputs["attention_bias"] = attention_bias.detach().cpu().numpy() + io_binding.bind_cpu_input("attention_bias", ort_inputs["attention_bias"]) + if config.has_head_sink: + ort_inputs["head_sink"] = head_sink.detach().cpu().numpy() + io_binding.bind_cpu_input("head_sink", ort_inputs["head_sink"]) + + # Binding past and present KV + if share_buffer: + past_k_ort = OrtValue.ortvalue_from_numpy(k.detach().cpu().numpy(), device, 0) + past_v_ort = OrtValue.ortvalue_from_numpy(v.detach().cpu().numpy(), device, 0) + io_binding.bind_input("past_key", device, 0, numpy_type, past_k_ort.shape(), past_k_ort.data_ptr()) + io_binding.bind_input("past_value", device, 0, numpy_type, past_v_ort.shape(), past_v_ort.data_ptr()) + io_binding.bind_output("output") + io_binding.bind_ortvalue_output("present_key", past_k_ort) + io_binding.bind_ortvalue_output("present_value", past_v_ort) + else: + ort_inputs["past_key"] = k.detach().cpu().numpy() + ort_inputs["past_value"] = v.detach().cpu().numpy() + io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) + io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) + io_binding.bind_output("output") + io_binding.bind_output("present_key") + io_binding.bind_output("present_value") + + ort_session.run_with_iobinding(io_binding) + ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() + return torch.tensor(ort_output), present_k, present_v + + +# ################################################################################################# +# Reference Attention Implementation +# ################################################################################################# + + +def construct_local_mask(seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + + +def smooth_softmax_ref(x, head_sink): + b, n, s, t = x.shape + if head_sink is not None: + sink = head_sink.reshape(1, n, 1, 1).expand(b, -1, s, -1) + else: + sink = torch.zeros(b, n, s, 1, dtype=x.dtype, device=x.device) + + y = torch.cat([x, sink], dim=-1) + y = torch.softmax(y, dim=-1) + return y[..., :-1] + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attention_bias=None, + causal=False, + window_size=(-1, -1), + softcap=0.0, + use_smooth_softmax=False, + head_sink=None, +): + if causal: + window_size = (window_size[0], 0) + + dtype_og = q.dtype + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + + # Repeat K/V heads for Grouped-Query Attention + if k.shape[2] != q.shape[2]: + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + if v.shape[2] != q.shape[2]: + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + + scores = torch.einsum("bthd,bshd->bhts", q, k) / math.sqrt(q.shape[-1]) + + if softcap > 0: + scores = (scores / softcap).tanh() * softcap + + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + + local_mask = None + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, q.device + ) + scores.masked_fill_(local_mask, float("-inf")) + + # Add custom attention bias if provided (for CPU tests) + if attention_bias is not None: + # The bias should only be applied to the relevant part of the scores matrix, + # matching the sequence length of the bias tensor. + scores[..., : attention_bias.shape[-1]] += attention_bias + + if use_smooth_softmax or (head_sink is not None): + # Note that the sink directly joins softmax. No scaling and softcap is needed! + attention = smooth_softmax_ref(scores, head_sink) + else: + attention = torch.softmax(scores, dim=-1) + + # Fill NaNs with 0 + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + + output = torch.einsum("bhts,bshd->bthd", attention, v) + + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +# ################################################################################################# +# Parity Check (Core Test Logic) +# ################################################################################################# + + +def parity_check_gqa_prompt( + config: GQAConfig, + ep, + device, + torch_type, + numpy_type, + ort_type, + causal, + rtol, + atol, +): + # Q/K/V have normal distribution with mean = 0 and standard deviation = 0.02. + # If we use standard deviation = 1, numerical stability issues may occur. + std = 0.02 + + # --- Test Data Generation --- + q = ( + torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + + # k and v are the cache buffers, created in BNSH format + k = ( + torch.randn( + config.batch_size, + config.kv_num_heads, + config.buffer_sequence_length, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + v = torch.randn_like(k) + + new_k = ( + torch.randn( + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + * std + ) + new_v = torch.randn_like(new_k) * std + + head_sink = torch.rand(config.num_heads, dtype=torch_type, device=device) if config.has_head_sink else None + + window_size = (-1, -1) + if config.local_window_size > 0: + window_size = (config.local_window_size, 0) + elif causal: + window_size = (-1, 0) + + # --- PyTorch Reference Path --- + # Transpose BNSH cache to BSNH format for reference implementation + k_cache_ref = k.clone().transpose(1, 2) + v_cache_ref = v.clone().transpose(1, 2) + + cache_seqlens = torch.full((config.batch_size,), config.kv_sequence_length, device=device, dtype=torch.int32) + rotary_seqlens = torch.zeros(config.batch_size, device=device, dtype=torch.long) + + cos, sin, q_ro, k_ro = None, None, q, new_k + if config.rotary: + rotary_dim = math.floor(config.head_size / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch_type) + sin = torch.sin(angle).to(dtype=torch_type) + q_ro = apply_rotary_embedding(q.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, rotary_seqlens, config.rotary_interleaved, device) + + position_ids = None + attention_bias = None + if ep == "CPUExecutionProvider": + if config.has_position_ids: + position_ids = ( + torch.arange(config.q_sequence_length, device=device).unsqueeze(0).expand(config.batch_size, -1) + ) + if config.has_attention_bias: + attention_bias = torch.zeros( + config.batch_size, + 1, + config.q_sequence_length, + config.kv_sequence_length, + device=device, + dtype=torch_type, + ) + + arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") + kv_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = arange < kv_seqlens_expanded + + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=torch_type) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=torch_type) + key_padding_mask = arange < kv_seqlens_expanded + + out_ref, _ = attention_ref( + q=q_ro, + k=k_cache_ref, + v=v_cache_ref, + query_padding_mask=None, + key_padding_mask=key_padding_mask, + attention_bias=attention_bias, + causal=True, + window_size=window_size, + softcap=config.softcap, + use_smooth_softmax=config.use_smooth_softmax, + head_sink=head_sink, + ) + out_ref_np = out_ref.detach().cpu().numpy() + + # Transpose reference cache back to BNSH for comparison + k_cache_ref_np = k_cache_ref.transpose(1, 2).detach().cpu().numpy() + v_cache_ref_np = v_cache_ref.transpose(1, 2).detach().cpu().numpy() + + # --- ONNX Runtime Path --- + q_ort, k_ort, v_ort, new_k_ort, new_v_ort = q, k, v, new_k, new_v + if config.packed: + q_ort = torch.cat([q, new_k, new_v], dim=2) + new_k_ort, new_v_ort = None, None + + # seqlens_k for GQA op is past_seq_len + seq_len - 1 + ort_seqlens = cache_seqlens - 1 + out, present_k, present_v = gqa_prompt_func( + q=q_ort, + k=k_ort, + v=v_ort, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens, + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=True, + ort_type=ort_type, + numpy_type=numpy_type, + ) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out_np = out.detach().cpu().numpy() + + # --- Comparison --- + numpy.testing.assert_allclose(present_k, k_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(present_v, v_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +def parity_check_gqa_past( + config: GQAConfig, + ep, + device, + torch_type, + numpy_type, + ort_type, + causal, + rtol, + atol, +): + # --- Test Data Generation --- + q = torch.randn( + config.batch_size, + config.q_sequence_length, + config.num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + # k and v are the cache buffers, created in BNSH format + k = torch.randn( + config.batch_size, + config.kv_num_heads, + config.buffer_sequence_length, + config.head_size, + device=device, + dtype=torch_type, + ) + v = torch.randn_like(k) + new_k = torch.randn( + config.batch_size, + config.q_sequence_length, + config.kv_num_heads, + config.head_size, + device=device, + dtype=torch_type, + ) + new_v = torch.randn_like(new_k) + + head_sink = torch.rand(config.num_heads, dtype=torch_type, device=device) if config.has_head_sink else None + window_size = (-1, -1) + if config.local_window_size > 0: + window_size = (config.local_window_size, 0) + elif causal: + window_size = (-1, 0) + + # --- PyTorch Reference Path --- + # Transpose BNSH cache to BSNH format for reference implementation + k_cache_ref = k.clone().transpose(1, 2) + v_cache_ref = v.clone().transpose(1, 2) + + cache_seqlens = torch.randint( + 0, + config.past_kv_sequence_length - config.q_sequence_length + 1, + (config.batch_size,), + device=device, + dtype=torch.long, + ) + + cos, sin, q_ro, k_ro = None, None, q, new_k + if config.rotary: + rotary_dim = math.floor(config.head_size / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device=device) * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch_type) + sin = torch.sin(angle).to(dtype=torch_type) + q_ro = apply_rotary_embedding(q.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + k_ro = apply_rotary_embedding(new_k.clone(), cos, sin, cache_seqlens, config.rotary_interleaved, device) + + position_ids = None + attention_bias = None + total_seq_len = config.past_kv_sequence_length + if ep == "CPUExecutionProvider": + if config.has_position_ids: + position_ids = (cache_seqlens.unsqueeze(1) + torch.arange(config.q_sequence_length, device=device)).long() + if config.has_attention_bias: + attention_bias = torch.zeros( + config.batch_size, 1, config.q_sequence_length, total_seq_len, device=device, dtype=torch_type + ) + for b in range(config.batch_size): + end_pos = cache_seqlens[b] + config.q_sequence_length + attention_bias[b, :, :, end_pos:] = float("-inf") + + arange = rearrange(torch.arange(config.buffer_sequence_length, device=device), "s -> 1 s") + cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") + update_mask = torch.logical_and( + cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.q_sequence_length + ) + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...").to(dtype=torch_type) + v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...").to(dtype=torch_type) + key_padding_mask = arange < cache_seqlens_expanded + config.q_sequence_length + + out_ref, _ = attention_ref( + q=q_ro, + k=k_cache_ref, + v=v_cache_ref, + query_padding_mask=None, + key_padding_mask=key_padding_mask, + attention_bias=attention_bias, + causal=True, + window_size=window_size, + softcap=config.softcap, + use_smooth_softmax=config.use_smooth_softmax, + head_sink=head_sink, + ) + out_ref_np = out_ref.detach().cpu().numpy() + + # Transpose reference cache back to BNSH for comparison + k_cache_ref_np = k_cache_ref.transpose(1, 2).detach().cpu().numpy() + v_cache_ref_np = v_cache_ref.transpose(1, 2).detach().cpu().numpy() + + # --- ONNX Runtime Path --- + q_ort, k_ort, v_ort, new_k_ort, new_v_ort = q, k, v, new_k, new_v + if config.packed: + q_ort = torch.cat([q, new_k, new_v], dim=2) + new_k_ort, new_v_ort = None, None + + ort_seqlens = cache_seqlens + config.q_sequence_length - 1 + out, present_k, present_v = gqa_past_func( + q=q_ort, + k=k_ort, + v=v_ort, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens.int(), + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=True, + ort_type=ort_type, + numpy_type=numpy_type, + ) + out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) + out_np = out.detach().cpu().numpy() + + numpy.testing.assert_allclose(present_k, k_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(present_v, v_cache_ref_np, rtol=rtol, atol=atol) + numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) + + +# ################################################################################################# +# Test Case Generators +# ################################################################################################# + + +def get_cuda_rotary_options(): + return [(False, False)] if pipeline_mode else [(True, False), (True, True), (False, False)] + + +def get_cpu_rotary_options(): + return [(False, False), (True, False), (True, True)] + + +def get_softmax_options(allow_head_sink: bool = True): + head_sink_option = (False, True) if allow_head_sink else (False, False) + return [(False, False), head_sink_option] if pipeline_mode else [(False, False), (False, True), (True, False)] + + +def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): + batches = [3] if pipeline_mode else [1, 3, 5] + seqs = [(35, 35)] if pipeline_mode else [(35, 35), (127, 127), (240, 240), (2000, 2000)] + num_h = [(6, 3)] if pipeline_mode else [(6, 3), (9, 9), (32, 8)] + h_sizes = [32] if pipeline_mode else [32, 64, 128, 256] + smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) + + for b in batches: + for sq, skv in seqs: + for n, n2 in num_h: + for h in h_sizes: + for lws in [-1, random.randint(1, skv)]: + for rotary, rotary_interleaved in get_cuda_rotary_options(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + for use_smooth_softmax, has_head_sink in smmoth_softmax__head_sink: + if softcap > 0 and (use_smooth_softmax or has_head_sink): + continue + config = GQAConfig( + batch_size=b, + q_sequence_length=sq, + kv_sequence_length=skv, + past_kv_sequence_length=0, + buffer_sequence_length=sq + skv + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + local_window_size=lws, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + has_head_sink=has_head_sink, + ) + name = f"b{b}_sq{sq}_skv{skv}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}" + yield name, config + + +def gqa_cuda_past_test_cases(allow_head_sink: bool = True): + batches = [5] if pipeline_mode else [1, 3, 5] + # s: new sequence length, s2: past sequence length + seqs = [(1, 1024)] if pipeline_mode else [(1, 128), (1, 1024), (1, 2048), (1, 5000)] + num_h = [(32, 8)] if pipeline_mode else [(6, 3), (9, 9), (32, 8)] + h_sizes = [256] if pipeline_mode else [64, 128, 256] + smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) + + for b in batches: + for s, s2 in seqs: + for n, n2 in num_h: + for h in h_sizes: + for lws in [-1, random.randint(1, s2)]: + for rotary, rotary_interleaved in get_cuda_rotary_options(): + for packed in [False, True]: + for softcap in [0.0, 50.0]: + if rotary and h % 16 > 0: + continue + for use_smooth_softmax, has_head_sink in smmoth_softmax__head_sink: + config = GQAConfig( + batch_size=b, + q_sequence_length=s, + kv_sequence_length=s, + past_kv_sequence_length=s2, + buffer_sequence_length=s + s2 + 8, + num_heads=n, + kv_num_heads=n2, + head_size=h, + local_window_size=lws, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + softcap=softcap, + use_smooth_softmax=use_smooth_softmax, + has_head_sink=has_head_sink, + ) + name = f"b{b}_s{s}_{s2}_nh{n}_{n2}_h{h}_w{lws}_rot{rotary}{rotary_interleaved}_pkd{packed}_sc{softcap}_sm{use_smooth_softmax}_{has_head_sink}" + yield name, config + + +# ################################################################################################# +# Unit Test Classes +# ################################################################################################# + + +def has_cuda_provider(): + return "CUDAExecutionProvider" in get_available_providers() + + +def has_flash_attention(): + if not has_cuda_provider() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 + + +def has_memory_efficient(): + if not has_cuda_provider() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 5 + + +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestFlashGQA(unittest.TestCase): + @parameterized.expand(gqa_cuda_prompt_test_cases()) + def test_gqa_prompt_flash_attention(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + @parameterized.expand(gqa_cuda_past_test_cases()) + def test_gqa_past_flash_attention(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + +@unittest.skipIf(not has_memory_efficient(), "Memory Efficient Attention is not available, skipping tests.") +class TestMemoryEfficientGQA(unittest.TestCase): + @parameterized.expand(gqa_cuda_prompt_test_cases(allow_head_sink=False)) + def test_gqa_prompt_memory_efficient(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + @parameterized.expand(gqa_cuda_past_test_cases(allow_head_sink=False)) + def test_gqa_past_memory_efficient(self, name, config): + os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" + parity_check_gqa_past( + config=config, + ep="CUDAExecutionProvider", + device="cuda", + torch_type=torch.float16, + numpy_type=numpy.float16, + ort_type=TensorProto.FLOAT16, + causal=True, + rtol=5e-3, + atol=5e-3, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_cuda.py b/onnxruntime/test/python/transformers/test_gqa_cuda.py deleted file mode 100644 index 79976a92e54bf..0000000000000 --- a/onnxruntime/test/python/transformers/test_gqa_cuda.py +++ /dev/null @@ -1,2046 +0,0 @@ -# -------------------------------------------------------------------------- -# Copyright 2020 The HuggingFace Inc. team -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -# -------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. See License.txt in the project root for -# license information. -# ------------------------------------------------------------------------- -import math -import os -import platform -import random -import unittest - -import numpy -import torch -from einops import rearrange, repeat -from onnx import TensorProto, helper -from packaging import version -from parameterized import parameterized -from test_gqa_cpu import smooth_softmax_ref - -from onnxruntime import InferenceSession, OrtValue, SessionOptions, get_available_providers - -torch.manual_seed(0) - -pipeline_mode = True # Reduces number of tests so pipeline doesn't time out - - -class Formats: - BSNH = 0 - BNSH = 1 - - -class Config: - batch_size = 0 - sequence_length = 0 - kv_sequence_length = 0 # this is past sequence length when there is past state. - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - ep = "CUDAExecutionProvider" - - def __init__(self, batch_size, sequence_length, kv_sequence_length, num_heads, kv_num_heads, head_size): - self.batch_size = batch_size - self.sequence_length = sequence_length - self.kv_sequence_length = kv_sequence_length - self.num_heads = num_heads - self.kv_num_heads = kv_num_heads - self.head_size = head_size - - def __repr__(self): - short_ep = self.ep[: -len("ExecutionProvider")].lower() - return ( - f"Config(batch_size={self.batch_size}, sequence_length={self.sequence_length}, " - f"kv_sequence_length={self.kv_sequence_length}, " - f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, ep={short_ep})" - ) - - -class PromptConfig: - batch_size = 0 - q_sequence_length = 0 - kv_sequence_length = 0 - buffer_sequence_length = 0 - num_heads = 0 - kv_num_heads = 0 - head_size = 0 - ep = "CUDAExecutionProvider" - - def __init__( - self, - batch_size, - q_sequence_length, - kv_sequence_length, - buffer_sequence_length, - num_heads, - kv_num_heads, - head_size, - ): - self.batch_size = batch_size - self.q_sequence_length = q_sequence_length - self.kv_sequence_length = kv_sequence_length - self.buffer_sequence_length = buffer_sequence_length - self.num_heads = num_heads - self.kv_num_heads = kv_num_heads - self.head_size = head_size - - def __repr__(self): - short_ep = self.ep[: -len("ExecutionProvider")].lower() - return ( - f"PromptConfig(batch_size={self.batch_size}, q_sequence_length={self.q_sequence_length}, " - f"kv_sequence_length={self.kv_sequence_length}, buffer_sequence_length={self.buffer_sequence_length}, " - f"num_heads={self.num_heads}, kv_num_heads={self.kv_num_heads}, head_size={self.head_size}, ep={short_ep})" - ) - - -def create_group_query_attention_graph_prompt( - config, - past_kv_format=Formats.BSNH, - share_buffer=True, - local_window_size=-1, - rotary=False, - rotary_interleaved=False, - packed=False, - interactive=False, - softcap=0.0, - use_smooth_softmax=False, -): - past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 - present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length - nodes = [ - helper.make_node( - "GroupQueryAttention", - [ - "query", - "key" if not packed else "", - "value" if not packed else "", - "past_key" if share_buffer else "", - "past_value" if share_buffer else "", - "seqlens_k", - "total_sequence_length", - "cos_cache" if rotary else "", - "sin_cache" if rotary else "", - ], - ["output", "present_key", "present_value"], - "GroupQueryAttention_0", - num_heads=config.num_heads, - kv_num_heads=config.kv_num_heads, - local_window_size=local_window_size, - do_rotary=rotary, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - smooth_softmax=1 if use_smooth_softmax else 0, - # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, - # kv_share_buffer=1 if share_buffer else 0, - domain="com.microsoft", - ), - ] - - graph_input = [ - helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.q_sequence_length, - ( - (config.num_heads * config.head_size) - if not packed - else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) - ), - ], - ), - helper.make_tensor_value_info( - "seqlens_k", - TensorProto.INT32, - [config.batch_size], - ), - helper.make_tensor_value_info( - "total_sequence_length", - TensorProto.INT32, - [1], - ), - ] - if not packed: - graph_input += [ - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - ] - if share_buffer: - graph_input += [ - helper.make_tensor_value_info( - "past_key", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "past_value", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - ] - if rotary: - graph_input += [ - helper.make_tensor_value_info( - "cos_cache", - TensorProto.FLOAT16, - [ - config.buffer_sequence_length if share_buffer else config.kv_sequence_length, - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - helper.make_tensor_value_info( - "sin_cache", - TensorProto.FLOAT16, - [ - config.buffer_sequence_length if share_buffer else config.kv_sequence_length, - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - ] - - graph_output = [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [config.batch_size, config.q_sequence_length, config.num_heads * config.head_size], - ), - helper.make_tensor_value_info( - "present_key", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_value", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - ], - ), - ] - - graph = helper.make_graph( - nodes, - "GroupQueryAttention_Graph", - graph_input, - graph_output, - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - -def create_group_query_attention_graph_past( - config, - past_kv_format=Formats.BSNH, - share_buffer=True, - local_window_size=-1, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, -): - past_kv_seqlen = config.kv_sequence_length - present_kv_seqlen = ( - config.kv_sequence_length if share_buffer else config.kv_sequence_length + config.sequence_length - ) - nodes = [ - helper.make_node( - "GroupQueryAttention", - [ - "query", - "key" if not packed else "", - "value" if not packed else "", - "past_key", - "past_value", - "seqlens_k", - "total_sequence_length", - "cos_cache" if rotary else "", - "sin_cache" if rotary else "", - ], - ["output", "present_key", "present_value"], - "GroupQueryAttention_0", - num_heads=config.num_heads, - kv_num_heads=config.kv_num_heads, - local_window_size=local_window_size, - do_rotary=rotary, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - smooth_softmax=1 if use_smooth_softmax else 0, - # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, - # kv_share_buffer=1 if share_buffer else 0, - domain="com.microsoft", - ), - ] - - graph_input = [ - helper.make_tensor_value_info( - "query", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - ( - (config.num_heads * config.head_size) - if not packed - else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size) - ), - ], - ), - helper.make_tensor_value_info( - "past_key", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "past_value", - TensorProto.FLOAT16, - [ - config.batch_size, - past_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else past_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "seqlens_k", - TensorProto.INT32, - [config.batch_size], - ), - helper.make_tensor_value_info( - "total_sequence_length", - TensorProto.INT32, - [1], - ), - ] - if not packed: - graph_input += [ - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - ] - if rotary: - graph_input += [ - helper.make_tensor_value_info( - "cos_cache", - TensorProto.FLOAT16, - [ - config.kv_sequence_length + (0 if share_buffer else config.sequence_length), - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - helper.make_tensor_value_info( - "sin_cache", - TensorProto.FLOAT16, - [ - config.kv_sequence_length + (0 if share_buffer else config.sequence_length), - (math.floor(config.head_size / 16) * 16) // 2, - ], - ), - ] - - graph_output = [ - helper.make_tensor_value_info( - "output", - TensorProto.FLOAT16, - [config.batch_size, config.sequence_length, config.num_heads * config.head_size], - ), - helper.make_tensor_value_info( - "present_key", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - helper.make_tensor_value_info( - "present_value", - TensorProto.FLOAT16, - [ - config.batch_size, - present_kv_seqlen if past_kv_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_kv_format == Formats.BSNH else present_kv_seqlen, - config.head_size, - ], - ), - ] - - graph = helper.make_graph( - nodes, - "GroupQueryAttention_Graph", - graph_input, - graph_output, - ) - - model = helper.make_model(graph) - return model.SerializeToString() - - -def rotary_options_for_current_os(): - # Reference implementation of rotary uses triton, which is not available in Windows. - # So we only test rotary in Linux right now. - return [(False, False)] if platform.system() != "Linux" else [(True, False), (True, True), (False, False)] - - -def gqa_prompt_func( - q, - k, - v, - config, - new_k, - new_v, - cos=None, - sin=None, - seqlens_k=None, - window_size=-1, - past_kv_format=Formats.BSNH, - share_buffer=True, - rotary_interleaved=False, - softcap=0.0, - use_smooth_softmax=False, -): - onnx_model_str = create_group_query_attention_graph_prompt( - config, - past_kv_format, - share_buffer, - local_window_size=window_size, - rotary=cos is not None, - rotary_interleaved=rotary_interleaved, - packed=new_k is None, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) - past_k = k.clone() if share_buffer else None - past_v = v.clone() if share_buffer else None - if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) - if share_buffer: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), - "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_input( - "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() - ) - io_binding.bind_input( - "past_value", - "cuda", - 0, - numpy.float16, - ort_inputs["past_value"].shape(), - ort_inputs["past_value"].data_ptr(), - ) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) - io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - else: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_output("present_key") - io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - - -def gqa_past_func( - q, - k, - v, - config, - new_k, - new_v, - cos=None, - sin=None, - seqlens_k=None, - past_kv_format=Formats.BSNH, - share_buffer=True, - window_size=-1, - rotary_interleaved=False, - softcap=0.0, - use_smooth_softmax=False, -): - onnx_model_str = create_group_query_attention_graph_past( - config, - past_kv_format, - share_buffer, - local_window_size=window_size, - rotary=cos is not None, - rotary_interleaved=rotary_interleaved, - packed=new_k is None, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) - past_k = k.clone() - past_v = v.clone() - if new_k is not None: - new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) - if share_buffer: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), - "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor([config.kv_sequence_length], dtype=torch.int32) - .detach() - .cpu() - .numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_input( - "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() - ) - io_binding.bind_input( - "past_value", - "cuda", - 0, - numpy.float16, - ort_inputs["past_value"].shape(), - ort_inputs["past_value"].data_ptr(), - ) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_ortvalue_output("present_key", ort_inputs["past_key"]) - io_binding.bind_ortvalue_output("present_value", ort_inputs["past_value"]) - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - else: - ort_inputs = { - "query": q.detach().cpu().numpy(), - "past_key": past_k.detach().cpu().numpy(), - "past_value": past_v.detach().cpu().numpy(), - "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), - "total_sequence_length": torch.tensor( - [config.kv_sequence_length + config.sequence_length], dtype=torch.int32 - ) - .detach() - .cpu() - .numpy(), - } - sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=[config.ep]) - io_binding = ort_session.io_binding() - if new_k is not None: - ort_inputs["key"] = new_k.detach().cpu().numpy() - ort_inputs["value"] = new_v.detach().cpu().numpy() - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) - if cos is not None: - ort_inputs["cos_cache"] = cos.detach().cpu().numpy() - ort_inputs["sin_cache"] = sin.detach().cpu().numpy() - io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) - io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) - io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) - io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) - io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) - io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) - io_binding.bind_output("output") - io_binding.bind_output("present_key") - io_binding.bind_output("present_value") - ort_session.run_with_iobinding(io_binding) - ort_output, present_k, present_v = io_binding.copy_outputs_to_cpu() - ort_output = numpy.array(ort_output) - output = torch.tensor(ort_output) - return output, present_k, present_v - - -def construct_local_mask( - seqlen_q, - seqlen_k, - window_size=(-1, -1), # -1 means infinite window size - query_padding_mask=None, - key_padding_mask=None, - device=None, -): - row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") - col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) - sk = seqlen_k if key_padding_mask is None else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") - sq = seqlen_q if query_padding_mask is None else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") - if window_size[0] < 0: - return col_idx > row_idx + sk - sq + window_size[1] - else: - sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - return torch.logical_or( - col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), - col_idx < row_idx + sk - sq - window_size[0], - ) - - -def attention_ref( - q, - k, - v, - query_padding_mask=None, - key_padding_mask=None, - dropout_p=0.0, - dropout_mask=None, - causal=False, - window_size=(-1, -1), # -1 means infinite window size - softcap=0.0, - upcast=True, - reorder_ops=False, - use_smooth_softmax=False, -): - """ - Arguments: - q: (batch_size, seqlen_q, nheads, head_dim) - k: (batch_size, seqlen_k, nheads_k, head_dim) - v: (batch_size, seqlen_k, nheads_k, head_dim) - query_padding_mask: (batch_size, seqlen_q) - key_padding_mask: (batch_size, seqlen_k) - dropout_p: float - dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) - causal: whether to apply causal masking - window_size: (int, int), left and right window size - upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast - output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) - without changing the math. This is to estimate the numerical error from operation - reordering. - Output: - output: (batch_size, seqlen_q, nheads, head_dim) - attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout - """ - if causal: - window_size = (window_size[0], 0) - dtype_og = q.dtype - if upcast: - q, k, v = q.float(), k.float(), v.float() - seqlen_q, seqlen_k = q.shape[1], k.shape[1] - k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) - v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) - d = q.shape[-1] - if not reorder_ops: - scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) - else: - scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) - if softcap > 0: - scores = scores / softcap - scores = scores.tanh() - scores = scores * softcap - if key_padding_mask is not None: - scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) - if window_size[0] >= 0 or window_size[1] >= 0: - local_mask = construct_local_mask( - seqlen_q, - seqlen_k, - window_size, - query_padding_mask, - key_padding_mask, - q.device, - ) - scores.masked_fill_(local_mask, float("-inf")) - - if use_smooth_softmax: - head_sink = None - attention = smooth_softmax_ref(scores, head_sink) - else: - attention = torch.softmax(scores, dim=-1) - - # Some rows might be completely masked out so we fill them with zero instead of NaN - if window_size[0] >= 0 or window_size[1] >= 0: - attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) - # We want to mask here so that the attention matrix doesn't have any NaNs - # Otherwise we'll get NaN in dV - if query_padding_mask is not None: - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) - dropout_scaling = 1.0 / (1 - dropout_p) - if dropout_mask is not None: - attention_drop = attention.masked_fill(~dropout_mask, 0.0) - else: - attention_drop = attention - output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) - if query_padding_mask is not None: - output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) - - -def rotary_embedding(*args, **kwargs): - # Use local import since triton is not available in Windows. - from rotary_flash import apply_rotary_emb # noqa: PLC0415 - - return apply_rotary_emb(*args, **kwargs) - - -def parity_check_gqa_prompt( - config: PromptConfig, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.q_sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.buffer_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.buffer_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) - # cache_seqlens = torch.randint( - # 0, - # config.kv_sequence_length, - # (config.batch_size,), - # dtype=torch.int32, - # device="cuda", - # ) - # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length - rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=rotary_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.q_sequence_length, - ) - # q_ro = q - k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, new_k - - rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") - arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) - kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") - update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = arange < cache_seqlens_expanded - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( - packed_qkv, - k, - v, - config, - None, - None, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - True, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_prompt_func( - q, - k, - v, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - True, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" - ) - # Make sure past-present buffer updating correctly - numpy.testing.assert_allclose( - present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose( - present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def parity_check_gqa_prompt_no_buff( - config: PromptConfig, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.q_sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = new_k.clone() - v_cache_ref = new_v.clone() - # if past_format == Formats.BNSH: - # k_cache_ref = k_cache_ref.transpose(1, 2) - # v_cache_ref = v_cache_ref.transpose(1, 2) - cache_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) - # cache_seqlens = torch.randint( - # 0, - # config.kv_sequence_length, - # (config.batch_size,), - # dtype=torch.int32, - # device="cuda", - # ) - # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length - rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=rotary_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.q_sequence_length, - ) - # q_ro = q - k_ro = rotary_embedding(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, k_cache_ref - k_cache_ref = k_ro - - brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - new_mask = brange < cache_seqlens_expanded - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - new_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_prompt_func( - packed_qkv, - None, - None, - config, - None, - None, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - False, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_prompt_func( - q, - None, - None, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - left_window_size, - past_format, - False, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}, use_smooth_softmax={use_smooth_softmax}" - ) - # Make sure past-present buffer updating correctly - numpy.testing.assert_allclose( - present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose( - present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def parity_check_gqa_past( - config: Config, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - cache_seqlens = torch.randint( - 0, - config.kv_sequence_length - config.sequence_length + 1, - (config.batch_size,), - dtype=torch.int32, - device="cuda", - ) - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.sequence_length, - ) - k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, new_k - - arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length - ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - cache_seqlens += config.sequence_length - 1 - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( - packed_qkv, - k, - v, - config, - None, - None, - cos, - sin, - cache_seqlens, - past_format, - True, - left_window_size, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_past_func( - q, - k, - v, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - past_format, - True, - left_window_size, - rotary_interleaved, - softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" - ) - # Make sure past-present buffer updating correctly - numpy.testing.assert_allclose( - present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose( - present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg - ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def parity_check_gqa_past_no_buff( - config: Config, - causal=True, - local=False, - past_format=Formats.BNSH, - rotary=False, - rotary_interleaved=False, - packed=False, - softcap=0.0, - use_smooth_softmax=False, - rtol=1e-3, - atol=1e-3, -): - torch.manual_seed(69) - q = torch.randn( - config.batch_size, - config.sequence_length, - config.num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - k = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - v = torch.randn( - config.batch_size, - config.kv_sequence_length if past_format == Formats.BSNH else config.kv_num_heads, - config.kv_num_heads if past_format == Formats.BSNH else config.kv_sequence_length, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_k = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - new_v = torch.randn( - config.batch_size, - config.sequence_length, - config.kv_num_heads, - config.head_size, - device="cuda", - dtype=torch.float16, - requires_grad=False, - ) - - window_size = (-1, -1) - left_window_size = -1 - if local: - left_window_size = random.randint(0, config.kv_sequence_length) - window_size = (left_window_size, 0) - elif causal: - left_window_size = -1 - window_size = (-1, 0) - - # Pytorch to compare - k_cache_ref = k.clone() - v_cache_ref = v.clone() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - k_cache_ref = torch.cat((k_cache_ref, new_k), 1) - v_cache_ref = torch.cat((v_cache_ref, new_v), 1) - cache_seqlens = torch.randint( - 0, - config.kv_sequence_length, - (config.batch_size,), - dtype=torch.int32, - device="cuda", - ) - cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length - - if rotary: - rotary_fraction = 1.0 - rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 - angle = ( - torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi - ) - cos = torch.cos(angle).to(dtype=torch.float16) - sin = torch.sin(angle).to(dtype=torch.float16) - if causal or local: - q_ro = rotary_embedding(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - q_ro = rearrange( - rotary_embedding( - rearrange(q, "b s h d -> b 1 (s h) d"), - cos, - sin, - seqlen_offsets=cache_seqlens, - interleaved=rotary_interleaved, - ), - "b 1 (s h) d -> b s h d", - s=config.sequence_length, - ) - k_ro = rotary_embedding(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) - else: - cos, sin = None, None - q_ro, k_ro = q, new_k - - arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") - update_mask = torch.logical_and( - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length - ) - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") - v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) - key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length - out_ref, _ = attention_ref( - q_ro, - k_cache_rep, - v_cache_rep, - None, - key_padding_mask, - 0.0, - None, - causal=True, - window_size=window_size, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out_ref = out_ref.detach().cpu().numpy() - if past_format == Formats.BNSH: - k_cache_ref = k_cache_ref.transpose(1, 2) - v_cache_ref = v_cache_ref.transpose(1, 2) - - cache_seqlens += config.sequence_length - 1 - - # Flash function - if packed: - packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) - out, present_k, present_v = gqa_past_func( - packed_qkv, - k, - v, - config, - None, - None, - cos, - sin, - cache_seqlens, - past_format, - False, - window_size=left_window_size, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - else: - out, present_k, present_v = gqa_past_func( - q, - k, - v, - config, - new_k, - new_v, - cos, - sin, - cache_seqlens, - past_format, - False, - window_size=left_window_size, - rotary_interleaved=rotary_interleaved, - softcap=softcap, - use_smooth_softmax=use_smooth_softmax, - ) - out = torch.squeeze(out, 0) - out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) - out = out.detach().cpu().numpy() - - err_msg = ( - f" with {config}, causal={causal}, local={local}, past_format={past_format}," - f" rotary={rotary}, rotary_interleaved={rotary_interleaved}, packed={packed}, softcap={softcap}" - ) - for b in range(config.batch_size): - numpy.testing.assert_allclose( - present_k[b, :, : (cache_seqlens + 1)[b]], - k_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - err_msg=err_msg, - ) - numpy.testing.assert_allclose( - present_v[b, :, : (cache_seqlens + 1)[b]], - v_cache_ref[b, :, : (cache_seqlens + 1)[b]].detach().cpu().numpy(), - rtol=rtol, - atol=atol, - equal_nan=True, - err_msg=err_msg, - ) - numpy.testing.assert_allclose(out, out_ref, rtol=rtol, atol=atol, equal_nan=True, err_msg=err_msg) - - -def has_flash_attention(): - if not torch.cuda.is_available(): - return False - if "CUDAExecutionProvider" not in get_available_providers(): - return False - major, _ = torch.cuda.get_device_capability() - return major >= 8 and ( - platform.system() == "Linux" - or (platform.system() == "Windows" and version.parse(torch.version.cuda) >= version.parse("12.0")) - ) - - -def has_memory_efficient(): - if not torch.cuda.is_available(): - return False - if "CUDAExecutionProvider" not in get_available_providers(): - return False - major, minor = torch.cuda.get_device_capability() - if major < 5 or (major == 5 and minor < 3): - return False - return True - - -def gqa_no_past_memory_efficient_test_cases(): - batches = [3] if pipeline_mode else [1, 3, 5] - seqs = ( - [ - (2000, 2000), - ] - if pipeline_mode - else [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] - ) - num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - torch.manual_seed(69) - - for b in batches: - for sq, skv in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - if rotary and h % 16 > 0: - continue - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_no_past_flash_attention_test_cases(): - batches = [3] if pipeline_mode else [1, 3, 5] - seqs = ( - [ - (240, 240), - ] - if pipeline_mode - else [ - (127, 127), - (35, 35), - (2000, 2000), - (200, 200), - (240, 240), - ] - ) - num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [128] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - torch.manual_seed(69) - - for b in batches: - for sq, skv in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - if rotary and h % 16 > 0: - continue - - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_past_memory_efficient_test_cases(): - batches = [5] if pipeline_mode else [1, 3, 5] - seqs = ( - [(1, 1024)] - if pipeline_mode - else [ - (1, 128), - (1, 339), - (1, 1024), - (1, 5000), - (1, 800), - (1, 256), - (1, 799), - (1, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - if rotary and h % 16 > 0: - continue - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_past_flash_attention_test_cases(): - batches = [5] if pipeline_mode else [1, 3, 5] - seqs = ( - [(1, 2048)] - if pipeline_mode - else [ - (1, 128), - (1, 339), - (1, 1024), - (1, 5000), - (1, 800), - (1, 256), - (1, 799), - (1, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(32, 8)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - for softcap in [0.0, 50.0]: - if rotary and h % 16 > 0: - continue - - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}_{softcap}", - config, - local, - rotary, - rotary_interleaved, - packed, - softcap, - ) - - -def gqa_interactive_one_batch_flash_attention_test_cases(): - batches = [1] - seqs = ( - [(128, 2048)] - if pipeline_mode - else [ - (1, 128), - (32, 128), - (128, 2048), - (1235, 5000), - (40, 800), - (1, 256), - (2, 799), - (41, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for local in [False, True]: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - if rotary and h % 16 > 0: - continue - - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{local}_{rotary}_{rotary_interleaved}_{packed}", - config, - local, - rotary, - rotary_interleaved, - packed, - ) - - -def gqa_interactive_one_batch_memory_efficient_attention_test_cases(): - batches = [1] - seqs = ( - [(32, 128)] - if pipeline_mode - else [ - (1, 128), - (32, 128), - (128, 2048), - (1235, 5000), - (40, 800), - (1, 256), - (2, 799), - (41, 2048), - # (1, 128 * 512), - # (16, 128 * 512), - # (128, 128), - ] - ) - num_h = [(9, 3)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] - h_sizes = [64] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] - random.seed(69) - - for b in batches: - for s, s2 in seqs: - for n, n2 in num_h: - for h in h_sizes: - for rotary, rotary_interleaved in rotary_options_for_current_os(): - for packed in [False, True]: - if rotary and h % 16 > 0: - continue - - config = Config(b, s, s2, n, n2, h) - yield ( - str(config) + f"{rotary}_{rotary_interleaved}_{packed}", - config, - rotary, - rotary_interleaved, - packed, - ) - - -@unittest.skipIf(not has_flash_attention(), reason="Flash Attention is not available, skipping tests.") -class TestFlashGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - print("------- FLASH ATTENTION (PROMPT CASE) --------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - parity_check_gqa_prompt( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - - @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - print("------- FLASH ATTENTION (TOKEN GEN) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - - @parameterized.expand(gqa_interactive_one_batch_flash_attention_test_cases()) - def test_gqa_interactive_one_batch_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): - print("------- FLASH ATTENTION (INTERACTIVE) -------") - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - - -@unittest.skipIf(not has_memory_efficient(), reason="Memory efficient FMHA is not available, skipping tests.") -class TestMemoryEfficientGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_memory_efficient_test_cases()) - def test_gqa_no_past_memory_efficient(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("------- MEMORY EFFICIENT ATTENTION (PROMPT CASE) ---------") - - parity_check_gqa_prompt( - config, - local=local, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - parity_check_gqa_prompt_no_buff( - config, - local=local, - rtol=5e-3, - atol=5e-3, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - - @parameterized.expand(gqa_past_memory_efficient_test_cases()) - def test_gqa_past_memory_efficient(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=True, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - softcap=softcap, - use_smooth_softmax=False, - ) - - @parameterized.expand(gqa_interactive_one_batch_memory_efficient_attention_test_cases()) - def test_gqa_interactive_one_batch_memory_efficient_attention(self, _, config, rotary, rotary_interleaved, packed): - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" - print("-------- MEMORY EFFICIENT (INTERACTIVE) --------") - - parity_check_gqa_past( - config, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - parity_check_gqa_past_no_buff( - config, - past_format=Formats.BNSH, - rtol=5e-3, - atol=5e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa_rocm.py b/onnxruntime/test/python/transformers/test_gqa_rocm.py deleted file mode 100644 index 29ae1b6e44a78..0000000000000 --- a/onnxruntime/test/python/transformers/test_gqa_rocm.py +++ /dev/null @@ -1,81 +0,0 @@ -import platform -import unittest - -import torch -from parameterized import parameterized -from test_gqa_cuda import ( - Formats, - gqa_no_past_flash_attention_test_cases, - gqa_past_flash_attention_test_cases, - parity_check_gqa_past, - parity_check_gqa_past_no_buff, - parity_check_gqa_prompt, - parity_check_gqa_prompt_no_buff, -) - -import onnxruntime - - -@unittest.skipIf( - (not torch.cuda.is_available()) - or (platform.system() != "Linux") - or ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()), - reason="ROCm is not available, skipping tests.", -) -class TestRocmGQA(unittest.TestCase): - @parameterized.expand(gqa_no_past_flash_attention_test_cases()) - def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - config.ep = "ROCMExecutionProvider" - print("------- FLASH ATTENTION (PROMPT CASE) --------") - - parity_check_gqa_prompt( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - - parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - - @parameterized.expand(gqa_past_flash_attention_test_cases()) - def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed, softcap): - config.ep = "ROCMExecutionProvider" - print("------- FLASH ATTENTION (TOKEN GEN) -------") - - parity_check_gqa_past( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=Formats.BNSH, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - rtol=0.001, - atol=0.005, - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/onnxruntime/test/python/transformers/test_mha_flash_attn.py b/onnxruntime/test/python/transformers/test_mha_flash_attn.py index f87370e37d21a..a015ce6979f91 100644 --- a/onnxruntime/test/python/transformers/test_mha_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_mha_flash_attn.py @@ -12,7 +12,7 @@ from einops import rearrange, repeat from onnx import TensorProto, helper from parameterized import parameterized -from test_gqa_cuda import attention_ref, has_flash_attention +from test_gqa import attention_ref, has_flash_attention from onnxruntime import InferenceSession, SessionOptions @@ -303,24 +303,16 @@ def mha_func(q, k, v, config): def attention_qkvpacked_ref( qkv, key_padding_mask=None, - dropout_p=0.0, - dropout_mask=None, causal=False, - upcast=True, - reorder_ops=False, use_smooth_softmax=False, ): return attention_ref( qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2], - key_padding_mask, - key_padding_mask, - dropout_p, - dropout_mask, - upcast=upcast, + query_padding_mask=key_padding_mask, + key_padding_mask=key_padding_mask, causal=causal, - reorder_ops=reorder_ops, use_smooth_softmax=use_smooth_softmax, ) @@ -344,7 +336,7 @@ def parity_check_mha( ) out = out.detach().cpu().numpy() # Pytorch to compare - out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, 0.0, None, causal=False) + out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, causal=False) out_ref = out_ref.detach().cpu().numpy() else: q = torch.randn( diff --git a/onnxruntime/test/shared_lib/custom_op_utils.cc b/onnxruntime/test/shared_lib/custom_op_utils.cc index 432d78927a1ab..919847723789e 100644 --- a/onnxruntime/test/shared_lib/custom_op_utils.cc +++ b/onnxruntime/test/shared_lib/custom_op_utils.cc @@ -453,8 +453,9 @@ void StandaloneCustomKernel::InitGru() { float betas[1] = {2.f}; Ort::OpAttr activation_beta = Ort::OpAttr("activation_beta ", betas, 1, OrtOpAttrType::ORT_OP_ATTR_FLOATS); - const char* direction_string = "bidirectional"; - Ort::OpAttr direction = Ort::OpAttr("direction", direction_string, 1, OrtOpAttrType::ORT_OP_ATTR_STRING); + const std::string direction_string = "bidirectional"; + Ort::OpAttr direction = Ort::OpAttr("direction", direction_string.c_str(), static_cast(direction_string.length()), + OrtOpAttrType::ORT_OP_ATTR_STRING); int64_t linear_before_reset_value = 0; Ort::OpAttr linear_before_reset = Ort::OpAttr("linear_before_reset", &linear_before_reset_value, 1, diff --git a/onnxruntime/test/shared_lib/test_data_copy.cc b/onnxruntime/test/shared_lib/test_data_copy.cc new file mode 100644 index 0000000000000..c09dbda745b76 --- /dev/null +++ b/onnxruntime/test/shared_lib/test_data_copy.cc @@ -0,0 +1,203 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +#include "core/graph/constants.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/util/include/api_asserts.h" +#include "test/shared_lib/utils.h" + +#ifdef USE_CUDA +#include "core/providers/cuda/cuda_provider_options.h" +#include +#endif + +extern std::unique_ptr ort_env; + +namespace onnxruntime { +namespace test { + +using StreamUniquePtr = std::unique_ptr>; + +#ifdef USE_CUDA +// test copying input to CUDA using an OrtEpFactory based EP. +// tests GetSharedAllocator, CreateSyncStreamForEpDevice and CopyTensors APIs +TEST(PluginEpDataCopyTest, CopyInputsToCudaDevice) { + OrtEnv* env = *ort_env; + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + +#ifdef _WIN32 + std::string cuda_lib = "onnxruntime_providers_cuda.dll"; +#else + std::string cuda_lib = "onnxruntime_providers_cuda.so"; +#endif + + if (!std::filesystem::exists(cuda_lib)) { + GTEST_SKIP() << "CUDA library was not found"; + } + + // register the provider bridge based CUDA EP so allocator and data transfer is available + // not all the CIs have the provider library in the expected place so we allow for that + const char* ep_registration_name = "ORT CUDA"; + ASSERT_ORTSTATUS_OK(api->RegisterExecutionProviderLibrary(env, ep_registration_name, + ORT_TSTR("onnxruntime_providers_cuda"))); + + const OrtEpDevice* cuda_device = nullptr; + for (const auto& ep_device : ort_env->GetEpDevices()) { + std::string vendor{ep_device.EpVendor()}; + std::string name = {ep_device.EpName()}; + if (vendor == std::string("Microsoft") && name == kCudaExecutionProvider) { + cuda_device = ep_device; + break; + } + } + + if (!cuda_device) { // device running tests may not have an nvidia card + return; + } + + const auto run_test = [&](bool use_streams) { + Ort::SessionOptions options; + options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); + + // we pass in the CUDA cudaStream_t from the OrtSyncStream via provider options so need to create it upfront. + // in the future the stream should be an input to the Session Run. + OrtSyncStream* stream = nullptr; + StreamUniquePtr stream_ptr; + if (use_streams) { + ASSERT_ORTSTATUS_OK(api->CreateSyncStreamForEpDevice(cuda_device, /*options*/ nullptr, &stream)); + stream_ptr = StreamUniquePtr(stream, [api](OrtSyncStream* stream) { api->ReleaseSyncStream(stream); }); + + size_t stream_addr = reinterpret_cast(api->SyncStream_GetHandle(stream)); + options.AddConfigEntry("ep.cudaexecutionprovider.user_compute_stream", std::to_string(stream_addr).c_str()); + // we explicitly specify user_compute_stream, so why do we also need to set has_user_compute_stream? + options.AddConfigEntry("ep.cudaexecutionprovider.has_user_compute_stream", "1"); + } + + Ort::Session session(*ort_env, ORT_TSTR("testdata/mnist.onnx"), options); + + size_t num_inputs = session.GetInputCount(); + + // find the input location so we know which inputs can be provided on device. + std::vector input_locations; + input_locations.resize(num_inputs, nullptr); + ASSERT_ORTSTATUS_OK(api->SessionGetMemoryInfoForInputs(session, input_locations.data(), num_inputs)); + + std::vector cpu_tensors; + + // info for device copy + std::vector src_tensor_ptrs; + std::vector dst_tensor_ptrs; + + // values we'll call Run with + std::vector input_tensors; + + ASSERT_EQ(num_inputs, 1); + + // create cpu based input data. + Ort::AllocatorWithDefaultOptions cpu_allocator; + std::vector shape{1, 1, 28, 28}; + std::vector input_data(28 * 28, 0.5f); + Ort::Value input_value = Ort::Value::CreateTensor(cpu_allocator.GetInfo(), + input_data.data(), input_data.size(), + shape.data(), shape.size()); + cpu_tensors.push_back(std::move(input_value)); + + for (size_t idx = 0; idx < num_inputs; ++idx) { + const OrtMemoryInfo* mem_info = input_locations[idx]; + OrtDeviceMemoryType mem_type; + OrtMemoryInfoDeviceType device_type; + ASSERT_ORTSTATUS_OK(api->MemoryInfoGetDeviceMemType(mem_info, &mem_type)); + api->MemoryInfoGetDeviceType(mem_info, &device_type); + + if (device_type == OrtMemoryInfoDeviceType_GPU && mem_type == OrtDeviceMemoryType_DEFAULT) { + // copy to device + OrtAllocator* allocator = nullptr; + ASSERT_ORTSTATUS_OK(api->GetSharedAllocator(env, mem_info, &allocator)); + + // allocate new on-device memory + auto src_shape = cpu_tensors[idx].GetTensorTypeAndShapeInfo().GetShape(); + Ort::Value device_value = Ort::Value::CreateTensor(allocator, src_shape.data(), src_shape.size()); + + /* if you have existing memory on device use one of these instead of CreateTensorAsOrtValue + CopyTensors + void* existing_data; + size_t data_length = 128 * sizeof(float); + api->CreateTensorWithDataAsOrtValue(input_locations[0], existing_data, data_length, shape, 2, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value); + + // existing with ownership transfer. ORT will use the allocator to free the memory once it is no longer required + api->CreateTensorWithDataAndDeleterAsOrtValue(allocator, existing_data, data_length, shape, 2, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &value); + */ + + src_tensor_ptrs.push_back(cpu_tensors[idx]); + dst_tensor_ptrs.push_back(device_value); + input_tensors.push_back(std::move(device_value)); + } else { + // input is on CPU accessible memory. move to input_tensors + input_tensors.push_back(std::move(cpu_tensors[idx])); + } + } + + if (!src_tensor_ptrs.empty()) { + ASSERT_ORTSTATUS_OK(api->CopyTensors(env, src_tensor_ptrs.data(), dst_tensor_ptrs.data(), stream, + src_tensor_ptrs.size())); + + // Stream support is still a work in progress. + // + // CUDA EP can use a user provided stream via provider options, so we can pass in the cudaStream_t from the + // OrtSyncStream used in CopyTensors call that way. + // + // Alternatively you can manually sync the device via IoBinding. + // Ort::IoBinding iobinding(session); + // iobinding.SynchronizeInputs(); // this doesn't actually require any bound inputs + } + + std::vector input_names = {"Input3"}; + std::vector output_names = {"Plus214_Output_0"}; + Ort::Value output; + + session.Run(Ort::RunOptions{}, input_names.data(), input_tensors.data(), input_tensors.size(), + output_names.data(), &output, 1); + + const float* results = nullptr; + ASSERT_ORTSTATUS_OK(api->GetTensorData(output, reinterpret_cast(&results))); + + // expected results from the CPU EP. can check/re-create by running with PREFER_CPU. + std::vector expected = { + -0.701670527f, + -0.583666623f, + 0.0480501056f, + 0.550699294f, + -1.25372827f, + 1.17879760f, + 0.838122189f, + -1.51267099f, + 0.902430952f, + 0.243748352f, + }; + + for (size_t i = 0; i < expected.size(); ++i) { + EXPECT_NEAR(expected[i], results[i], 1e-3) << "i=" << i; + } + }; + + run_test(/*use_streams*/ true); + run_test(/*use_streams*/ false); + + ASSERT_ORTSTATUS_OK(api->UnregisterExecutionProviderLibrary(env, ep_registration_name)); +} +#endif // USE_CUDA + +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aa25e3f31166a..202aa61da0b80 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -60,7 +60,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 resources: repositories: diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index ab779e164b36e..74f7f782fe1b2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -19,7 +19,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: enableWebGpu displayName: Enable WebGPU test diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml index 110f83ff587c8..92e862bd79008 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar.yml @@ -53,7 +53,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: '2.36.0.250627' + default: '2.36.1.250708' - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 535784933a087..5b48a14e2afc3 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -47,7 +47,7 @@ parameters: - name: QnnSDKVersion displayName: QNN SDK Version type: string - default: 2.36.0.250627 + default: 2.36.1.250708 - name: is1ES displayName: Is 1ES pipeline diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 9815e1ac94d24..5eef1ae8e8e93 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -34,7 +34,7 @@ stages: buildSettingsFile: "tools/ci_build/github/apple/default_training_ios_framework_build_settings.json" cPodName: onnxruntime-training-c objcPodName: onnxruntime-training-objc - timeoutInMinutes: 240 + timeoutInMinutes: 270 templateContext: outputs: - output: pipelineArtifact diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile index 607b3e693b624..fe6c00f99323f 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile @@ -19,8 +19,8 @@ RUN dnf install -y --nodocs \ && dnf clean all \ && rm -rf /var/cache/dnf -ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2025.1.0 -ARG OPENVINO_PACKAGE_URL=https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.1/linux/openvino_toolkit_rhel8_2025.1.0.18503.6fec06580ab_x86_64.tgz +ENV INTEL_OPENVINO_DIR=/opt/intel/openvino_2025.2.0 +ARG OPENVINO_PACKAGE_URL=https://storage.openvinotoolkit.org/repositories/openvino/packages/2025.2/linux/openvino_toolkit_rhel8_2025.2.0.19140.c01cd93e24d_x86_64.tgz ARG TEMP_DIR=/tmp/openvino_installer RUN mkdir -p ${TEMP_DIR} && \ diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index 223f1f46f2e70..bb3d82f074786 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -6,7 +6,7 @@ numpy==1.21.6 ; python_version < '3.9' numpy==2.0.0 ; python_version >= '3.9' torch>=2.6.0 coloredlogs==15.0 -transformers==4.48.0 +transformers==4.52.1 parameterized>=0.8.1 sentencepiece psutil