diff --git a/.gitmodules b/.gitmodules index b5bff01d89850..a48c4062a90fe 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "cmake/external/emsdk"] path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git - branch = 4.0.8 + branch = 4.0.11 diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b0941b4d0c922..a76be16572a03 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -562,6 +562,7 @@ else() check_cxx_compiler_flag(-Wcast-function-type HAS_CAST_FUNCTION_TYPE) check_cxx_compiler_flag(-Wcatch-value HAS_CATCH_VALUE) check_cxx_compiler_flag(-Wclass-memaccess HAS_CLASS_MEMACCESS) + check_cxx_compiler_flag(-Wcharacter-conversion HAS_CHARACTER_CONVERSION) check_cxx_compiler_flag(-Wdangling-reference HAS_DANGLING_REFERENCE) check_cxx_compiler_flag(-Wdeprecated-anon-enum-enum-conversion HAS_DEPRECATED_ANON_ENUM_ENUM_CONVERSION) check_cxx_compiler_flag(-Wdeprecated-builtins HAS_DEPRECATED_BUILTINS) diff --git a/cmake/deps.txt b/cmake/deps.txt index 01e5c809640f9..ed1de06f33dcb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -27,8 +27,8 @@ fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f638fa0e724760e2ba07ff8cfba32cd644e1ce28 -#xnnpack 2024.09.04 -googlexnnpack;https://github.com/google/XNNPACK/archive/fe98e0b93565382648129271381c14d6205255e3.zip;14f61dcf17cec2cde34ba2dcf61d6f24bf6059f3 +#xnnpack 2025.06.22 +googlexnnpack;https://github.com/google/XNNPACK/archive/3cf85e705098622d59056dcb8f5f963ea7bb0a00.zip;6f6bbba627241f89463ca845febaf063982b34fe json;https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip;5e88795165cc8590138d1f47ce94ee567b85b4d6 microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 @@ -45,9 +45,9 @@ protoc_linux_x86;https://github.com/protocolbuffers/protobuf/releases/download/v protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-linux-aarch_64.zip;df9d45470b0b8cf939dd2f0ec6b88e9cafc4d617 protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 -pthreadpool;https://github.com/google/pthreadpool/archive/4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0.zip;bd4ea65c8292801e9555b527a0ecbb2e0092c917 +pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f780292da9db273c8ef06ccf5fd4b623624143e9 -pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6.zip;85bf8a60dae026b99b6ccd78606c85ed83bfb2cd +pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/de0ce7c7251372892e53ce9bc891750d2c9a4fd8.zip;c45b8d3619b9bccbd26dc5f657959aee38b18b7a re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381 diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 419021fa04042..d49219d03a41c 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 419021fa040428bc69ef1559b325addb8e10211f +Subproject commit d49219d03a41cd12f95a33ba84273c20d41fd350 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index f76ad642447ba..84e3132b1deec 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -367,23 +367,23 @@ if (CPUINFO_SUPPORTED) set(CPUINFO_BUILD_MOCK_TESTS OFF CACHE INTERNAL "") set(CPUINFO_BUILD_BENCHMARKS OFF CACHE INTERNAL "") if (onnxruntime_target_platform STREQUAL "ARM64EC" OR onnxruntime_target_platform STREQUAL "ARM64") - message(STATUS "Applying a patch for Windows ARM64/ARM64EC in cpuinfo") - onnxruntime_fetchcontent_declare( - pytorch_cpuinfo - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - EXCLUDE_FROM_ALL - PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch - FIND_PACKAGE_ARGS NAMES cpuinfo - ) + message(STATUS "Applying a patch for Windows ARM64/ARM64EC in cpuinfo") + onnxruntime_fetchcontent_declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + EXCLUDE_FROM_ALL + PATCH_COMMAND ${Patch_EXECUTABLE} -p1 < ${PROJECT_SOURCE_DIR}/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch + FIND_PACKAGE_ARGS NAMES cpuinfo + ) else() - onnxruntime_fetchcontent_declare( - pytorch_cpuinfo - URL ${DEP_URL_pytorch_cpuinfo} - URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} - EXCLUDE_FROM_ALL - FIND_PACKAGE_ARGS NAMES cpuinfo - ) + onnxruntime_fetchcontent_declare( + pytorch_cpuinfo + URL ${DEP_URL_pytorch_cpuinfo} + URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo} + EXCLUDE_FROM_ALL + FIND_PACKAGE_ARGS NAMES cpuinfo + ) endif() set(ONNXRUNTIME_CPUINFO_PROJ pytorch_cpuinfo) onnxruntime_fetchcontent_makeavailable(${ONNXRUNTIME_CPUINFO_PROJ}) @@ -570,7 +570,7 @@ if (onnxruntime_USE_XNNPACK) ENDIF() ADD_LIBRARY(xnnpack STATIC IMPORTED) find_library(xnnpack_LIBRARY NAMES XNNPACK) - find_library(microkernels_prod_LIBRARY NAMES microkernels-prod) + find_library(microkernels_prod_LIBRARY NAMES xnnpack-microkernels-prod) find_package(unofficial-pthreadpool CONFIG REQUIRED) target_include_directories(xnnpack INTERFACE "${XNNPACK_HDR}") diff --git a/cmake/external/xnnpack.cmake b/cmake/external/xnnpack.cmake index d0ab770053be1..c994e7e15aac4 100644 --- a/cmake/external/xnnpack.cmake +++ b/cmake/external/xnnpack.cmake @@ -90,7 +90,7 @@ onnxruntime_fetchcontent_makeavailable(googlexnnpack) set(XNNPACK_DIR ${googlexnnpack_SOURCE_DIR}) set(XNNPACK_INCLUDE_DIR ${XNNPACK_DIR}/include) -set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK microkernels-prod pthreadpool) +set(onnxruntime_EXTERNAL_LIBRARIES_XNNPACK XNNPACK xnnpack-microkernels-prod pthreadpool) if(ORT_TARGET_PROCESSOR MATCHES "^arm64.*" AND NOT CMAKE_C_COMPILER_ID STREQUAL "MSVC") list(APPEND onnxruntime_EXTERNAL_LIBRARIES_XNNPACK kleidiai) endif() diff --git a/cmake/onnxruntime_session.cmake b/cmake/onnxruntime_session.cmake index 3ec3c6ee1d5ae..f81a7a9726b76 100644 --- a/cmake/onnxruntime_session.cmake +++ b/cmake/onnxruntime_session.cmake @@ -5,6 +5,8 @@ file(GLOB onnxruntime_session_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_INCLUDE_DIR}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.h" "${ONNXRUNTIME_ROOT}/core/session/*.cc" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.h" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.cc" ) if (onnxruntime_ENABLE_TRAINING_APIS) @@ -22,7 +24,7 @@ endif() # which is not enabled for any minimal builds. if (onnxruntime_MINIMAL_BUILD) file(GLOB autoep_srcs - "${ONNXRUNTIME_ROOT}/core/session/ep_*.*" + "${ONNXRUNTIME_ROOT}/core/session/plugin_ep/*.*" ) set(onnxruntime_session_src_exclude diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 96e513c8a7bc9..c3bebba3bab54 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -120,6 +120,9 @@ function(AddTest) if (${HAS_NOERROR}) target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=uninitialized>") endif() + if (${HAS_CHARACTER_CONVERSION}) + target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Wno-error=character-conversion>") + endif() endif() set(TEST_ARGS ${_UT_TEST_ARGS}) @@ -787,6 +790,9 @@ if(MSVC) "$<$>:/wd6326>") else() target_include_directories(onnxruntime_test_utils PRIVATE ${CMAKE_CURRENT_BINARY_DIR} ${ONNXRUNTIME_ROOT}) + if (HAS_CHARACTER_CONVERSION) + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Wno-error=character-conversion>") + endif() endif() if (onnxruntime_USE_NCCL) target_include_directories(onnxruntime_test_utils PRIVATE ${NCCL_INCLUDE_DIRS}) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index ffe866164a411..e2d04843d858e 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -175,9 +175,9 @@ else() "${ONNXRUNTIME_ROOT}/wasm/api.cc" "${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc" ) - set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") - set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) + set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS "-sDISABLE_EXCEPTION_CATCHING=0") + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_CATCHING=0") endif() target_link_libraries(onnxruntime_webassembly PRIVATE @@ -241,11 +241,10 @@ else() "SHELL:-s FILESYSTEM=0" "SHELL:-s INCOMING_MODULE_JS_API=[locateFile,instantiateWasm,wasmBinary]" "SHELL:-s WASM_BIGINT=1" - ${WASM_API_EXCEPTION_CATCHING} --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) - + if (onnxruntime_USE_JSEP) # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU # This flag allows async functions to be called from sync functions, in the cost of binary size and @@ -256,7 +255,7 @@ else() "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" ) list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js") - + endif() if (onnxruntime_USE_WEBGPU) diff --git a/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch b/cmake/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch similarity index 53% rename from cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch rename to cmake/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch index 7785621965b00..23ceeb8f758cc 100644 --- a/cmake/patches/cpuinfo/9bb12d342fd9479679d505d93a478a6f9cd50a47.patch +++ b/cmake/patches/cpuinfo/patch_cpuinfo_h_for_arm64ec.patch @@ -1,5 +1,5 @@ diff --git a/include/cpuinfo.h b/include/cpuinfo.h -index 6eb4b8c..4346a5a 100644 +index f1d35d4..9e454d2 100644 --- a/include/cpuinfo.h +++ b/include/cpuinfo.h @@ -18,7 +18,7 @@ @@ -20,16 +20,3 @@ index 6eb4b8c..4346a5a 100644 #define CPUINFO_ARCH_ARM64 1 #endif -diff --git a/src/arm/windows/init.c b/src/arm/windows/init.c -index de2f6cc..c3a7835 100644 ---- a/src/arm/windows/init.c -+++ b/src/arm/windows/init.c -@@ -175,7 +175,7 @@ static struct woa_chip_info* get_system_info_from_registry(void) { - if (chip_info == NULL) { - /* No match was found, so print a warning and assign the unknown - * case. */ -- cpuinfo_log_error( -+ cpuinfo_log_debug( - "Unknown chip model name '%ls'.\nPlease add new Windows on Arm SoC/chip support to arm/windows/init.c!", - text_buffer); - } else { diff --git a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch index c9cb4bcad9e20..ea0bb61274f84 100644 --- a/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch +++ b/cmake/patches/xnnpack/AddEmscriptenAndIosSupport.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f0b3410ae..1e3cb8178 100644 +index 94bcad92e3..be7dfe95fd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -337,7 +337,7 @@ ENDIF() +@@ -360,7 +360,7 @@ ENDIF() # ---[ Build flags IF(NOT CMAKE_SYSTEM_NAME) MESSAGE(FATAL_ERROR "CMAKE_SYSTEM_NAME not defined") @@ -11,21 +11,30 @@ index f0b3410ae..1e3cb8178 100644 MESSAGE(FATAL_ERROR "Unrecognized CMAKE_SYSTEM_NAME value \"${CMAKE_SYSTEM_NAME}\"") ENDIF() IF(CMAKE_SYSTEM_NAME MATCHES "Windows") -@@ -848,7 +848,12 @@ IF(XNNPACK_BUILD_LIBRARY) - TARGET_LINK_LIBRARIES(operator-utils PRIVATE xnnpack-base logging) - TARGET_LINK_LIBRARIES(reference-ukernels PRIVATE xnnpack-base) - TARGET_LINK_LIBRARIES(subgraph PRIVATE xnnpack-base allocator logging memory mutex operators operator-run datatype) -- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph datatype reference-ukernels) +@@ -903,10 +903,18 @@ IF(XNNPACK_BUILD_LIBRARY) + TARGET_LINK_LIBRARIES(xnnpack-operator-utils PRIVATE xnnpack-base xnnpack-logging) + TARGET_LINK_LIBRARIES(xnnpack-reference-ukernels PRIVATE xnnpack-base xnnpack-datatype) + TARGET_LINK_LIBRARIES(xnnpack-subgraph PRIVATE xnnpack-base xnnpack-allocator xnnpack-logging xnnpack-memory xnnpack-mutex xnnpack-operators xnnpack-operator-run xnnpack-datatype) +- TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache +- xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init +- xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing +- xnnpack-microkernels-prod xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + IF(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # omit microkernels-prod as the list is manually created by ORT in cmake/external/xnnpack.cmake -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing subgraph datatype reference-ukernels) ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache ++ xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init ++ xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing ++ xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + ELSE() -+ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base allocator cache hardware-config indirection memory microkernel-utils microparams-init mutex normalization operators operator-run operator-utils packing microkernels-prod subgraph datatype reference-ukernels) ++ TARGET_LINK_LIBRARIES(XNNPACK PRIVATE xnnpack-base xnnpack-allocator xnnpack-cache ++ xnnpack-hardware-config xnnpack-indirection xnnpack-memory xnnpack-microkernel-utils xnnpack-microparams-init ++ xnnpack-mutex xnnpack-normalization xnnpack-operators xnnpack-operator-run xnnpack-operator-utils xnnpack-pack-lh xnnpack-packing ++ xnnpack-microkernels-prod xnnpack-subgraph xnnpack-datatype xnnpack-reference-ukernels) + ENDIF() - TARGET_LINK_LIBRARIES(XNNPACK PUBLIC pthreadpool logging) + TARGET_LINK_LIBRARIES(XNNPACK PUBLIC pthreadpool xnnpack-logging) SET_TARGET_PROPERTIES(XNNPACK PROPERTIES C_EXTENSIONS YES) ENDIF() -@@ -857,7 +862,8 @@ IF(NOT MSVC) +@@ -915,7 +923,8 @@ IF(NOT MSVC) ENDIF() IF(XNNPACK_TARGET_PROCESSOR STREQUAL "arm") SET_PROPERTY(SOURCE ${ALL_MICROKERNEL_SRCS} APPEND_STRING PROPERTY COMPILE_FLAGS " -marm ") diff --git a/cmake/vcpkg-ports/cpuinfo/portfile.cmake b/cmake/vcpkg-ports/cpuinfo/portfile.cmake index e61308bf643b4..6722f10a72857 100644 --- a/cmake/vcpkg-ports/cpuinfo/portfile.cmake +++ b/cmake/vcpkg-ports/cpuinfo/portfile.cmake @@ -6,8 +6,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO pytorch/cpuinfo - REF 8a1772a0c5c447df2d18edf33ec4603a8c9c04a6 - SHA512 b94ccbfa886221d6bb16513d074675af0a72928a9dd9485dcacdc1124a8a60aacbbe91913a1579e766dfb024f0be1d52eeead40342004ff0238a8b94a095ed08 + REF de0ce7c7251372892e53ce9bc891750d2c9a4fd8 + SHA512 0fde9210b700d2648d37c8deeb0d5c0d007d8ca5689578dd3bce4c460886b20d7649f0194d2ea06b02238fe9d4f06193599ec3ab5cafb19f1f860b00404264fa HEAD_REF master ) diff --git a/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch b/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch index 97fd1ac7a2bb1..cf7df0ea22980 100644 --- a/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch +++ b/cmake/vcpkg-ports/pthreadpool/fix-cmakelists.patch @@ -1,8 +1,8 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f06aada..3c6c6e2 100644 +index efff8cc..1a0f7e9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -31,8 +31,6 @@ IF(CCACHE_BINARY) +@@ -41,8 +41,6 @@ IF(CMAKE_C_COMPILER_ID STREQUAL "MSVC") ENDIF() # ---[ Options. @@ -11,7 +11,7 @@ index f06aada..3c6c6e2 100644 OPTION(PTHREADPOOL_ALLOW_DEPRECATED_API "Enable deprecated API functions" ON) SET(PTHREADPOOL_SYNC_PRIMITIVE "default" CACHE STRING "Synchronization primitive (condvar, futex, gcd, event, or default) for worker threads") SET_PROPERTY(CACHE PTHREADPOOL_SYNC_PRIMITIVE PROPERTY STRINGS default condvar futex gcd event) -@@ -41,7 +39,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$") +@@ -51,7 +49,7 @@ IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i[3-6]86|AMD64|x86(_64)?)$") ELSE() OPTION(PTHREADPOOL_ENABLE_FASTPATH "Enable fast path using atomic decrement instead of atomic compare-and-swap" OFF) ENDIF() @@ -20,8 +20,8 @@ index f06aada..3c6c6e2 100644 OPTION(PTHREADPOOL_BUILD_TESTS "Build pthreadpool unit tests" ON) OPTION(PTHREADPOOL_BUILD_BENCHMARKS "Build pthreadpool micro-benchmarks" ON) ELSE() -@@ -67,7 +65,8 @@ MACRO(PTHREADPOOL_TARGET_ENABLE_CXX11 target) - ENDMACRO() +@@ -71,7 +69,8 @@ IF(PTHREADPOOL_BUILD_TESTS) + ENDIF() # ---[ Download deps -IF(NOT DEFINED FXDIV_SOURCE_DIR) @@ -30,7 +30,7 @@ index f06aada..3c6c6e2 100644 MESSAGE(STATUS "Downloading FXdiv to ${CMAKE_BINARY_DIR}/FXdiv-source (define FXDIV_SOURCE_DIR to avoid it)") CONFIGURE_FILE(cmake/DownloadFXdiv.cmake "${CMAKE_BINARY_DIR}/FXdiv-download/CMakeLists.txt") EXECUTE_PROCESS(COMMAND "${CMAKE_COMMAND}" -G "${CMAKE_GENERATOR}" . -@@ -118,21 +117,13 @@ ELSE() +@@ -122,21 +121,13 @@ ELSE() ENDIF() ADD_LIBRARY(pthreadpool_interface INTERFACE) @@ -54,7 +54,7 @@ index f06aada..3c6c6e2 100644 IF(PTHREADPOOL_SYNC_PRIMITIVE STREQUAL "condvar") TARGET_COMPILE_DEFINITIONS(pthreadpool PRIVATE PTHREADPOOL_USE_FUTEX=0) -@@ -181,18 +172,22 @@ IF(CMAKE_SYSTEM_NAME STREQUAL "Linux") +@@ -182,18 +173,22 @@ IF(CMAKE_SYSTEM_NAME STREQUAL "Linux") ENDIF() # ---[ Configure FXdiv @@ -80,3 +80,4 @@ index f06aada..3c6c6e2 100644 IF(PTHREADPOOL_BUILD_TESTS) # ---[ Build google test + diff --git a/cmake/vcpkg-ports/pthreadpool/portfile.cmake b/cmake/vcpkg-ports/pthreadpool/portfile.cmake index 9400e5e886639..449459feb33cc 100644 --- a/cmake/vcpkg-ports/pthreadpool/portfile.cmake +++ b/cmake/vcpkg-ports/pthreadpool/portfile.cmake @@ -5,8 +5,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO google/pthreadpool - REF 4e80ca24521aa0fb3a746f9ea9c3eaa20e9afbb0 - SHA512 776017cc5d2aa94337292f2f4fbd54d099ef29abf736ab8147f07f98f12b7654cbd2fe38d34646a479a519c261ac253bbaf19c6dcbb0ec4cc0859de70f7e6472 + REF dcc9f28589066af0dbd4555579281230abbf74dd + SHA512 61853fa8f6c3297d8760be3af1df3f2a00583c1e0e58bdd03cd9cb915e8660a4f2817b22e6463cf53f10de902a1c6204ec6054fcbeada72eeee9e44baeb97178 PATCHES fix-cmakelists.patch ) diff --git a/cmake/vcpkg-ports/xnnpack/fix-build.patch b/cmake/vcpkg-ports/xnnpack/fix-build.patch index b867377d2ff9e..3da8825e2b57d 100644 --- a/cmake/vcpkg-ports/xnnpack/fix-build.patch +++ b/cmake/vcpkg-ports/xnnpack/fix-build.patch @@ -1,21 +1,17 @@ diff --git a/CMakeLists.txt b/CMakeLists.txt -index f0b3410ae..ba54c3bfe 100644 +index 9f6fb5e256..4387298e59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt -@@ -1047,9 +1047,11 @@ ENDIF() - IF(XNNPACK_BUILD_ALL_MICROKERNELS) - TARGET_INCLUDE_DIRECTORIES(microkernels-all PRIVATE include src) +@@ -1125,7 +1125,7 @@ ELSE() ENDIF() -+ - TARGET_INCLUDE_DIRECTORIES(datatype PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(microkernels-prod PRIVATE include src) --TARGET_INCLUDE_DIRECTORIES(hardware-config PRIVATE include src ${CPUINFO_SOURCE_DIR}/include) -+TARGET_INCLUDE_DIRECTORIES(hardware-config PRIVATE include src) -+ - TARGET_INCLUDE_DIRECTORIES(indirection PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(microparams-init PRIVATE include src) - TARGET_INCLUDE_DIRECTORIES(normalization PRIVATE include src) -@@ -1104,14 +1106,9 @@ IF(NOT TARGET cpuinfo) + + INCLUDE_DIRECTORIES(.) +-TARGET_INCLUDE_DIRECTORIES(xnnpack-hardware-config PRIVATE include src ${CPUINFO_SOURCE_DIR}/include) ++TARGET_INCLUDE_DIRECTORIES(xnnpack-hardware-config PRIVATE include src) + IF(XNNPACK_BUILD_LIBRARY) + TARGET_INCLUDE_DIRECTORIES(XNNPACK PUBLIC include) + IF(WIN32) +@@ -1164,14 +1164,9 @@ IF(NOT TARGET cpuinfo) "${CPUINFO_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/cpuinfo") ELSE() @@ -33,7 +29,7 @@ index f0b3410ae..ba54c3bfe 100644 ENDIF() ENDIF() IF(XNNPACK_BUILD_LIBRARY) -@@ -1129,16 +1126,12 @@ IF(NOT TARGET pthreadpool) +@@ -1189,16 +1184,12 @@ IF(NOT TARGET pthreadpool) "${PTHREADPOOL_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/pthreadpool") ELSE() @@ -53,7 +49,7 @@ index f0b3410ae..ba54c3bfe 100644 ENDIF() ENDIF() TARGET_LINK_LIBRARIES(xnnpack-base INTERFACE pthreadpool) -@@ -1152,12 +1145,12 @@ IF(NOT TARGET fxdiv) +@@ -1212,12 +1203,12 @@ IF(NOT TARGET fxdiv) "${FXDIV_SOURCE_DIR}" "${CMAKE_BINARY_DIR}/FXdiv") ELSE() diff --git a/cmake/vcpkg-ports/xnnpack/portfile.cmake b/cmake/vcpkg-ports/xnnpack/portfile.cmake index d63ad0fbd0cce..60b3566629e10 100644 --- a/cmake/vcpkg-ports/xnnpack/portfile.cmake +++ b/cmake/vcpkg-ports/xnnpack/portfile.cmake @@ -5,8 +5,8 @@ endif() vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO google/XNNPACK - REF 953dcb96cc1b21b4b966952f8ee67a9e1f0d3e71 - SHA512 8c12930ef3b2f832962682d73c362518c014bb4e56d0c5cad2b8b63a03c91dccf6e6a3fd0eb91931fc5872c7df9773e76bf08553fc9c3cc22c94636c74815e94 + REF 3cf85e705098622d59056dcb8f5f963ea7bb0a00 + SHA512 af10afde80def08dc3b20a35bd38e84f9f749865ecc4bc9733b5d99d8a2f0f30c19c3f23472d65462a907b3a58226e3b254354a92a6baa31031824f68012a055 HEAD_REF master PATCHES fix-build.patch diff --git a/cmake/vcpkg-ports/xnnpack/vcpkg.json b/cmake/vcpkg-ports/xnnpack/vcpkg.json index e0d0600902f36..643b5c4abe166 100644 --- a/cmake/vcpkg-ports/xnnpack/vcpkg.json +++ b/cmake/vcpkg-ports/xnnpack/vcpkg.json @@ -1,6 +1,6 @@ { "name": "xnnpack", - "version-date": "2025-01-23", + "version-date": "2025-06-22", "description": "High-efficiency floating-point neural network inference operators for mobile, server, and Web", "homepage": "https://github.com/google/XNNPACK", "license": "BSD-3-Clause", diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs index 098a18b7444cf..2467475b6b189 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Exceptions.shared.cs @@ -23,8 +23,8 @@ internal enum ErrorCode ModelLoaded = 8, NotImplemented = 9, InvalidGraph = 10, - ShapeInferenceNotRegistered = 11, - RequirementNotRegistered = 12, + ShapeInferenceNotRegistered = 11, // TODO: should be ORT_EP_FAIL + RequirementNotRegistered = 12, // TODO: should be ORT_MODEL_LOAD_CANCELED } /// diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index f3dcde1abe37a..9c6fc6ce57a20 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3089,7 +3089,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
k : int
Number of top experts to select from expert pool
normalize_routing_weights : int
@@ -3106,9 +3106,9 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T
-
3D input tensor with shape (num_experts, hidden_size, inter_size)
+
3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T
3D input tensor with shape (num_experts, inter_size, hidden_size)
fc2_experts_bias (optional) : T
@@ -3129,8 +3129,8 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float), tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
@@ -4523,7 +4523,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation_type : string
-
Activation function to use. Choose from relu, gelu, silu and identity. Default is relu
+
Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu
expert_weight_bits : int
Number of bits used in quantized weights. Default is 4 bits
k : int
@@ -4542,20 +4542,20 @@ This version of the operator has been available since version 1 of the 'com.micr
router_probs : T
2D input tensor with shape (num_rows, num_experts)
fc1_experts_weights : T1
-
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-
fc1_scales : T
-
2D input tensor with shape (num_experts, inter_size)
+
3D input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).
+
fc1_scales : T2
+
2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc1_experts_bias (optional) : T
-
2D optional input tensor with shape (num_experts, inter_size)
+
2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu
fc2_experts_weights : T1
3D input tensor with shape (num_experts, inter_size, hidden_size) or (num_experts, inter_size, hidden_size / 2)
-
fc2_scales : T
+
fc2_scales : T2
2D input tensor with shape (num_experts, hidden_size)
fc2_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, hidden_size)
fc3_experts_weights (optional) : T1
3D optional input tensor with shape (num_experts, hidden_size, inter_size) or (num_experts, hidden_size, inter_size / 2)
-
fc3_scales (optional) : T
+
fc3_scales (optional) : T2
2D optional input tensor with shape (num_experts, inter_size)
fc3_experts_bias (optional) : T
2D optional input tensor with shape (num_experts, inter_size)
@@ -4571,10 +4571,12 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float16)
-
Constrain input and output types to float or float16 tensors.
+
T : tensor(float16), tensor(bfloat16)
+
Constrain input and output types to float tensors.
T1 : tensor(uint8)
Constrain weights type to uint8 tensors.
+
T2 : tensor(float), tensor(float16)
+
Constrain scales type to float tensors.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 3b70e5da8b3e4..26b701fea6fbb 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -949,7 +949,7 @@ Do not modify directly.* |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)
**T3** = tensor(bfloat16), tensor(float), tensor(float16), tensor(uint8)| -|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**QK**|1+|**QK** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -957,7 +957,7 @@ Do not modify directly.* |PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |PagedAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* key_cache:**T**
*in* value_cache:**T**
*in* cumulative_sequence_length:**S**
*in* past_seqlens:**S**
*in* block_table:**S**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* key_cache_out:**T**
*out* value_cache_out:**T**|1+|**S** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)| -|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)| +|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T2**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T2**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T2**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(float16)
**T1** = tensor(uint8)
**T2** = tensor(bfloat16), tensor(float16)| |QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)| |QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| diff --git a/include/onnxruntime/core/common/status.h b/include/onnxruntime/core/common/status.h index da9735aa4e418..8cf6420f2d0f7 100644 --- a/include/onnxruntime/core/common/status.h +++ b/include/onnxruntime/core/common/status.h @@ -46,6 +46,7 @@ enum StatusCode { EP_FAIL = 11, MODEL_LOAD_CANCELED = 12, MODEL_REQUIRES_COMPILATION = 13, + NOT_FOUND = 14, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -78,6 +79,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "MODEL_LOAD_CANCELED"; case StatusCode::MODEL_REQUIRES_COMPILATION: return "MODEL_REQUIRES_COMPILATION"; + case StatusCode::NOT_FOUND: + return "NOT_FOUND"; default: return "GENERAL ERROR"; } @@ -114,6 +117,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_CANCELLED); case StatusCode::MODEL_REQUIRES_COMPILATION: return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); + case StatusCode::NOT_FOUND: + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); default: return E_FAIL; } diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 0d920ab7dac89..21aa797ce16eb 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -232,7 +232,7 @@ static Ort::Status GetOrtValueInfoTensorTypeShape(const OrtValueInfo& ort_value_ /*out*/ std::vector& dims, /*out*/ std::vector& symbolic_dims); static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, onnx::ValueInfoProto& value_info_proto); -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto); Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, onnx::GraphProto& graph_proto, @@ -379,7 +379,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& ort_graph, } onnx::AttributeProto* attr_proto = node_proto->add_attribute(); - ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_attr, *attr_proto)); + ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtOpAttrToProto(*ort_node, *ort_attr, *attr_proto)); } } @@ -652,7 +652,7 @@ static Ort::Status OrtValueInfoToProto(const OrtValueInfo& ort_value_info, return Ort::Status{nullptr}; } -static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { +static Ort::Status OrtOpAttrToProto(const OrtNode& ort_node, const OrtOpAttr& ort_attr, onnx::AttributeProto& attr_proto) { const OrtApi& ort_api = Ort::GetApi(); const char* attr_name = nullptr; @@ -758,6 +758,103 @@ static Ort::Status OrtOpAttrToProto(const OrtOpAttr& ort_attr, onnx::AttributePr break; } + case OrtOpAttrType::ORT_OP_ATTR_TENSOR: { + attr_proto.set_type(onnx::AttributeProto_AttributeType_TENSOR); + + onnx::TensorProto tensor_proto; + + // TensorProto as an attribute value doesn't require a name. + + OrtValue* ort_value = nullptr; + ORT_EP_UTILS_C_RETURN_IF_ERROR(ort_api.Node_GetTensorAttributeAsOrtValue(&ort_node, &ort_attr, &ort_value)); + + Ort::Value tensor(ort_value); + + // Get tensor type and shape info + Ort::TensorTypeAndShapeInfo type_shape_info = tensor.GetTensorTypeAndShapeInfo(); + + // Get tensor type + ONNXTensorElementDataType element_type = type_shape_info.GetElementType(); + + size_t element_size = 0; + switch (element_type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_FLOAT); + element_size = sizeof(float); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT8); + element_size = sizeof(uint8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT8); + element_size = sizeof(int8_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT16); + element_size = sizeof(uint16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT16); + element_size = sizeof(int16_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT32); + element_size = sizeof(int32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_INT64); + element_size = sizeof(int64_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_BOOL); + element_size = sizeof(bool); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_DOUBLE); + element_size = sizeof(double); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT32); + element_size = sizeof(uint32_t); + break; + } + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: { + tensor_proto.set_data_type(onnx::TensorProto_DataType_UINT64); + element_size = sizeof(uint64_t); + break; + } + default: { + std::string err_msg = "Unexpected ONNXTensorElementDataType with value " + std::to_string(static_cast(element_type)); + return Ort::Status(err_msg.c_str(), ORT_FAIL); + } + } + + auto shape = type_shape_info.GetShape(); + + for (auto& dim : shape) { + tensor_proto.add_dims(dim); + } + + size_t element_count = type_shape_info.GetElementCount(); + size_t data_bytes = element_count * element_size; + const void* data = tensor.GetTensorData(); + + // Copy the Ortvalue to TensorProto as raw data + tensor_proto.set_raw_data(data, data_bytes); + + *(attr_proto.mutable_t()) = std::move(tensor_proto); + break; + } default: { std::string err_msg = "Unexpected OrtOpAttrType with value " + std::to_string(static_cast(attr_type)); return Ort::Status(err_msg.c_str(), ORT_FAIL); diff --git a/include/onnxruntime/core/session/environment.h b/include/onnxruntime/core/session/environment.h index 7e49275e59b8b..306f81df38e48 100644 --- a/include/onnxruntime/core/session/environment.h +++ b/include/onnxruntime/core/session/environment.h @@ -20,7 +20,7 @@ #include "core/platform/threadpool.h" #include "core/session/abi_devices.h" -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" #include "core/session/onnxruntime_c_api.h" struct OrtThreadingOptions; diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2f0e4aa7ce108..2899a219bdda0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -264,6 +264,7 @@ typedef enum OrtErrorCode { ORT_EP_FAIL, ORT_MODEL_LOAD_CANCELED, ORT_MODEL_REQUIRES_COMPILATION, + ORT_NOT_FOUND, } OrtErrorCode; typedef enum OrtOpAttrType { @@ -275,6 +276,7 @@ typedef enum OrtOpAttrType { ORT_OP_ATTR_STRING, ORT_OP_ATTR_STRINGS, ORT_OP_ATTR_GRAPH, + ORT_OP_ATTR_TENSOR, } OrtOpAttrType; //! @} @@ -5846,14 +5848,13 @@ struct OrtApi { /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. * - * Note: - * The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference + * \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference * the same underlying graph. * * \param[in] src_graph The source OrtGraph instance. * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. * \param[in] num_nodes Number of nodes. - * \param[out] dst_sub_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. + * \param[out] dst_graph An OrtGraph created from a given set of nodes. Must be released by calling ReleaseGraph. * * \snippet{doc} snippets.dox OrtStatus Return Value * @@ -6032,6 +6033,11 @@ struct OrtApi { * Typical usage sets this to the result of Node_GetNumAttributes(). An error status is * returned if `num_attributes` is less than the number of node attributes. * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. + * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. @@ -6043,14 +6049,36 @@ struct OrtApi { * * \param[in] node The OrtNode instance. * \param[in] attribute_name The name of the attribute - * \param[out] attribute Output the attribute if its name matches 'attribute_name', otherwise output nullptr. + * \param[out] attribute Output parameter set to the OrtOpAttr instance if an attribute by the given name exists. + * For an unset optional attribute, `attribute` is set to NULL and a non-error status is + * returned. For an invalid attribute name, `attribute` is set to NULL and an error status with + * code ORT_NOT_FOUND is returned. + * + * \note ONNX Runtime automatically sets optional (unset) attributes to their default values if the default value + * is a constant expression that does not depend on other tensor/model characteristics. Conv's 'kernel_shape' + * attribute is an example of an optional attribute that does not have a constant default value. This function + * does not provide any unset optional attributes without a constant default value. * * \snippet{doc} snippets.dox OrtStatus Return Value * * \since Version 1.23. */ ORT_API2_STATUS(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); + + /** \brief Get the OrtNode's 'TENSOR' attribute as an OrtValue. + * + * \param[in] node The OrtNode instance. + * \param[in] attribute The OrtOpAttr instance. + * \param[out] attr_tensor If successful, contains the 'TENSOR' attribute as a newly created OrtValue. + Must be freed with OrtApi::ReleaseValue. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); /** \brief Get the attribute type as OrtOpAttrType from an OrtOpAttr. * diff --git a/java/src/main/java/ai/onnxruntime/OrtException.java b/java/src/main/java/ai/onnxruntime/OrtException.java index 5ec58ea137124..06c3d3cbc770c 100644 --- a/java/src/main/java/ai/onnxruntime/OrtException.java +++ b/java/src/main/java/ai/onnxruntime/OrtException.java @@ -81,11 +81,17 @@ public enum OrtErrorCode { /** The ONNX graph is invalid. */ ORT_INVALID_GRAPH(10), /** The ORT execution provider failed. */ - ORT_EP_FAIL(11); + ORT_EP_FAIL(11), + /** Model load was canceled. */ + ORT_MODEL_LOAD_CANCELED(12), + /** Model requires compilation. */ + ORT_MODEL_REQUIRES_COMPILATION(13), + /** Item was not found. */ + ORT_NOT_FOUND(14); private final int value; - private static final OrtErrorCode[] values = new OrtErrorCode[12]; + private static final OrtErrorCode[] values = new OrtErrorCode[15]; static { for (OrtErrorCode ot : OrtErrorCode.values()) { diff --git a/java/src/main/native/OrtJniUtil.c b/java/src/main/native/OrtJniUtil.c index fe19015d642f0..5d8efd7b476cb 100644 --- a/java/src/main/native/OrtJniUtil.c +++ b/java/src/main/native/OrtJniUtil.c @@ -1051,6 +1051,12 @@ jint convertErrorCode(OrtErrorCode code) { return 10; case ORT_EP_FAIL: return 11; + case ORT_MODEL_LOAD_CANCELED: + return 12; + case ORT_MODEL_REQUIRES_COMPILATION: + return 13; + case ORT_NOT_FOUND: + return 14; default: return -1; // Unknown error code } diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 860424d31ffab..0bb1b09687018 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -97,6 +97,10 @@ export class WebNNBackend { * Temporary tensors for the current session. */ private temporarySessionTensorIds: Map = new Map(); + /** + * Maps from session id to MLOpSupportLimits. + */ + private mlOpSupportLimitsBySessionId = new Map(); constructor(env: Env) { configureLogger(env.logLevel!, !!env.debug); @@ -172,6 +176,10 @@ export class WebNNBackend { } sessionIds.add(sessionId); + if (!this.mlOpSupportLimitsBySessionId.has(sessionId)) { + this.mlOpSupportLimitsBySessionId.set(sessionId, mlContext.opSupportLimits()); + } + if (this.temporaryGraphInputs.length > 0) { this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs); this.temporaryGraphInputs = []; @@ -192,6 +200,7 @@ export class WebNNBackend { } this.tensorManager.releaseTensorsForSession(sessionId); this.mlContextBySessionId.delete(sessionId); + this.mlOpSupportLimitsBySessionId.delete(sessionId); const sessionIds = this.sessionIdsByMLContext.get(mlContext)!; sessionIds.delete(sessionId); if (sessionIds.size === 0) { @@ -207,6 +216,10 @@ export class WebNNBackend { return this.mlContextBySessionId.get(sessionId); } + public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined { + return this.mlOpSupportLimitsBySessionId.get(sessionId); + } + public reserveTensorId(): TensorId { return this.tensorManager.reserveTensorId(); } @@ -399,17 +412,17 @@ export class WebNNBackend { } public isGraphInputOutputTypeSupported(sessionId: number, type: Tensor.Type, isInput = true): boolean { - const context = this.mlContextBySessionId.get(sessionId); const dataType = onnxDataTypeToWebnnDataType.get(tensorDataTypeStringToEnum(type)); + const opLimits = this.mlOpSupportLimitsBySessionId.get(sessionId); if (typeof dataType === 'undefined') { return false; } if (isInput) { - return !!context?.opSupportLimits().input.dataTypes.includes(dataType); + return !!opLimits?.input.dataTypes.includes(dataType); } else { - return !!context?.opSupportLimits().output.dataTypes.includes(dataType); + return !!opLimits?.output.dataTypes.includes(dataType); } } diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 784e5ab53c1fa..70a2a9a892470 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -336,11 +336,12 @@ class TensorIdTracker { copyOld: boolean, ): Promise { const context = this.tensorManager.getMLContext(sessionId); + const opLimits = this.tensorManager.getMLOpSupportLimits(sessionId); let fallbackDataType: MLOperandDataType | undefined; // Check if the context supports the data type. If not, try to use the fallback data type. - if (!context.opSupportLimits().input.dataTypes.includes(dataType)) { + if (!opLimits?.input.dataTypes.includes(dataType)) { fallbackDataType = webnnDataTypeToFallback.get(dataType); - if (!fallbackDataType || !context.opSupportLimits().input.dataTypes.includes(fallbackDataType)) { + if (!fallbackDataType || opLimits?.input.dataTypes.includes(fallbackDataType)) { throw new Error(`WebNN backend does not support data type: ${dataType}`); } LOG_DEBUG( @@ -460,6 +461,10 @@ class TensorManagerImpl implements TensorManager { return context; } + public getMLOpSupportLimits(sessionId: number): MLOpSupportLimits | undefined { + return this.backend.getMLOpSupportLimits(sessionId); + } + public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index e2bb3b508ca7c..85a2cbaea0e44 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "core/common/cpuid_info.h" // for CPUIDInfo::GetCPUIDInfo().HasArm_SME() #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/mlas/inc/mlas.h" @@ -10,6 +11,7 @@ #include "core/util/math_cpuonly.h" #include "core/util/qmath.h" +#include #include #include @@ -169,43 +171,40 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { // only pack Matrix B if (input_idx == GetBIdx()) { const Tensor* b_zp_constant_tensor{nullptr}; - bool b_quantization_is_asymmetric = false; + bool b_quantization_might_be_asymmetric = false; - // zero point tensor could be provided as a direct input to the kernel and not as a constant so this - // test is not sufficient const OrtValue* b_zp; if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { b_zp_constant_tensor = &b_zp->Get(); } - // MlasDynamicQgemm requires symmetric quantization for B, so no zero point should exist or it should - // have a zero value - if (b_zp_constant_tensor != nullptr) { // Covers the case where tensor is not a constant - const auto& shape = b_zp_constant_tensor->Shape(); - const auto* zp_data = static_cast(b_zp_constant_tensor->DataRaw()); - size_t zp_size = static_cast(shape.Size()); - // MlasDynamicQgemm requires symmetric quantization: zp must be scalar 0 or 1D all-zero - if ((shape.NumDimensions() == 0) && (zp_data[0] == 0)) { - b_quantization_is_asymmetric = false; - } else if (shape.NumDimensions() == 1) { - b_quantization_is_asymmetric = false; - for (size_t i = 0; i < zp_size; ++i) { - if (zp_data[i] != 0) { - b_quantization_is_asymmetric = true; - break; - } - } - } else { - // Unsupported higher-rank zp tensor - b_quantization_is_asymmetric = true; - } + // MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros + // or not provided. + if (b_zp_constant_tensor != nullptr) { + // B zero point is constant. Check if it is all zeros. + assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); + const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); + const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); + b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes, + [](std::byte v) { return v != std::byte{0}; }); + } else { + // B zero point input is not constant. If it exists, we can't assume symmetric quantization. + const auto input_defs = Info().node().InputDefs(); + const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists(); + b_quantization_might_be_asymmetric = b_zp_input_exists; } // MlasDynamicQgemm requires scale data to be available at packing stage const Tensor* b_scale_tensor = nullptr; const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); - can_use_dynamic_quant_mlas_ = (!b_quantization_is_asymmetric && b_scale_available); + can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); + + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + // We check that here too before attempting to use them. + if (!CPUIDInfo::GetCPUIDInfo().HasArm_SME()) { + can_use_dynamic_quant_mlas_ = false; + } // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 1a4a63de38790..e8cdc50ed4ca7 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -78,8 +78,11 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0, "num_experts should be divisible by world_size"); - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index f3346d4513261..a50ee907c302b 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -92,7 +92,9 @@ class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, double, Crop); class CUDA_ONNX_OP_TYPED_CLASS_NAME(1, MLFloat16, Crop); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float, MoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, MoE); -class CUDA_MS_OP_CLASS_NAME(1, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, MoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, BFloat16, QMoE); +class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16, QMoE); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_float, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, float_MLFloat16, MultiHeadAttention); class CUDA_MS_OP_TYPED_CLASS_NAME(1, MLFloat16_float, MultiHeadAttention); @@ -307,7 +309,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h index 36127054cfd5e..d5ad8161e100e 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels.h @@ -52,6 +52,7 @@ enum class ActivationType { Gelu, GeGLU, ReGLU, SiGLU, + SwiGLU, Identity, InvalidType }; diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu new file mode 100644 index 0000000000000..5f0a71147b366 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_bf16.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, __nv_bfloat16>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu new file mode 100644 index 0000000000000..4a84581127156 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint4.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, cutlass::uint4b_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu new file mode 100644 index 0000000000000..6c23127955ac2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_bf16_uint8.cu @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#pragma warning(disable : 4244) +#pragma warning(disable : 4200) +#endif + +#include "contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +namespace ort_fastertransformer { +template class MoeGemmRunner<__nv_bfloat16, uint8_t>; +} // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index ef1f97b9e57a2..f855092670bc3 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -53,6 +53,8 @@ #include "cutlass_heuristic.h" #include "moe_gemm_kernels.h" +#include + #include #include #include @@ -66,8 +68,8 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, const int multi_processor_count, cudaStream_t stream, int* kernel_occupancy = nullptr) { - static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value, - "Specialized for half, float"); + static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || cutlass::platform::is_same::value, + "Specialized for half, float, bfloat16"); static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value || @@ -76,12 +78,11 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w // The cutlass type for the input elements. This is needed to convert to cutlass::half_t if necessary. using ElementType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, T>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, T>::type>::type; using ElementType = ElementType_; using CutlassWeightType_ = - typename cutlass::platform::conditional::value, cutlass::half_t, - WeightType>::type; + typename cutlass::platform::conditional::value, cutlass::half_t, typename cutlass::platform::conditional::value, cutlass::bfloat16_t, WeightType>::type>::type; using CutlassWeightType = CutlassWeightType_; @@ -391,12 +392,10 @@ void MoeGemmRunner::dispatch_to_arch(const T* A, con dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else if (sm_ >= 80 && sm_ < 90) { + } else if (sm_ >= 80) { // Hopper and Blackwell will fallback to use Ampere kernels. dispatch_moe_gemm_to_cutlass( A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, gemm_k, num_experts, gemm_config, sm_, multi_processor_count_, stream, occupancy); - } else { - ORT_THROW("[MoE][GEMM Dispatch] Arch unsupported for MoE GEMM"); } } @@ -478,6 +477,7 @@ void MoeGemmRunner::moe_gemm_bias_act(const T* A, const WeightTyp int64_t total_rows, int64_t gemm_n, int64_t gemm_k, int num_experts, ActivationType activation_type, cudaStream_t stream) { + // Swiglu will use Identity to call this function so we not need to handle it here. switch (activation_type) { case ActivationType::Relu: run_gemm(A, B, weight_scales, biases, C, total_rows_before_expert, total_rows, gemm_n, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index bfbe1d81b1c15..fc412a02e0383 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -38,12 +38,86 @@ #include "moe_kernel.h" +#include #include #include #include +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" + namespace ort_fastertransformer { static constexpr int WARP_SIZE = 32; + +// SwiGLU with interleaved is like the following python code using PyTorch: +// dim = x.shape[-1] +// x = x.view(-1, dim // 2, 2) +// x_glu, x_linear = x[..., 0], x[..., 1] +// y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) +template +__global__ void swiglu_kernel_interleaved(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + T x_glu = row_input[2 * i]; + T x_linear = row_input[2 * i + 1]; + + float sigmoid_arg = swiglu_alpha * static_cast(x_glu); + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = static_cast(x_glu) * sigmoid_out; + row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f)); + } +} + +// Non interleaved version of SwiGLU kernel, which splits each row into two chunks of same size. +template +__global__ void swiglu_kernel_chunked(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha) { + int const row = blockIdx.x; + if (row >= num_rows) { + return; + } + + T const* row_input = input + row * 2 * intermediate_size; + T* row_output = output + row * intermediate_size; + + for (int i = threadIdx.x; i < intermediate_size; i += blockDim.x) { + T x_glu = row_input[i]; + T x_linear = row_input[i + intermediate_size]; + + float sigmoid_arg = swiglu_alpha * static_cast(x_glu); + float sigmoid_out = 1.f / (1.f + expf(-sigmoid_arg)); + + float swish_out = static_cast(x_glu) * sigmoid_out; + row_output[i] = static_cast(swish_out * (static_cast(x_linear) + 1.f)); + } +} + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream) { + if (num_rows == 0) { + return; + } + dim3 block(std::min(intermediate_size, 1024)); + dim3 grid(num_rows); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("swiglu input", input, num_rows, 2 * intermediate_size); + + if constexpr (interleaved) { + swiglu_kernel_interleaved<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); + } else { + swiglu_kernel_chunked<<>>(output, input, intermediate_size, num_rows, swiglu_alpha); + } + + DUMP_TENSOR("swiglu output", output, num_rows, intermediate_size); +} + // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing the output // in the softmax kernel when we extend this module to support expert-choice routing. @@ -666,9 +740,14 @@ __global__ void dispatch_activations_kernel(int64_t* total_rows_before_expert, i } template -CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, bool has_fc3, +CutlassMoeFCRunner::CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer) - : has_fc3_(has_fc3), total_past_rows_(0), total_covered_rows_(0), normalize_routing_weights_(normalize_routing_weights), use_sparse_mixer_(use_sparse_mixer) { + : activation_type_(activation_type), + has_fc3_(has_fc3), + total_past_rows_(0), + total_covered_rows_(0), + normalize_routing_weights_(normalize_routing_weights), + use_sparse_mixer_(use_sparse_mixer) { moe_gemm_runner_.initialize(sm_version); } @@ -695,8 +774,16 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro total_ws_bytes += buf_size * sizeof(T); // permuted_data total_ws_bytes += padded_experts * sizeof(int64_t); // Hold total_rows_before_expert_ total_ws_bytes += num_softmax_outs * sizeof(T); - const size_t bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); - const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(num_rows)); + + size_t bytes_for_fc1_result; + if (activation_type_ == ActivationType::SwiGLU) { + // Space for both fc1_result_ and act_result_. + bytes_for_fc1_result = (2 * interbuf_size + interbuf_size) * sizeof(T); + } else { + bytes_for_fc1_result = has_fc3_ ? 2 * interbuf_size * sizeof(T) : interbuf_size * sizeof(T); + } + + const size_t sorter_ws_size_bytes = pad_to_multiple_of_16(sorter_.getWorkspaceSize(k * num_rows)); sorter_.update_num_experts(static_cast(num_experts)); size_t bytes_for_intermediate_and_sorting = bytes_for_fc1_result; @@ -705,7 +792,7 @@ size_t CutlassMoeFCRunner::getWorkspaceSize(size_t num_ro bytes_for_intermediate_and_sorting += remaining_bytes; } - total_ws_bytes += bytes_for_intermediate_and_sorting; // intermediate (fc1) output + cub sorting workspace + total_ws_bytes += bytes_for_intermediate_and_sorting; return total_ws_bytes; } @@ -725,27 +812,49 @@ void CutlassMoeFCRunner::configure_ws_ptrs(char* ws_ptr, total_rows_before_expert_ = reinterpret_cast(permuted_data_ + buf_size); + char* current_ptr = reinterpret_cast(total_rows_before_expert_ + padded_experts); + + if (activation_type_ == ActivationType::SwiGLU) { + // fc1_result_ is used for GEMM1 output (2 * inter_size) + fc1_result_ = reinterpret_cast(current_ptr); + current_ptr += 2 * interbuf_size * sizeof(T); + + // act_result_ is used for SwiGLU output (inter_size) + act_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); + + ORT_ENFORCE(!has_fc3_, "SwiGLU activation is not supported with fc3"); + } else { + fc1_result_ = reinterpret_cast(current_ptr); + act_result_ = nullptr; // No extra buffer for activation since it is done inplace. + current_ptr += interbuf_size * sizeof(T); + } + if (has_fc3_) { - fc3_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); - fc1_result_ = reinterpret_cast(fc3_result_ + interbuf_size); + fc3_result_ = reinterpret_cast(current_ptr); + current_ptr += interbuf_size * sizeof(T); } else { - fc1_result_ = reinterpret_cast(total_rows_before_expert_ + padded_experts); + fc3_result_ = nullptr; } const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); if (!is_pow_2 || num_experts > 256) { - softmax_out_ = reinterpret_cast(fc1_result_ + interbuf_size); + softmax_out_ = reinterpret_cast(current_ptr); } else { softmax_out_ = nullptr; } } namespace { - -struct __align__(8) Half4 { +typedef struct __CUDA_ALIGN__(8) { half2 x; half2 y; -}; +} half2_2; + +typedef struct __CUDA_ALIGN__(8) { + __nv_bfloat162 x; + __nv_bfloat162 y; +} __nv_bfloat162_2; // TODO(wy): move to common header template @@ -756,7 +865,11 @@ struct T4 { }; template <> struct T4 { - using Type = Half4; + using Type = half2_2; +}; +template <> +struct T4<__nv_bfloat16> { + using Type = __nv_bfloat162_2; }; template @@ -769,6 +882,10 @@ template <> struct T2 { using Type = half2; }; +template <> +struct T2<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; inline __device__ float2 operator*(const float2 a, const float2 b) { return make_float2(a.x * b.x, a.y * b.y); } @@ -785,15 +902,27 @@ inline __device__ half2 operator*(const half2 a, const half2 b) { return make_ha #endif // TODO(wy): use cuda common header and investigate pipeline build issue. -inline __device__ Half4 operator*(const Half4 a, const Half4 b) { +inline __device__ half2_2 operator*(const half2_2 a, const half2_2 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 530 && \ ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) - Half4 result; + half2_2 result; result.x = a.x * b.x; result.y = a.y * b.y; return result; #else - return Half4{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; + return half2_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; +#endif +} + +inline __device__ __nv_bfloat162_2 operator*(const __nv_bfloat162_2 a, const __nv_bfloat162_2 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 && \ + ((__CUDACC_VER_MAJOR__ < 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ < 2))) + __nv_bfloat162_2 result; + result.x = a.x * b.x; + result.y = a.y * b.y; + return result; +#else + return __nv_bfloat162_2{__hmul2(a.x, b.x), __hmul2(a.y, b.y)}; #endif } @@ -880,8 +1009,51 @@ void CutlassMoeFCRunner::run_moe_fc( stream); } - // moe_gemm_runner_.try_find_best_config(local_num_experts, hidden_size, inter_size, - // expanded_active_expert_rows); + if (fc1_activation_type == ActivationType::SwiGLU) { + T* gemm1_output_buffer = fc1_result_; + T* swiglu_output_buffer = act_result_; + + moe_gemm_runner_.moe_gemm_bias_act( + permuted_data_ + total_past_rows_ * hidden_size, + fc1_expert_weights, + fc1_scales, + fc1_expert_biases, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + 2 * inter_size, + hidden_size, + local_num_experts, + ActivationType::Identity, + stream); + + constexpr bool swiglu_interleaved = true; + constexpr float swiglu_alpha = 1.702f; + invokeSwiGLU( + swiglu_output_buffer + total_past_rows_ * inter_size, + gemm1_output_buffer + total_past_rows_ * 2 * inter_size, + inter_size, + static_cast(total_covered_rows_), + swiglu_alpha, + stream); + + moe_gemm_runner_.moe_gemm( + swiglu_output_buffer + total_past_rows_ * inter_size, + fc2_expert_weights, + fc2_scales, + nullptr, + fc2_result + total_past_rows_ * hidden_size, + total_rows_before_expert_ + local_experts_start_index, + expanded_active_expert_rows, + hidden_size, + inter_size, + local_num_experts, + stream); + + // No fc3 for SwiGLU + return; + } + moe_gemm_runner_.moe_gemm_bias_act( permuted_data_ + total_past_rows_ * hidden_size, fc1_expert_weights, fc1_scales, fc1_expert_biases, fc1_result_ + total_past_rows_ * inter_size, total_rows_before_expert_ + local_experts_start_index, @@ -1151,18 +1323,26 @@ template void topk_gating_softmax_kernelLauncher(const float*, const bool*, floa int, bool, bool, cudaStream_t); template void topk_gating_softmax_kernelLauncher(const half*, const bool*, half*, half*, int*, int*, int, int, int, bool, bool, cudaStream_t); +template void topk_gating_softmax_kernelLauncher(const __nv_bfloat16*, const bool*, __nv_bfloat16*, __nv_bfloat16*, int*, int*, int, int, + int, bool, bool, cudaStream_t); // ==================== Variable batched GEMM specializations ================================== template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, __nv_bfloat16>; +// For qMoE: template class CutlassMoeFCRunner; template class CutlassMoeFCRunner; +template class CutlassMoeFCRunner<__nv_bfloat16, cutlass::uint4b_t>; +template class CutlassMoeFCRunner<__nv_bfloat16, uint8_t>; // ===================== Specializations for init routing ========================= template void initialize_moe_routing_kernelLauncher(const float*, float*, const int*, int*, int, int, int, int, cudaStream_t); template void initialize_moe_routing_kernelLauncher(const half*, half*, const int*, int*, int, int, int, int, cudaStream_t); +template void initialize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const int*, int*, int, int, int, int, + cudaStream_t); // ==================== Specializations for final routing =================================== template void finalize_moe_routing_kernelLauncher(const float*, float*, const float*, const float*, const int*, @@ -1177,5 +1357,10 @@ template void finalize_moe_routing_kernelLauncher(const float*, float*, const fl const float*, const int*, const int*, int, int, int, cudaStream_t); template void finalize_moe_routing_kernelLauncher(const half*, half*, const half*, const half*, const half*, const half*, const int*, const int*, int, int, int, cudaStream_t); +template void finalize_moe_routing_kernelLauncher(const __nv_bfloat16*, __nv_bfloat16*, const __nv_bfloat16*, + const __nv_bfloat16*, const int*, const int*, int, int, int, cudaStream_t); + +template void invokeSwiGLU(float*, float const*, int, int, float, cudaStream_t); +template void invokeSwiGLU(half*, half const*, int, int, float, cudaStream_t); } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h index c457b608decbf..de11d357a8c07 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.h @@ -54,7 +54,10 @@ static inline size_t pad_to_multiple_of_16(size_t input) { template void topk_gating_softmax_kernelLauncher(const T* input, const bool* finished, T* output, T* softmax_temp_out, int* indices, int* source_row, int num_rows, int num_experts, int k, - cudaStream_t stream); + bool normalize_routing_weights, bool use_sparse_mixer, cudaStream_t stream); + +template +void invokeSwiGLU(T* output, T const* input, int intermediate_size, int num_rows, float swiglu_alpha, cudaStream_t stream); class CubKeyValueSorter { public: @@ -109,7 +112,7 @@ template class CutlassMoeFCRunner { public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); + CutlassMoeFCRunner(int sm_version, ActivationType activation_type, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k); @@ -157,8 +160,10 @@ class CutlassMoeFCRunner { int64_t* total_rows_before_expert_; T* fc1_result_; + T* act_result_; T* fc3_result_; + ActivationType activation_type_; bool has_fc3_; bool normalize_routing_weights_; bool use_sparse_mixer_; @@ -173,14 +178,4 @@ class CutlassMoeFCRunner { std::vector total_rows_before_expert_host_; }; -template -class CutlassMoeFCRunner::value>> { - public: - CutlassMoeFCRunner(int sm_version, bool has_fc3, bool normalize_routing_weights, bool use_sparse_mixer); - - size_t getWorkspaceSize(size_t num_rows, size_t hidden_size, size_t inter_size, size_t num_experts, size_t k) { - return 0; - } -}; - } // namespace ort_fastertransformer diff --git a/onnxruntime/contrib_ops/cuda/moe/moe.cc b/onnxruntime/contrib_ops/cuda/moe/moe.cc index c5352d931ce2c..6409a6e12afc6 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe.cc +++ b/onnxruntime/contrib_ops/cuda/moe/moe.cc @@ -3,6 +3,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_type_conversion.h" #include "moe.h" using namespace onnxruntime::cuda; @@ -20,6 +21,7 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) +REGISTER_KERNEL_TYPED(BFloat16) template MoE::MoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { @@ -42,14 +44,17 @@ Status MoE::ComputeInternal(OpKernelContext* context) const { fc1_experts_bias_optional, fc2_experts_weights, fc2_experts_bias_optional, fc3_experts_weights_optional, fc3_experts_bias_optional)); - typedef typename ToCudaType::MappedType CudaT; + using CudaT = typename OrtToCudaType::type; auto stream = context->GetComputeStream(); auto& device_prop = GetDeviceProp(); const int sm = device_prop.major * 10 + device_prop.minor; - ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, fc3_experts_weights_optional != nullptr, - normalize_routing_weights_, use_sparse_mixer_); + ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, + fc3_experts_weights_optional != nullptr, + normalize_routing_weights_, + use_sparse_mixer_); size_t ws_size = moe_runner.getWorkspaceSize( static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size), diff --git a/onnxruntime/contrib_ops/cuda/moe/moe_base.h b/onnxruntime/contrib_ops/cuda/moe/moe_base.h index 6b65557444a66..194f33acbeb59 100644 --- a/onnxruntime/contrib_ops/cuda/moe/moe_base.h +++ b/onnxruntime/contrib_ops/cuda/moe/moe_base.h @@ -76,15 +76,16 @@ class MoEBase { } const int64_t coe = quant_type == MoEQuantType::UINT4 ? 2 : 1; - if (fc1_experts_weights_dims[2] != inter_size / coe) { + const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1; + if (fc1_experts_weights_dims[2] != act * inter_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_weights_dims[2] must be equal to inter_size, got ", - fc1_experts_weights_dims[2], " and ", inter_size); + "fc1_experts_weights_dims[2] is ", + fc1_experts_weights_dims[2], " expected ", act * inter_size / coe); } if (fc2_experts_weights_dims[2] != hidden_size / coe) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc2_experts_weights_dims[2] must be equal to hidden_size, got ", - fc2_experts_weights_dims[2], " and ", hidden_size); + "fc2_experts_weights_dims[2] is ", + fc2_experts_weights_dims[2], " expected ", hidden_size / coe); } if (router_probs_dims.size() != 2) { @@ -116,10 +117,10 @@ class MoEBase { "fc2_experts_bias_dims[0] must be equal to num_experts, got ", fc2_experts_bias_dims[0], " and ", num_experts); } - if (fc1_experts_bias_dims[1] != inter_size) { + if (fc1_experts_bias_dims[1] != act * inter_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "fc1_experts_bias_dims[1] must be equal to inter_size, got ", fc1_experts_bias_dims[1], - " and ", inter_size); + "fc1_experts_bias_dims[1] is ", fc1_experts_bias_dims[1], + ", expected ", act * inter_size); } if (fc2_experts_bias_dims[1] != hidden_size) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -182,10 +183,14 @@ class MoEBase { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[0] must be equal to num_experts, got ", fc1_experts_scales_dims[0], " and ", num_experts); } - if (fc1_experts_scales_dims[1] != inter_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to inter_size, got ", - fc1_experts_scales_dims[1], " and ", inter_size); + + // The activation type affects the output dimension of the first FC layer. + const int64_t act = activation_type_ == ort_fastertransformer::ActivationType::SwiGLU ? 2 : 1; + if (fc1_experts_scales_dims[1] != act * inter_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc1_experts_scales[1] must be equal to act * inter_size, got ", + fc1_experts_scales_dims[1], " and ", act * inter_size); } + if (fc2_experts_scales_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "fc2_experts_scales must be 2D, got ", fc2_experts_scales->Shape().GetDims().size()); @@ -219,6 +224,8 @@ class MoEBase { activation_type_ = ort_fastertransformer::ActivationType::Gelu; } else if (activation_type_str == "silu") { activation_type_ = ort_fastertransformer::ActivationType::Silu; + } else if (activation_type_str == "swiglu") { + activation_type_ = ort_fastertransformer::ActivationType::SwiGLU; } else if (activation_type_str == "identity") { activation_type_ = ort_fastertransformer::ActivationType::Identity; } else { diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc index 4dd5a079d1a29..aef31c7e9ed3a 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.cc @@ -5,6 +5,7 @@ #include "core/common/safeint.h" #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/quantization/moe_quantization.h" +#include "core/providers/cuda/cuda_type_conversion.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -14,16 +15,6 @@ namespace onnxruntime { namespace contrib { namespace cuda { -#define REGISTER_KERNEL() \ - ONNX_OPERATOR_KERNEL_EX(QMoE, kMSDomain, 1, kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(0, 0) \ - .TypeConstraint("T", BuildKernelDefConstraints()) \ - .TypeConstraint("T1", BuildKernelDefConstraints()), \ - QMoE); - -REGISTER_KERNEL() - namespace { template struct ToCudaTypeWrapper : public ToCudaType {}; @@ -40,27 +31,29 @@ struct ToCudaTypeWrapper { } // anonymous namespace -QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { +template +QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoEBase(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("expert_weight_bits", &expert_weight_bits_).IsOK()); ORT_ENFORCE(expert_weight_bits_ == 8 || expert_weight_bits_ == 4, "expert_weight_bits must be 4 or 8, but got ", expert_weight_bits_); } +template template -Status QMoE::QuantizedMoEImpl(OpKernelContext* context, - MoEParameters& moe_params, - const Tensor* input, - const Tensor* router_probs, - const Tensor* fc1_experts_weights, - const Tensor* fc1_experts_bias_optional, - const Tensor* fc2_experts_weights, - const Tensor* fc2_experts_bias_optional, - const Tensor* fc3_experts_weights_optional, - const Tensor* fc3_experts_bias_optional, - const Tensor* fc1_scales, - const Tensor* fc2_scales, - const Tensor* fc3_scales_optional, - const cudaDeviceProp& device_prop) const { +Status QMoE::QuantizedMoEImpl(OpKernelContext* context, + MoEParameters& moe_params, + const Tensor* input, + const Tensor* router_probs, + const Tensor* fc1_experts_weights, + const Tensor* fc1_experts_bias_optional, + const Tensor* fc2_experts_weights, + const Tensor* fc2_experts_bias_optional, + const Tensor* fc3_experts_weights_optional, + const Tensor* fc3_experts_bias_optional, + const Tensor* fc1_scales, + const Tensor* fc2_scales, + const Tensor* fc3_scales_optional, + const cudaDeviceProp& device_prop) const { auto stream = context->GetComputeStream(); const int sm = device_prop.major * 10 + device_prop.minor; @@ -68,10 +61,10 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - using T = MLFloat16; - using CudaT = typename ToCudaType::MappedType; + using CudaT = typename OrtToCudaType::type; ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm, + activation_type_, fc3_experts_weights_optional != nullptr, normalize_routing_weights_, use_sparse_mixer_); @@ -136,7 +129,8 @@ Status QMoE::QuantizedMoEImpl(OpKernelContext* context, return Status::OK(); } -Status QMoE::ComputeInternal(OpKernelContext* context) const { +template +Status QMoE::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* router_probs = context->Input(1); const Tensor* fc1_experts_weights = context->Input(2); @@ -183,6 +177,32 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { #endif } +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + QMoE, + kMSDomain, + 1, + BFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .MayInplace(0, 0) + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + QMoE); + } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h index c0164576d7c7f..c4698a1f277ef 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h +++ b/onnxruntime/contrib_ops/cuda/quantization/moe_quantization.h @@ -14,6 +14,7 @@ namespace cuda { using namespace onnxruntime::cuda; +template class QMoE final : public CudaKernel, public MoEBase { public: explicit QMoE(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h index 71161c120a306..39b0ccdc7fe6a 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/webgpu/bert/attention_common.h @@ -50,6 +50,7 @@ struct WebgpuAttentionParameters { v_hidden_size_(parameters.kv_hidden_size), v_head_size_(parameters.kv_hidden_size / parameters.kv_num_heads), num_heads_(parameters.num_heads), + is_unidirectional_(true), do_rotary_(parameters.do_rotary), scale_(parameters.scale), seqlen_past_kv_cache_(parameters.seqlen_past_kv_cache), diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index c9e182bf10f2f..dbe2614099be1 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -275,8 +275,23 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { var previous_max : q_element_t = min_value; var previous_denom : q_element_t = 0; +)MAIN_FN"; - for(var k_start = 0u; k_start < uniforms.total_sequence_length; k_start+=capped_sg_size) + if (is_unidirectional_) { + // If attention is unidirectional, set the loop bound to enforce causal masking. + shader.MainFunctionBody() << R"MAIN_FN( + let max_causal_len_for_workgroup = uniforms.past_sequence_length + + (workgroup_idx % uniforms.num_seq_tile + 1) * workgroup_size_x; + let loop_bound = min(uniforms.total_sequence_length, max_causal_len_for_workgroup); +)MAIN_FN"; + } else { + shader.MainFunctionBody() << R"MAIN_FN( + let loop_bound = uniforms.total_sequence_length; +)MAIN_FN"; + } + + shader.MainFunctionBody() << R"MAIN_FN( + for(var k_start = 0u; k_start < loop_bound; k_start+=capped_sg_size) { workgroupBarrier(); loadk(k_start, head_idx / uniforms.n_reps, local_idx, capped_sg_size); @@ -337,7 +352,7 @@ Status FlashAttentionProgram::GenerateShaderCode(ShaderHelper& shader) const { qk_4 = qk_4 + loadAttentionBias(q_idx_global, k_start+12, head_idx); } - let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_gqa > 0); + let seq_causal_length = select(uniforms.total_sequence_length, uniforms.past_sequence_length + q_idx_global + 1, uniforms.is_unidirectional > 0); // Neuter qk values where K is out of bounds. qk_1[0] = select(min_value, qk_1[0], k_start+0 < seq_causal_length); qk_1[1] = select(min_value, qk_1[1], k_start+1 < seq_causal_length); @@ -903,7 +918,13 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co bool has_attention_bias = attention_bias != nullptr; bool is_qualcomm = context.AdapterInfo().vendor == std::string_view{"qualcomm"}; bool is_fp16 = (Q->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16); - FlashAttentionProgram program{"FlashAttention", has_attention_bias, is_qualcomm, is_fp16, parameters.head_size_, parameters.num_heads_}; + FlashAttentionProgram program{"FlashAttention", + has_attention_bias, + is_qualcomm, + is_fp16, + parameters.head_size_, + parameters.num_heads_, + parameters.is_unidirectional_}; program.AddInputs({{Q, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_key, ProgramTensorMetadataDependency::TypeAndRank, 4}, {present_value, ProgramTensorMetadataDependency::TypeAndRank, 4}}); @@ -916,12 +937,12 @@ Status ApplyFlashAttention(const Tensor* Q, const Tensor* K, const Tensor* V, co const uint32_t num_seq_tile = (parameters.sequence_length_ + tile_size - 1) / tile_size; program.SetDispatchGroupSize(parameters.num_heads_ * num_seq_tile) .SetWorkgroupSize(tile_size) - .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, is_qualcomm) + .CacheHint(has_attention_bias, parameters.head_size_, parameters.num_heads_, parameters.is_unidirectional_, is_qualcomm) .AddUniformVariables({{static_cast(parameters.sequence_length_)}, {static_cast(parameters.total_sequence_length_)}, {static_cast(present_sequence_length)}, {static_cast(parameters.total_sequence_length_ - parameters.kv_sequence_length_)}, - {static_cast(parameters.is_gqa_ ? 1 : 0)}, + {static_cast(parameters.is_unidirectional_)}, {static_cast(parameters.n_reps)}, {alpha}, {num_seq_tile}}); diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 3f79b80fb73bc..9908b33a38372 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -39,13 +39,15 @@ class FlashAttentionProgram final : public Program { bool is_qualcomm, bool is_fp16, int qkv_head_size, - int qkv_num_heads) + int qkv_num_heads, + bool is_unidirectional) : Program{kernel_name}, has_attention_bias_(has_attention_bias), is_qualcomm_(is_qualcomm), is_fp16_(is_fp16), qkv_head_size_(qkv_head_size), - qkv_num_heads_(qkv_num_heads) { + qkv_num_heads_(qkv_num_heads), + is_unidirectional_(is_unidirectional) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -54,7 +56,7 @@ class FlashAttentionProgram final : public Program { {"total_sequence_length", ProgramUniformVariableDataType::Uint32}, {"present_sequence_length", ProgramUniformVariableDataType::Uint32}, {"past_sequence_length", ProgramUniformVariableDataType::Uint32}, - {"is_gqa", ProgramUniformVariableDataType::Uint32}, + {"is_unidirectional", ProgramUniformVariableDataType::Uint32}, {"n_reps", ProgramUniformVariableDataType::Uint32}, {"alpha", ProgramUniformVariableDataType::Float32}, {"num_seq_tile", ProgramUniformVariableDataType::Uint32}); @@ -65,6 +67,7 @@ class FlashAttentionProgram final : public Program { bool is_fp16_; int qkv_head_size_; int qkv_num_heads_; + bool is_unidirectional_; }; class FlashAttentionDecodeQKTProgram final : public Program { diff --git a/onnxruntime/core/graph/abi_graph_types.h b/onnxruntime/core/graph/abi_graph_types.h index 6383d29d7a2bc..504b102e782fd 100644 --- a/onnxruntime/core/graph/abi_graph_types.h +++ b/onnxruntime/core/graph/abi_graph_types.h @@ -251,6 +251,16 @@ struct OrtNode { /// A status indicating success or an error. virtual onnxruntime::Status GetAttributes(gsl::span attrs) const = 0; + /// + /// Gets the node's 'TENSOR' attribute as an OrtValue. + /// + /// Node's 'TENSOR' attribute. + /// Output parameter is set to a newly created OrtValue containing the 'TENSOR' attribute value, + /// only if the attribute is of type 'TENSOR' + /// A status indicating success or an error. + virtual onnxruntime::Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attr, + OrtValue*& value) const = 0; + /// /// Gets the number of node subgraphs. /// diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5511275239e45..686ebfb1f6fb5 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1392,20 +1392,20 @@ constexpr const char* MoE_ver1_doc = R"DOC( ONNX_MS_OPERATOR_SET_SCHEMA(MoE, 1, OpSchema() .SetDoc(MoE_ver1_doc) - .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) + .Attr("activation_type", "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", "Number of top experts to select from expert pool", AttributeProto::INT, static_cast(1)) .Attr("normalize_routing_weights", "Whether to normalize routing weights", AttributeProto::INT, static_cast(0)) .Attr("use_sparse_mixer", "Whether to use sparse mixer", AttributeProto::INT, static_cast(0)) .Input(0, "input", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") .Input(1, "router_probs", "2D input tensor with shape (num_rows, num_experts)", "T") - .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size)", "T") - .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size), or (num_experts, hidden_size, 2 * inter_size) for swiglu", "T") + .Input(3, "fc1_experts_bias", "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(4, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size)", "T") .Input(5, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", "T", OpSchema::Optional) .Input(6, "fc3_experts_weights", "3D optional input tensor with shape (num_experts, hidden_size, inter_size)", "T", OpSchema::Optional) .Input(7, "fc3_experts_bias", "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) .Output(0, "output", "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape (batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA( @@ -1413,7 +1413,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema() .SetDoc("Quantized MoE") .Attr("activation_type", - "Activation function to use. Choose from relu, gelu, silu and identity. Default is relu", + "Activation function to use. Choose from relu, gelu, silu, swiglu and identity. Default is relu", AttributeProto::STRING, std::string("relu")) .Attr("k", @@ -1438,18 +1438,18 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size) " - "or (num_experts, hidden_size, inter_size / 2)", + "or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).", "T1") - .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size)", "T") + .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2") .Input(4, "fc1_experts_bias", - "2D optional input tensor with shape (num_experts, inter_size)", "T", OpSchema::Optional) + "2D optional input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T", OpSchema::Optional) .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size) " "or (num_experts, inter_size, hidden_size / 2)", "T1") - .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T") + .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2") .Input(7, "fc2_experts_bias", "2D optional input tensor with shape (num_experts, hidden_size)", @@ -1464,7 +1464,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(9, "fc3_scales", "2D optional input tensor with shape (num_experts, inter_size)", - "T", + "T2", OpSchema::Optional) .Input(10, "fc3_experts_bias", @@ -1476,8 +1476,9 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "2D input tensor with shape (num_rows, hidden_size) or 3D input tensor with shape " "(batch_size, sequence_length, hidden_size)", "T") - .TypeConstraint("T", {"tensor(float16)"}, "Constrain input and output types to float or float16 tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("T1", {"tensor(uint8)"}, "Constrain weights type to uint8 tensors.") + .TypeConstraint("T2", {"tensor(float)", "tensor(float16)"}, "Constrain scales type to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 4ceadb6191a9b..eb7fb6937c29e 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -87,6 +87,24 @@ static void ConvertNodeArgsToValueInfos(const EpGraph* ep_graph, } } +#if !defined(ORT_MINIMAL_BUILD) +static bool IsOptionalAttribute(const Node& node, const std::string& attr_name) { + const ONNX_NAMESPACE::OpSchema* op_schema = node.Op(); + if (op_schema == nullptr) { + return false; + } + + auto attr_schema_iter = op_schema->attributes().find(attr_name); + if (attr_schema_iter == op_schema->attributes().end()) { + return false; // Not an attribute for this operator type. + } + + const ONNX_NAMESPACE::OpSchema::Attribute& attr_schema = attr_schema_iter->second; + + return !attr_schema.required; +} +#endif // !defined(ORT_MINIMAL_BUILD) + // // EpNode // @@ -230,6 +248,32 @@ Status EpNode::GetAttributes(gsl::span dst) const { return Status::OK(); } +Status EpNode::GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, OrtValue*& result) const { + const auto* attr_proto = reinterpret_cast(attribute); + + if (attr_proto->type() != onnx::AttributeProto::TENSOR) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "This OrtOpAttr instance is not a 'TENSOR' attribute"); + } + + const auto& graph_viewer = ep_graph_->GetGraphViewer(); + const auto& tensor_proto = attr_proto->t(); + + // Check that TensorProto is valid. + ORT_ENFORCE(utils::HasDataType(tensor_proto), "Tensor proto doesn't have data type."); + ORT_ENFORCE(ONNX_NAMESPACE::TensorProto::DataType_IsValid(tensor_proto.data_type()), "Tensor proto has invalid data type."); + ORT_ENFORCE(!utils::HasExternalData(tensor_proto), + "Tensor proto with external data for value attribute is not supported."); + + // Initialize OrtValue for tensor attribute. + auto tensor_attribute_value = std::make_unique(); + AllocatorPtr tensor_attribute_allocator = CPUAllocator::DefaultInstance(); + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue(Env::Default(), graph_viewer.ModelPath(), tensor_proto, + tensor_attribute_allocator, *tensor_attribute_value)); + + result = tensor_attribute_value.release(); + return Status::OK(); +} + Status EpNode::GetNumSubgraphs(size_t& num_subgraphs) const { num_subgraphs = subgraphs_.size(); return Status::OK(); @@ -268,13 +312,20 @@ gsl::span EpNode::GetOutputsSpan() const { return outputs_; } -const OrtOpAttr* EpNode::GetAttribute(const std::string& name) const { +const OrtOpAttr* EpNode::GetAttribute(const std::string& name, bool& is_unset_optional_attr) const { auto iter = attributes_map_.find(name); - if (iter == attributes_map_.end()) { - return nullptr; - } else { + if (iter != attributes_map_.end()) { + is_unset_optional_attr = false; return reinterpret_cast(iter->second.get()); } + +#if !defined(ORT_MINIMAL_BUILD) + is_unset_optional_attr = IsOptionalAttribute(node_, name); +#else + // This is not properly set in a minimal build because it does not have access to the operator schema. + is_unset_optional_attr = false; +#endif // !defined(ORT_MINIMAL_BUILD) + return nullptr; } const std::string& EpNode::GetEpName() const { diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index 243bdc2944ffb..be78d77360cb8 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -183,6 +183,9 @@ struct EpNode : public OrtNode { // Gets the node's attributes. Status GetAttributes(gsl::span attrs) const override; + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* attribute, + OrtValue*& attr_tensor) const override; + // Gets the number of subgraphs contained by this node. Status GetNumSubgraphs(size_t& num_subgraphs) const override; @@ -209,8 +212,9 @@ struct EpNode : public OrtNode { // Helper that returns this node's outputs as a span of EpValueInfo pointers. gsl::span GetOutputsSpan() const; - // Helper that gets the node's attributes by name. - const OrtOpAttr* GetAttribute(const std::string& name) const; + // Helper that gets the node's attributes by name. If the attribute is not set, returns NULL and sets the + // output parameter `is_unset_optional_attr` to true if this is an unset optional attribute. + const OrtOpAttr* GetAttribute(const std::string& name, bool& is_unset_optional_attr) const; // Helper that gets the execution provider name that this node is assigned to run on. const std::string& GetEpName() const; diff --git a/onnxruntime/core/graph/model_editor_api_types.h b/onnxruntime/core/graph/model_editor_api_types.h index 5d84e48182bfe..d3795d911b22f 100644 --- a/onnxruntime/core/graph/model_editor_api_types.h +++ b/onnxruntime/core/graph/model_editor_api_types.h @@ -137,6 +137,11 @@ struct ModelEditorNode : public OrtNode { "OrtModelEditorApi does not support getting attribute OrtOpAttr for OrtNode"); } + Status GetTensorAttributeAsOrtValue(const OrtOpAttr* /*attribute*/, OrtValue*& /*attr_tensor*/) const override { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "OrtModelEditorApi does not support getting 'TENSOR' attribute for OrtNode"); + } + Status GetNumSubgraphs(size_t& /*num_subgraphs*/) const override { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "OrtModelEditorApi does not support getting the subgraphs for OrtNode"); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index dbf86e2bb7fc7..7aba3b9549f23 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -67,6 +67,7 @@ #include "core/optimizer/qdq_transformer/avx2_weight_s8_to_u8.h" #endif #include "core/optimizer/qdq_transformer/weight_bias_quantization.h" +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" #include "core/optimizer/qdq_transformer/ensure_unique_dq_for_node_unit.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" @@ -271,6 +272,7 @@ InlinedVector> GenerateTransformers( // It runs unconditionally in InferenceSession::TransformGraph() prior to Level1 optimizers. // We also put it here with other Level1 optimizers so that it can fix things up after their changes. transformers.emplace_back(std::make_unique()); + transformers.emplace_back(std::make_unique()); } // add __backwardpass attribute to nodes after YieldOp, ROCm-only diff --git a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc new file mode 100644 index 0000000000000..a8b9814f1020c --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.cc @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" + +#include "core/framework/tensorprotoutils.h" +#include "core/common/common.h" +#include "core/util/qmath.h" +#include "core/graph/graph_utils.h" +#include "core/graph/graph_viewer.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" +#include "core/optimizer/qdq_transformer/qdq_util.h" + +namespace onnxruntime { +bool WhereDummyDq::SatisfyCondition(const Graph& graph, const Node& node) const { + if (!(node.OpType() == "Where")) { + return false; + } + const auto& where_inputs = node.InputDefs(); + const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name()); + const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name()); + + bool is_p1_dq = (parent_node_1 && parent_node_1->OpType() == QDQ::DQOpName); + bool is_p2_dq = (parent_node_2 && parent_node_2->OpType() == QDQ::DQOpName); + + // WhereDummyDq focus on WhereOp with one DQ input and one scalar initializer input + if (is_p1_dq && !parent_node_2) { + return (where_inputs[2]->Shape()->dim_size() == 0); + } + if (!parent_node_1 && is_p2_dq) { + return (where_inputs[1]->Shape()->dim_size() == 0); + } + return false; +} + +Status WhereDummyDq::InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const { + const auto& where_inputs = node.InputDefs(); + const Node* parent_node_1 = graph.GetProducerNode(where_inputs[1]->Name()); + const Node* parent_node_2 = graph.GetProducerNode(where_inputs[2]->Name()); + + // With SatisfyCondition, we must have one DQ and one initializer + const Node* dq_node = parent_node_1 ? parent_node_1 : parent_node_2; + int const_idx = parent_node_1 ? 2 : 1; + + const ONNX_NAMESPACE::TensorProto* dq_node_scale_proto = nullptr; + graph.GetInitializedTensor(dq_node->InputDefs()[1]->Name(), dq_node_scale_proto); + const ONNX_NAMESPACE::TensorProto* dq_node_zp_proto = nullptr; + graph.GetInitializedTensor(dq_node->InputDefs()[2]->Name(), dq_node_zp_proto); + + // Dummy data initializer. + ONNX_NAMESPACE::TensorProto dummy_data_proto; + dummy_data_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_data")); + // Set data type to dq node's zp dtype + dummy_data_proto.set_data_type(dq_node_zp_proto->data_type()); + + // Dummy zero point initializer. + ONNX_NAMESPACE::TensorProto dummy_zp_proto; + dummy_zp_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_zp")); + dummy_zp_proto.set_data_type(dq_node_zp_proto->data_type()); + + switch (dummy_zp_proto.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_INT8: { + int8_t zp = 0; + int8_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 1); + dummy_data_proto.set_raw_data(&dummy_data, 1); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { + uint8_t zp = 0; + uint8_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 1); + dummy_data_proto.set_raw_data(&dummy_data, 1); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + int16_t zp = 0; + int16_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 2); + dummy_data_proto.set_raw_data(&dummy_data, 2); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + uint16_t zp = 0; + uint16_t dummy_data = 1; + dummy_zp_proto.set_raw_data(&zp, 2); + dummy_data_proto.set_raw_data(&dummy_data, 2); + break; + } + default: + LOGS(logger, WARNING) << "Currently support existing DQ's zero point with INT8, UINT8, INT16, UINT16"; + return Status::OK(); + } + + // Set dummy scale to the original value + const ONNX_NAMESPACE::TensorProto* const_node_data_proto = nullptr; + graph.GetInitializedTensor(where_inputs[const_idx]->Name(), const_node_data_proto); + Initializer initializer(graph, *const_node_data_proto, graph.ModelPath()); + if (dq_node_scale_proto->data_type() != const_node_data_proto->data_type()) { + // WhereDummyDq fills the const value to the dummy DQ's scale + LOGS(logger, WARNING) << "Currently only support existing DQ's scale with same datatype as scalar"; + return Status::OK(); + } + + // Dummy scale initializer. + ONNX_NAMESPACE::TensorProto dummy_scale_proto; + dummy_scale_proto.set_name(graph.GenerateNodeArgName(node.Name() + "_dummy_scale")); + dummy_scale_proto.set_data_type(dq_node_scale_proto->data_type()); + switch (initializer.data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: { + float* where_const_scalar = initializer.data(); + dummy_scale_proto.set_raw_data(where_const_scalar, sizeof(float)); + break; + } + default: + LOGS(logger, WARNING) << "Currently support scalar with FLOAT"; + return Status::OK(); + } + + // Start editing the graph + NodeArg& dummy_data_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_data_proto); + NodeArg& dummy_scale_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_scale_proto); + NodeArg& dummy_zp_arg = graph_utils::AddInitializerWithExternalData(graph, dummy_zp_proto); + + ONNX_NAMESPACE::TypeProto dummy_dq_type_proto = utils::TypeProtoFromTensorProto(*const_node_data_proto); + dummy_dq_type_proto.mutable_tensor_type()->set_elem_type(const_node_data_proto->data_type()); + NodeArg& dummy_dq_arg = + graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(node.Name() + "_dummy_dq"), &dummy_dq_type_proto); + Node& dummy_dq_node = + graph.AddNode( + graph.GenerateNodeArgName(node.Name() + "_dummy_dq"), + QDQ::DQOpName, + "DeQuantizeLinear from WhereDummyDq GraphTransformer", + {&dummy_data_arg, &dummy_scale_arg, &dummy_zp_arg}, + {&dummy_dq_arg}, + nullptr, + dq_node->Domain()); + + node.MutableInputDefs()[const_idx] = &dummy_dq_arg; + if (graph.GetConsumerNodes(where_inputs[const_idx]->Name()).size() == 0) { + graph.RemoveInitializedTensor(where_inputs[const_idx]->Name()); + } + graph.AddEdge(dummy_dq_node.Index(), node.Index(), 0, const_idx); + modified = true; + + return Status::OK(); +} + +Status WhereDummyDq::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { + const GraphViewer graph_viewer{graph}; + const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder(); + for (const auto node_idx : node_indices) { + auto* node_ptr = graph.GetNode(node_idx); + if (!node_ptr) { + continue; + } + + Node& node = *node_ptr; + ORT_RETURN_IF_ERROR(Recurse(node, modified, graph_level, logger)); + + if (this->SatisfyCondition(graph, node)) { + ORT_RETURN_IF_ERROR(WhereDummyDq::InsertDummyDQ(node, graph, modified, logger)); + } + } + + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h new file mode 100644 index 0000000000000..3260a865f8c4b --- /dev/null +++ b/onnxruntime/core/optimizer/qdq_transformer/where_dummy_dq.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/graph_transformer.h" + +namespace onnxruntime { + +/** + @Class WhereDummyDq + + Graph transformer that inserts a dummy DQ on Where node's initializer input + to form Node Unit when Where node has one DQ and one scalar initializer input +*/ +class WhereDummyDq : public GraphTransformer { + public: + WhereDummyDq() noexcept : GraphTransformer("WhereDummyDq") {} + + private: + Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; + + bool SatisfyCondition(const Graph& graph, const Node& node) const; + Status InsertDummyDQ(Node& node, Graph& graph, bool& modified, const logging::Logger& logger) const; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index 9cf89a04f031c..6cbbdd4e0a7ef 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -40,13 +40,16 @@ void Telemetry::SetLanguageProjection(uint32_t projection) const { void Telemetry::LogProcessInfo() const { } -void Telemetry::LogSessionCreationStart() const { +void Telemetry::LogSessionCreationStart(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } -void Telemetry::LogEvaluationStop() const { +void Telemetry::LogEvaluationStop(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } -void Telemetry::LogEvaluationStart() const { +void Telemetry::LogEvaluationStart(uint32_t session_id) const { + ORT_UNUSED_PARAMETER(session_id); } void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index cb7a6176e5aec..b60345e1b8a80 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -48,11 +48,11 @@ class Telemetry { virtual void LogProcessInfo() const; - virtual void LogSessionCreationStart() const; + virtual void LogSessionCreationStart(uint32_t session_id) const; - virtual void LogEvaluationStop() const; + virtual void LogEvaluationStop(uint32_t session_id) const; - virtual void LogEvaluationStart() const; + virtual void LogEvaluationStart(uint32_t session_id) const; virtual void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 44ef44a3f5aff..2e5d334856278 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -194,7 +194,7 @@ void WindowsTelemetry::LogProcessInfo() const { process_info_logged = true; } -void WindowsTelemetry::LogSessionCreationStart() const { +void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; @@ -203,23 +203,26 @@ void WindowsTelemetry::LogSessionCreationStart() const { TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingLevel(WINEVENT_LEVEL_INFO)); } -void WindowsTelemetry::LogEvaluationStop() const { +void WindowsTelemetry::LogEvaluationStop(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; TraceLoggingWrite(telemetry_provider_handle, - "EvaluationStop"); + "EvaluationStop", + TraceLoggingUInt32(session_id, "sessionId")); } -void WindowsTelemetry::LogEvaluationStart() const { +void WindowsTelemetry::LogEvaluationStart(uint32_t session_id) const { if (global_register_count_ == 0 || enabled_ == false) return; TraceLoggingWrite(telemetry_provider_handle, - "EvaluationStart"); + "EvaluationStart", + TraceLoggingUInt32(session_id, "sessionId")); } void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index 7281063d50c2e..261d14a7fed8c 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -41,11 +41,11 @@ class WindowsTelemetry : public Telemetry { void LogProcessInfo() const override; - void LogSessionCreationStart() const override; + void LogSessionCreationStart(uint32_t session_id) const override; - void LogEvaluationStop() const override; + void LogEvaluationStop(uint32_t session_id) const override; - void LogEvaluationStart() const override; + void LogEvaluationStart(uint32_t session_id) const override; void LogSessionCreation(uint32_t session_id, int64_t ir_version, const std::string& model_producer_name, const std::string& model_producer_version, const std::string& model_domain, diff --git a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc index cf9f44f4cd8f0..1cac133ab0c2c 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_allocator.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_allocator.cc @@ -18,7 +18,7 @@ void MIGraphXAllocator::CheckDevice() const { int current_device; auto hip_err = hipGetDevice(¤t_device); if (hip_err == hipSuccess) { - ORT_ENFORCE(current_device == Info().id); + ORT_ENFORCE(current_device == Info().device.Id()); } #endif } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 84c7fe3e4d4ab..e9b1fe0f39da5 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -151,10 +151,10 @@ struct MIGraphX_Provider : Provider { const OrtSessionOptions& session_options, const OrtLogger& logger, std::unique_ptr& ep) override { + ORT_UNUSED_PARAMETER(num_devices); const ConfigOptions* config_options = &session_options.GetConfigOptions(); std::array configs_array = {&provider_options, config_options}; - const void* arg = reinterpret_cast(&configs_array); auto ep_factory = CreateExecutionProviderFactory(&provider_options); ep = ep_factory->CreateProvider(session_options, logger); @@ -181,26 +181,47 @@ struct MigraphXEpFactory : OrtEpFactory { const char* ep_name, OrtHardwareDeviceType hw_type, const OrtLogger& default_logger_in) - : ort_api{ort_api_in}, ep_name{ep_name}, ort_hw_device_type{hw_type}, default_logger{default_logger_in} { + : ort_api{ort_api_in}, default_logger{default_logger_in}, ep_name{ep_name}, ort_hw_device_type{hw_type} { + ort_version_supported = ORT_API_VERSION; // set to the ORT version we were compiled with GetName = GetNameImpl; GetVendor = GetVendorImpl; + GetVendorId = GetVendorIdImpl; + GetVersion = GetVersionImpl; + GetSupportedDevices = GetSupportedDevicesImpl; CreateEp = CreateEpImpl; ReleaseEp = ReleaseEpImpl; + + CreateAllocator = CreateAllocatorImpl; + ReleaseAllocator = ReleaseAllocatorImpl; + CreateDataTransfer = CreateDataTransferImpl; + + IsStreamAware = IsStreamAwareImpl; + CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; } // Returns the name for the EP. Each unique factory configuration must have a unique name. // Ex: a factory that supports NPU should have a different than a factory that supports GPU. - static const char* GetNameImpl(const OrtEpFactory* this_ptr) { + static const char* GetNameImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->ep_name.c_str(); } - static const char* GetVendorImpl(const OrtEpFactory* this_ptr) { + static const char* GetVendorImpl(const OrtEpFactory* this_ptr) noexcept { const auto* factory = static_cast(this_ptr); return factory->vendor.c_str(); } + static uint32_t GetVendorIdImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->vendor_id; + } + + static const char* GetVersionImpl(const OrtEpFactory* this_ptr) noexcept { + const auto* factory = static_cast(this_ptr); + return factory->version.c_str(); + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of MigraphX EP is not currently setup to partition a model among @@ -212,7 +233,7 @@ struct MigraphXEpFactory : OrtEpFactory { size_t num_devices, OrtEpDevice** ep_devices, size_t max_ep_devices, - size_t* p_num_ep_devices) { + size_t* p_num_ep_devices) noexcept { size_t& num_ep_devices = *p_num_ep_devices; auto* factory = static_cast(this_ptr); @@ -237,18 +258,56 @@ struct MigraphXEpFactory : OrtEpFactory { _In_ size_t /*num_devices*/, _In_ const OrtSessionOptions* /*session_options*/, _In_ const OrtLogger* /*logger*/, - _Out_ OrtEp** /*ep*/) { + _Out_ OrtEp** /*ep*/) noexcept { return onnxruntime::CreateStatus(ORT_INVALID_ARGUMENT, "[MigraphX/AMDGPU EP] EP factory does not support this method."); } - static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) { + static void ReleaseEpImpl(OrtEpFactory* /*this_ptr*/, OrtEp* /*ep*/) noexcept { // no-op as we never create an EP here. } + static OrtStatus* ORT_API_CALL CreateAllocatorImpl(OrtEpFactory* this_ptr, + const OrtMemoryInfo* /*memory_info*/, + const OrtKeyValuePairs* /*allocator_options*/, + OrtAllocator** allocator) noexcept { + auto* factory = static_cast(this_ptr); + + *allocator = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, + "CreateAllocator should not be called as we did not add OrtMemoryInfo to our OrtEpDevice."); + } + + static void ORT_API_CALL ReleaseAllocatorImpl(OrtEpFactory* /*this_ptr*/, OrtAllocator* /*allocator*/) noexcept { + // should never be called as we don't implement CreateAllocator + } + + static OrtStatus* ORT_API_CALL CreateDataTransferImpl(OrtEpFactory* /*this_ptr*/, + OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; // not implemented + return nullptr; + } + + static bool ORT_API_CALL IsStreamAwareImpl(const OrtEpFactory* /*this_ptr*/) noexcept { + return false; + } + + static OrtStatus* ORT_API_CALL CreateSyncStreamForDeviceImpl(OrtEpFactory* this_ptr, + const OrtMemoryDevice* /*memory_device*/, + const OrtKeyValuePairs* /*stream_options*/, + OrtSyncStreamImpl** stream) noexcept { + auto* factory = static_cast(this_ptr); + + *stream = nullptr; + return factory->ort_api.CreateStatus( + ORT_INVALID_ARGUMENT, "CreateSyncStreamForDevice should not be called as IsStreamAware returned false."); + } + const OrtApi& ort_api; const OrtLogger& default_logger; const std::string ep_name; const std::string vendor{"AMD"}; + const std::string version{"1.0.0"}; // MigraphX EP version const uint32_t vendor_id{0x1002}; const OrtHardwareDeviceType ort_hw_device_type; // Supported OrtHardwareDevice diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e236cccaaaa77..d23d50549b2c5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -557,6 +557,67 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return ORT_VERSION; } + /** + * @brief Checks if a given OrtHardwareDevice is a supported NVIDIA GPU. + * + * This function verifies if the provided hardware device corresponds to a physical + * NVIDIA GPU that meets the minimum compute capability requirements for this execution provider. + * + * The check is performed by: + * 1. Extracting the LUID (Locally Unique Identifier) from the device's metadata. + * 2. Converting the string LUID to a 64-bit integer. + * 3. Iterating through all available CUDA devices on the system. + * 4. For each CUDA device, constructing its 64-bit LUID from its properties. + * 5. Comparing the LUIDs. If a match is found, it checks if the device's + * compute capability is at least 8.0 (Ampere) or newer. + * + * @param device The OrtHardwareDevice to check. + * @return True if the device is a supported NVIDIA GPU, false otherwise. + */ + bool IsOrtHardwareDeviceSupported(const OrtHardwareDevice& device) { + const auto& metadata_entries = device.metadata.Entries(); + const auto it = metadata_entries.find("LUID"); + if (it == metadata_entries.end()) { + return false; + } + + uint64_t target_luid; + try { + target_luid = std::stoull(it->second); + } catch (const std::exception&) { + return false; + } + + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + return false; + } + + for (int i = 0; i < device_count; ++i) { + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, i) != cudaSuccess) { + continue; + } + + // The LUID is an 8-byte value, valid on Windows when luidDeviceNodeMask is non-zero. + // We reconstruct the 64-bit integer representation from the raw bytes. + if (prop.luidDeviceNodeMask == 0) { + continue; + } + + // Ensure the LUID is 8 bytes and reinterpret it directly as a uint64_t for comparison. + static_assert(sizeof(prop.luid) == sizeof(uint64_t), "cudaDeviceProp::luid should be 8 bytes"); + uint64_t current_luid = *reinterpret_cast(prop.luid); + + if (current_luid == target_luid) { + // Ampere architecture or newer is required. + return prop.major >= 8; + } + } + + return false; + } + // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. // An EP created with this factory is expected to be able to execute a model with *all* supported // hardware devices at once. A single instance of NvTensorRtRtx EP is not currently setup to partition a model among @@ -579,11 +640,12 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { int16_t device_id = 0; for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { const OrtHardwareDevice& device = *devices[i]; + if (factory->ort_api.HardwareDevice_Type(&device) == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU && - factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id) { + factory->ort_api.HardwareDevice_VendorId(&device) == factory->vendor_id && + factory->IsOrtHardwareDeviceSupported(device)) { OrtKeyValuePairs* ep_options = nullptr; OrtKeyValuePairs* ep_metadata = nullptr; - factory->ort_api.CreateKeyValuePairs(&ep_options); factory->ort_api.CreateKeyValuePairs(&ep_metadata); factory->ort_api.AddKeyValuePair(ep_options, "device_id", std::to_string(device_id).c_str()); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index 0152ad27c0ba2..e248034f225ec 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -8,6 +8,13 @@ namespace onnxruntime { namespace qnn { +namespace { +bool IsOptionalNodeUnitIODef(const NodeUnitIODef& node_io_def) { + const NodeArg& arg = node_io_def.node_arg; + return !arg.Exists() || arg.Name().empty(); +} +} // namespace + std::string BaseOpBuilder::GetOpBuilderType() const { return op_builder_type_; } @@ -46,12 +53,18 @@ Status BaseOpBuilder::ProcessDataTypes(QnnModelWrapper& qnn_model_wrapper, const auto& inputs = node_unit.Inputs(); const auto& outputs = node_unit.Outputs(); for (auto input : inputs) { + if (IsOptionalNodeUnitIODef(input)) { + continue; + } TensorInfo tensor_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input, tensor_info)); Qnn_DataType_t qnn_data_type = tensor_info.qnn_data_type; input_qnn_dtypes.push_back(qnn_data_type); } for (auto output : outputs) { + if (IsOptionalNodeUnitIODef(output)) { + continue; + } TensorInfo tensor_info = {}; ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output, tensor_info)); Qnn_DataType_t qnn_data_type = tensor_info.qnn_data_type; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc index 9db0b5202dcd4..51c38b4483cb9 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/einsum_op_builder.cc @@ -45,13 +45,7 @@ std::optional ParseEquation(std::string_view equation_string) { if (term_1.empty() || term_2.empty()) { return std::nullopt; } - if (term_1.size() < 2) { - return std::nullopt; - } - if (term_1.size() != term_2.size()) { - return std::nullopt; - } - if (term_1.size() != result.size()) { + if (term_1.size() < 2 || term_2.size() < 2 || result.size() < 2) { return std::nullopt; } if (!std::all_of(term_1.begin(), term_1.end(), [](unsigned char c) { return std::islower(c); })) { @@ -154,6 +148,94 @@ bool IsEquationMatMulTransposeAll(const Equation& equation) { return true; } +bool IsEquationMatMulBroadcastTransposeY(const Equation& equation) { + // E.g., bhwc,hkc->bhwk + const auto& [term_1, term_2, result] = equation; + const size_t term1_dims = term_1.size(); + if (term1_dims != 4) { + return false; + } + const size_t term2_dims = term_2.size(); + if (term2_dims != 3) { + return false; + } + const size_t result_dims = result.size(); + if (result_dims != 4) { + return false; + } + // Check matrix multiplication dimensions + char term_1_m = term_1[term1_dims - 2]; + char term_1_k = term_1[term1_dims - 1]; + char term_2_k = term_2[term2_dims - 1]; + char term_2_n = term_2[term2_dims - 2]; + char result_m = result[result_dims - 2]; + char result_n = result[result_dims - 1]; + if (term_1_m != result_m) { + return false; + } + if (term_1_k != term_2_k) { + return false; + } + if (term_2_n != result_n) { + return false; + } + // Check batch dimensions + if (term_1[0] != result[0]) { + return false; + } + if (term_1[1] != result[1]) { + return false; + } + if (term_2[0] != result[1]) { + return false; + } + return true; +} + +bool IsEquationReduceSumMulBroadcastX(const Equation& equation) { + // E.g., bhwc,wkc->bhwk + const auto& [term_1, term_2, result] = equation; + if (term_1.size() != 4) { + return false; + } + if (term_2.size() != 3) { + return false; + } + if (result.size() != 4) { + return false; + } + + // Check contraction over last axis (c) + char c1 = term_1[3]; + char c2 = term_2[2]; + if (c1 != c2) { + return false; + } + + // Check w axis alignment + if (term_1[2] != term_2[0]) { + return false; + } + if (term_1[2] != result[2]) { + return false; + } + + // Check k axis alignment + if (term_2[1] != result[3]) { + return false; + } + + // Check batch dimensions + if (term_1[0] != result[0]) { + return false; + } + if (term_1[1] != result[1]) { + return false; + } + + return true; +} + /** * @brief Sets the parameter tensor names for a MatMul op. * @@ -267,6 +349,113 @@ Status CreateMatMulTransposeAll( return Status::OK(); } +/** + * @brief Creates a ReduceSum, Multiply on broadcasted input X and original input Y. + * + * @param qnn_model_wrapper Pointer to the QnnModelWrapper instance used to manage the QNN model. + * @param node_unit The NodeUnit representing the ONNX node to be converted. + * @param do_op_validation A boolean flag indicating whether to perform operation validation. + * @return Status indicating success or failure of the operation. + */ +Status CreateReduceSumMulBroadcastX( + onnxruntime::qnn::QnnModelWrapper* qnn_model_wrapper, + const onnxruntime::NodeUnit& node_unit, + std::vector&& input_names, + bool do_op_validation) { + // Reshape in0 to shape (b, h, w, 1, c) to expand dimension before the contraction axis 'c'. + // Allowing broadcast with in1 for multiplication, aligning the contraction axis for reduce. + onnxruntime::qnn::TensorInfo tensor_info_in0{}, tensor_info_in1{}, tensor_info_out{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[0], tensor_info_in0)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Inputs()[1], tensor_info_in1)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper->GetTensorInfo(node_unit.Outputs()[0], tensor_info_out)); + const std::vector& shape_in0 = tensor_info_in0.shape; + const std::vector& shape_in1 = tensor_info_in1.shape; + ORT_RETURN_IF_NOT(shape_in0.size() == 4, "CreateReduceSumMulBroadcastX expects input 0 to be rank 4"); + ORT_RETURN_IF_NOT(shape_in1.size() == 3, "CreateReduceSumMulBroadcastX expects input 1 to be rank 3"); + const std::vector new_shape_in0{shape_in0[0], shape_in0[1], shape_in0[2], 1, shape_in0[3]}; + const std::string reshape_out_name = input_names[0] + "_reshaped"; + ORT_RETURN_IF_ERROR(qnn_model_wrapper->AddReshapeNode( + /*input_name=*/input_names[0], + /*output_name=*/reshape_out_name, + /*input_shape=*/shape_in0, + /*output_shape=*/new_shape_in0, + /*tensor_data_type=*/tensor_info_in0.qnn_data_type, + /*quantize_param=*/tensor_info_in0.quant_param.Copy(), + /*do_op_validation=*/do_op_validation, + /*is_for_input=*/qnn_model_wrapper->IsGraphInput(input_names[0]))); + + // Multiply: reshaped in0 * in1 + // The output shape of the multiplication is determined by broadcasting the reshaped in0 of + // (b, h, w, 1, c) and in1 (w, k, c) along the matching axes, resulting in (b, h, w, k, c). + const std::string mul_out_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_mul"; + std::vector shape_out_mul{new_shape_in0[0], new_shape_in0[1], new_shape_in0[2], shape_in1[1], new_shape_in0[4]}; + onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_mul(mul_out_name, + QNN_TENSOR_TYPE_NATIVE, + tensor_info_in0.qnn_data_type, + tensor_info_in0.quant_param.Copy(), + std::move(shape_out_mul)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_mul)), + "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( + /*qnn_node_name=*/mul_out_name, + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_ELEMENT_WISE_MULTIPLY, + /*input_names=*/{reshape_out_name, input_names[1]}, + /*output_names=*/{mul_out_name}, + /*param_tensor_names=*/{}, + /*do_op_validation=*/do_op_validation), + "CreateReduceSumMulBroadcastX: failed to create Mul node"); + + std::vector param_tensor_names{}; + + // ReduceSum on last axes={4}, keep_dims=False + // Axis '4' corresponds to the last dimension ('c') of the reshaped tensor (b, h, w, k, c), + // which is the contraction axis for reduce sum op in the einsum equation (bhwc,wkc->bhwk). + std::vector axes_shape{SafeInt(1)}; + std::vector axes_value{SafeInt(4)}; + onnxruntime::qnn::QnnParamWrapper param_axes(node_unit.Index(), + node_unit.Name(), + QNN_OP_REDUCE_SUM_PARAM_AXES, + std::move(axes_shape), + std::move(axes_value)); + param_tensor_names.push_back(param_axes.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_axes)), + "CreateReduceSumMulBroadcastX: failed to add param axes"); + + Qnn_Scalar_t keep_dims_scalar = QNN_SCALAR_INIT; + keep_dims_scalar.dataType = QNN_DATATYPE_BOOL_8; + keep_dims_scalar.bool8Value = SafeInt(0); + onnxruntime::qnn::QnnParamWrapper param_keep_dims(node_unit.Index(), + node_unit.Name(), + QNN_OP_REDUCE_SUM_PARAM_KEEP_DIMS, + keep_dims_scalar); + param_tensor_names.push_back(param_keep_dims.GetParamTensorName()); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddParamWrapper(std::move(param_keep_dims)), + "CreateReduceSumMulBroadcastX: failed to add param keep_dims"); + + const std::string out_name = node_unit.Outputs()[0].node_arg.Name(); + Qnn_TensorType_t out_tensor_type = qnn_model_wrapper->IsGraphOutput(out_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + onnxruntime::qnn::QnnTensorWrapper tensor_wrapper_out(out_name, + out_tensor_type, + tensor_info_out.qnn_data_type, + tensor_info_out.quant_param.Copy(), + std::move(tensor_info_out.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper->AddTensorWrapper(std::move(tensor_wrapper_out)), + "CreateReduceSumMulBroadcastX: failed to AddTensorWrapper"); + + ORT_RETURN_IF_NOT(qnn_model_wrapper->CreateQnnNode( + /*qnn_node_name=*/out_name, + /*package_name=*/QNN_OP_PACKAGE_NAME_QTI_AISW, + /*qnn_node_type=*/QNN_OP_REDUCE_SUM, + /*input_names=*/{mul_out_name}, + /*output_names=*/{out_name}, + /*param_tensor_names=*/std::move(param_tensor_names), + /*do_op_validation=*/do_op_validation), + "CreateReduceSumMulBroadcastX: failed to create ReduceSum node"); + + return Status::OK(); +} + } // namespace namespace onnxruntime { @@ -317,9 +506,21 @@ Status EinsumOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, } if (!IsEquationMatMul(parsed_equation.value()) && !IsEquationMatMulTransposeY(parsed_equation.value()) && - !IsEquationMatMulTransposeAll(parsed_equation.value())) { + !IsEquationMatMulBroadcastTransposeY(parsed_equation.value()) && + !IsEquationMatMulTransposeAll(parsed_equation.value()) && + !IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } + if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { + if (IsGpuBackend(qnn_model_wrapper.GetQnnBackendType())) { + // QAIRT 3.36.1: Failed to validate on GPU. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " on backend GPU"); + } + if (node_unit.Inputs()[0].quant_param.has_value()) { + // QAIRT 3.36.1: Failed to finalize QNN graph 1002. + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation + " for quantized inputs"); + } + } return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } @@ -353,7 +554,8 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*logger=*/logger, /*do_op_validation=*/do_op_validation, /*qnn_op_type=*/QNN_OP_MAT_MUL)); - } else if (IsEquationMatMulTransposeY(parsed_equation.value())) { + } else if (IsEquationMatMulTransposeY(parsed_equation.value()) || + IsEquationMatMulBroadcastTransposeY(parsed_equation.value())) { std::vector param_tensor_names = SetMatMulParamTensorNames( &qnn_model_wrapper, node_unit, /*transpose_in0=*/false, /*transpose_in1=*/true); ORT_RETURN_IF_ERROR(ProcessOutputs(/*qnn_model_wrapper=*/qnn_model_wrapper, @@ -364,7 +566,15 @@ Status EinsumOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w /*do_op_validation=*/do_op_validation, /*qnn_op_type=*/QNN_OP_MAT_MUL)); } else if (IsEquationMatMulTransposeAll(parsed_equation.value())) { - ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(&qnn_model_wrapper, node_unit, std::move(input_names), do_op_validation)); + ORT_RETURN_IF_ERROR(CreateMatMulTransposeAll(/*qnn_model_wrapper=*/&qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*do_op_validation=*/do_op_validation)); + } else if (IsEquationReduceSumMulBroadcastX(parsed_equation.value())) { + ORT_RETURN_IF_ERROR(CreateReduceSumMulBroadcastX(/*qnn_model_wrapper=*/&qnn_model_wrapper, + /*node_unit=*/node_unit, + /*input_names=*/std::move(input_names), + /*do_op_validation=*/do_op_validation)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, node_unit.OpType() + " unsupported equation: " + equation); } diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc index 1ad910be810bb..03dba4e98a901 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/gemm_op_builder.cc @@ -55,8 +55,12 @@ Status GemmOpBuilder::ExplictOpCheck(const NodeUnit& node_unit) const { auto transB = node_helper.Get("transB", static_cast(0)); auto M = (transB == 0) ? inputB_shape.at(1) : inputB_shape.at(0); if (inputC_shape.size() == 0 || (inputC_shape.size() == 1 && inputC_shape.at(0) != M) || - (inputC_shape.size() == 2 && (inputC_shape.at(0) != 1 || inputC_shape.at(1) != M))) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support C with shape [M]."); + (inputC_shape.size() == 2 && inputC_shape.at(1) != M)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support C with shape [N, M]."); + } + + if (inputC_shape.size() == 2 && node_unit.Inputs()[2].quant_param.has_value() && inputC_shape.at(0) != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN FullyConnected Op only support quantized C with shape [1, M]."); } } @@ -133,7 +137,8 @@ Status GemmOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, qnn_model_wrapper.IsGraphInput(node_input_name))); } - if (2 == input_i && 2 == input_shape.size()) { + // Reshape [1, M] shape Bias. + if (2 == input_i && 2 == input_shape.size() && input_shape[0] == 1) { input_shape[0] = input_shape[1]; input_shape.resize(1); } @@ -199,8 +204,70 @@ Status GemmOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { - ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), {}, - logger, do_op_validation, GetQnnOpType(node_unit.OpType()))); + // FullyConnected dosen't support 2d bias with shape [N, M], In this case, decompose Gemm into FullyConnected + Add for compatibility. + bool split_gemm = false; + if (node_unit.Inputs().size() == 3) { + auto& input_c = node_unit.Inputs()[2]; + std::vector input_c_shape; + QnnModelWrapper::GetOnnxShape(input_c.node_arg, input_c_shape); + + // Split when input_c has 2d shape and not [1, M] + split_gemm = (input_c_shape.size() == 2 && input_c_shape.at(0) != 1); + } + + if (split_gemm) { + // If split_gemm, input and output of Gemm must at least 2d. + const std::string& org_output_name = node_unit.Outputs()[0].node_arg.Name(); + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info)); + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); + std::vector output_shape = output_info.shape; + QnnQuantParamsWrapper op_output_quant_param = output_info.quant_param.Copy(); + + const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(org_output_name); + + // Create FullyConnected Node + std::vector gemm_input_0_1; + gemm_input_0_1.push_back(input_names[0]); + gemm_input_0_1.push_back(input_names[1]); + std::string split_fully_connected_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected"; + std::string split_fully_connected_output_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_FullyConnected_output"; + QnnTensorWrapper fully_connected_output(split_fully_connected_output_name, QNN_TENSOR_TYPE_NATIVE, input_info.qnn_data_type, + QnnQuantParamsWrapper(), std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(fully_connected_output)), + "Failed to add FullyConnected output tensor."); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_fully_connected_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_FULLY_CONNECTED, + std::move(gemm_input_0_1), + {split_fully_connected_output_name}, + {}, + do_op_validation), + "Failed to add FullyConnected node."); + + // Create Add Node + Qnn_TensorType_t op_output_tensor_type = is_graph_output ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + std::string split_add_name = onnxruntime::qnn::utils::GetNodeName(node_unit) + "_split_add"; + QnnTensorWrapper op_output_tensor_wrapper(org_output_name, op_output_tensor_type, output_info.qnn_data_type, + op_output_quant_param.Copy(), std::vector(output_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(op_output_tensor_wrapper)), + "Failed to add ElementWiseAdd output tensor."); + std::string bias_name = input_names[2]; + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(split_add_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_ADD, + {split_fully_connected_output_name, bias_name}, // FullyConnected output as input + {org_output_name}, // Original output as output + {}, + do_op_validation), + "Failed to add ElementWiseAdd node."); + } else { + ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), {}, + logger, do_op_validation, GetQnnOpType(node_unit.OpType()))); + } + return Status::OK(); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 14f12c906f11a..5fc0b8900730b 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -580,6 +580,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { graph.RemoveInitializedTensor(tensor_name); }; the_global_api.graph_reverse_dfs_from_preemp = vaip::graph_reverse_dfs_from; + the_global_api.graph_save_string = vaip::graph_save_string; + if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 20ae1cfbfa2c1..028ee7fa8c5ce 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -15,7 +15,7 @@ #include "vaip/node.h" #include "vaip/node_arg.h" - +#include "./tensor_proto.h" namespace vaip { struct NodeEdgeT { @@ -205,6 +205,29 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri vai_assert(result, "model serialize to ostream error"); } +vaip_core::DllSafe graph_save_string(const Graph& graph) { + auto model_proto = const_cast(graph.GetModel()).ToProto(); + auto graph_proto_subgraph = graph.ToGraphProto(); + *model_proto->mutable_graph() = *graph_proto_subgraph; + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model = Model::Create(std::move(*model_proto), graph.ModelPath(), nullptr, logger); + model_proto = model->ToProto(); + auto& metadata = model->MetaData(); + if (!metadata.empty()) { + auto metadata_props = model_proto->mutable_metadata_props(); + metadata_props->Clear(); + for (auto& m : metadata) { + auto prop = metadata_props->Add(); + *prop->mutable_key() = m.first; + *prop->mutable_value() = m.second; + } + } + std::string graph_string; + bool result = model_proto->SerializeToString(graph_string); + vai_assert(result, "model serialize to string error"); + return vaip_core::DllSafe(graph_string); +} + Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes, @@ -286,7 +309,14 @@ Model* model_clone(const Model& original_model, int64_t external_data_threshold) cloned_tensor->add_dims(dim); size = size * dim; } - if (size >= external_data_threshold) { + auto ORT_MEM_ADDR_tag = process_ext_address(*original_tensor); + if (!ORT_MEM_ADDR_tag.empty()) { + cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto external_data = cloned_tensor->mutable_external_data(); + auto p = external_data->Add(); + *p->mutable_key() = "location"; + *p->mutable_value() = std::string("<") + graph_ptr; + } else if (size >= external_data_threshold) { cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); auto external_data = cloned_tensor->mutable_external_data(); auto p = external_data->Add(); diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc index bb942c69003a1..2f1478bf1326b 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.cc @@ -10,7 +10,7 @@ namespace vaip { using namespace onnxruntime; -static gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor) { auto tensor_proto = const_cast(&tensor); auto file = std::string(); uintptr_t offset = 0; diff --git a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h index 73015d3411a54..a7c90ac18b44e 100644 --- a/onnxruntime/core/providers/vitisai/imp/tensor_proto.h +++ b/onnxruntime/core/providers/vitisai/imp/tensor_proto.h @@ -37,4 +37,5 @@ ONNX_NAMESPACE::TensorProto* tensor_proto_new_fp16(const std::string& name, cons const std::vector& data); ONNX_NAMESPACE::TensorProto* tensor_proto_new_doubles(const std::string& name, const std::vector& shape, const std::vector& data); +gsl::span process_ext_address(const ONNX_NAMESPACE::TensorProto& tensor); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h index bd8d0229d627c..440b8295da658 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h @@ -14,6 +14,7 @@ Node& graph_add_node(Graph& graph, const std::string& name, const std::string& o const NodeAttributes& attributes, const std::string& domain); void graph_save(const Graph& graph, const std::string& filename, const std::string& dat_filename, size_t initializer_size_threshold); +vaip_core::DllSafe graph_save_string(const Graph& graph); Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes, const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 63949116507e4..acb258894e11c 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (17u) +#define VAIP_ORT_API_MAJOR (18u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -252,10 +252,11 @@ struct OrtApiForVaip { stop); // [103] void (*graph_set_name)(Graph& graph, const char* name); // [104] void (*graph_infer_shapes_from_filepath)( - const std::string& m, const std::string& save_path); // [105] - GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106] - void (*graph_proto_delete)(GraphProto* p); // [107] - void (*graph_infer_shapes)(ModelProto& m); // [108] + const std::string& m, const std::string& save_path); // [105] + GraphProto* (*graph_to_graph_proto)(const Graph& graph); // [106] + void (*graph_proto_delete)(GraphProto* p); // [107] + void (*graph_infer_shapes)(ModelProto& m); // [108] + DllSafe (*graph_save_string)(const Graph& graph); // [109] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index a22d21d8d798b..bdeea726a2cf5 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -491,16 +491,29 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector& sha ss << ","; } - auto alignment = (data_type == ProgramUniformVariableDataType::Float16 && length > 4) ? "@align(16) " : ""; - ss << "\n " << alignment << name << ": "; + // The actual variable type for the uniform variable depends on the data type (T) and length (N). + // + // For T in [i32, u32, f32]: + // - If N == 1, the type is simply i32, u32, or f32. + // - If 2 < N <= 4, the type is vecN, vecN, or vecN where N is the length. + // - If N > 4, the type is array, ceil(N / 4)>. + // + // For T is f16: + // - If N == 1 or N == 2, the type is u32. + // - If 2 < N <= 8, the type is vecX where X is ceil(N / 2). + // - If N > 8, the type is array, X> where X is ceil(N / 8). + // + // Note: Using f16 type in uniforms is not generally supported on all devices. We use a u32 variable to represent + // 2 f16 values. + + if (data_type == ProgramUniformVariableDataType::Float16) { + data_type = ProgramUniformVariableDataType::Uint32; // f16 is represented as u32 + length = (length + 1) / 2; // each u32 can hold 2 f16 values + } + ss << "\n " << name << ": "; if (length > 4) { - if (data_type == ProgramUniformVariableDataType::Float16) { - size_t array_size = (length + 7) / 8; - ss << "array, " << array_size << ">"; - } else { - size_t array_size = (length + 3) / 4; - ss << "array, " << array_size << ">"; - } + size_t array_size = (length + 3) / 4; + ss << "array, " << array_size << ">"; } else if (length > 1) { ss << "vec" << length << "<" << data_type << ">"; } else { diff --git a/onnxruntime/core/providers/webgpu/shader_variable.h b/onnxruntime/core/providers/webgpu/shader_variable.h index 2aba2a59d157f..78c98ab26f5b8 100644 --- a/onnxruntime/core/providers/webgpu/shader_variable.h +++ b/onnxruntime/core/providers/webgpu/shader_variable.h @@ -17,18 +17,34 @@ template || std::is_same_v>> std::string GetElementAt(std::string_view var, const TIdx& idx, TRank rank, bool is_f16 = false) { - // "std::string::rfind(str, 0) == 0" is equivalent to "std::string::starts_with(str)" before C++20. - if (var.rfind("uniforms.", 0) == 0) { - if (rank > 4) { - if constexpr (std::is_integral_v) { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[", idx / 8, "][", (idx % 8) / 4, "][", (idx % 8) % 4, "]"); + if (var.starts_with("uniforms.")) { + if (is_f16) { + if (rank > 8) { + // array, N> + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 8, "][", (idx % 8) / 2, "])[", (idx % 8) % 2, "]"); } else { - return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 8][((", idx, ") % 8) / 2])[((", idx, ") % 8) % 2]"); + } + } else if (rank > 2) { + // vecN + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, "[", idx / 2, "])[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, "[(", idx, ") / 2])[(", idx, ") % 2]"); } } else { - if (is_f16) { - return MakeStringWithClassicLocale(var, "[(", idx, ") / 8][(", idx, ") % 8 / 4][(", idx, ") % 8 % 4]"); + // u32 + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale("bitcast>(", var, ")[", idx % 2, "]"); + } else { + return MakeStringWithClassicLocale("bitcast>(", var, ")[(", idx, ") % 2]"); + } + } + } else { + if (rank > 4) { + if constexpr (std::is_integral_v) { + return MakeStringWithClassicLocale(var, "[", idx / 4, "][", idx % 4, "]"); } else { return MakeStringWithClassicLocale(var, "[(", idx, ") / 4][(", idx, ") % 4]"); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 4bd79a627df22..a9557f7b9aa87 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -373,26 +373,57 @@ Status WebGpuContext::Run(ComputeContext& context, const ProgramBase& program) { continue; } - bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; - - size_t element_size = ProgramUniformVariableDataTypeSize[static_cast(uniform.data_type)]; + // Calculate the size and alignment of the uniform variable. + // // https://www.w3.org/TR/WGSL/#alignof - size_t base_alignment = is_f16 - ? (length > 4 ? 16 : length > 2 ? 8 - : length * element_size) - : (length > 2 ? 16 : length * element_size); - size_t struct_size = is_f16 && length <= 4 ? length * element_size : 16; - - current_offset = (current_offset + base_alignment - 1) / base_alignment * base_alignment; + // + // For f16: + // - length > 8 : array, N> (align 16) (size 16 * N, N = ceil(length / 8)) + // - length == 7 or 8: vec4 (align 16) (size 16) + // - length == 5 or 6: vec3 (align 16) (size 12) + // - length == 3 or 4: vec2 (align 8) (size 8) + // - length == 1 or 2: u32 (align 4) (size 4) + // + // For other types (i32, u32, f32): + // - length > 4 : array, N> (align 16) (size 16 * N, N = ceil(length / 4)) + // - length == 4 : vec4 (align 16) (size 16) + // - length == 3 : vec3 (align 16) (size 12) + // - length == 2 : vec2 (align 8) (size 8) + // - length == 1 : T (align 4) (size 4) + // + + const bool is_f16 = uniform.data_type == ProgramUniformVariableDataType::Float16; + + size_t variable_alignment = 4; // default alignment for scalar types + size_t variable_size = 4; // default size for scalar types + + if (is_f16) { + if (length > 6) { + variable_alignment = 16; + variable_size = 16 * ((length + 7) / 8); + } else if (length > 4) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 2) { + variable_alignment = 8; + variable_size = 8; + } + } else { + if (length > 3) { + variable_alignment = 16; + variable_size = 16 * ((length + 3) / 4); + } else if (length > 2) { + variable_alignment = 16; + variable_size = 12; + } else if (length > 1) { + variable_alignment = 8; + variable_size = 8; + } + } + current_offset = (current_offset + variable_alignment - 1) / variable_alignment * variable_alignment; uniform_and_offsets.emplace_back(uniform, current_offset); - // For non-float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N * SizeOf(vec4). - // For float16 type, when length > 4, the uniform variable is of type array,N>, where - // N = ceil(data.length / 8) and SizeOf(mat2x4) = 16. The total byte length is N * SizeOf(mat2x4). - size_t element_per_struct = is_f16 ? 8 : 4; - current_offset += - length > 4 ? (length + element_per_struct - 1) / element_per_struct * struct_size : length * element_size; + current_offset += variable_size; } // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 6e09f494f4a8d..bca41b7851c28 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -18,8 +18,11 @@ #include "core/framework/data_transfer_manager.h" #include "core/framework/fallback_cpu_capability.h" #include "core/framework/kernel_registry.h" +#include "core/framework/run_options.h" #include "core/graph/function_utils.h" #include "core/graph/indexed_sub_graph.h" +#include "core/session/onnxruntime_run_options_config_keys.h" +#include "core/common/parse_string.h" #include "core/providers/webgpu/webgpu_context.h" #include "core/providers/webgpu/data_transfer.h" @@ -692,7 +695,6 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -907,7 +909,7 @@ Status WebGpuExecutionProvider::OnSessionInitializationEnd() { return Status::OK(); } -Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { if (context_.ValidationMode() >= ValidationMode::Basic) { context_.PushErrorScope(); } @@ -916,20 +918,32 @@ Status WebGpuExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_ context_.StartProfiling(); } - if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { - context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + if (IsGraphCaptureEnabled()) { + auto graph_annotation_str = run_options.config_options.GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + int graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(onnxruntime::TryParseStringWithClassicLocale(*graph_annotation_str, graph_annotation_id), + "Failed to parse the graph annotation id: ", + *graph_annotation_str); + } + + if (graph_annotation_id != -1 && IsGraphCaptureAllowed() && !IsGraphCaptured(graph_annotation_id)) { + context_.CaptureBegin(&captured_commands_, *graph_buffer_mgr_); + } + m_current_graph_annotation_id = graph_annotation_id; } return Status::OK(); } -Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /*run_options*/) { +Status WebGpuExecutionProvider::OnRunEnd(bool /* sync_stream */, const onnxruntime::RunOptions& /* run_options */) { context_.Flush(BufferManager()); - if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { - if (IsGraphCaptureAllowed()) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(m_current_graph_annotation_id)) { + if (m_current_graph_annotation_id != -1 && IsGraphCaptureAllowed()) { context_.CaptureEnd(); is_graph_captured_ = true; + ORT_RETURN_IF_ERROR(ReplayGraph(m_current_graph_annotation_id)); } else { IncrementRegularRunCountBeforeGraphCapture(); } @@ -952,12 +966,12 @@ bool WebGpuExecutionProvider::IsGraphCaptureEnabled() const { return enable_graph_capture_; } -bool WebGpuExecutionProvider::IsGraphCaptured(int) const { - return is_graph_captured_; +bool WebGpuExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return is_graph_captured_ && graph_annotation_id != -1; } -Status WebGpuExecutionProvider::ReplayGraph(int) { - ORT_ENFORCE(IsGraphCaptured(0)); +Status WebGpuExecutionProvider::ReplayGraph(int graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); context_.Replay(captured_commands_, *graph_buffer_mgr_); return Status::OK(); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 2567be2a1eb18..3bbec164a0190 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -99,6 +99,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + int m_current_graph_annotation_id = 0; webgpu::GpuBufferAllocator* allocator_ = nullptr; // Buffer manager specifically for graph capture mode diff --git a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc index 9b287b7b7df99..bc5c755160cb5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_pix_frame_generator.cc @@ -36,6 +36,7 @@ WebGpuPIXFrameGenerator::WebGpuPIXFrameGenerator(wgpu::Instance instance, wgpu:: format = capabilities.formats[0]; wgpu::SurfaceConfiguration config; + config.presentMode = capabilities.presentModes[0]; config.device = device; config.format = format; config.width = kWidth; diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index 17369e6fbc75d..6b5bed4ecdeae 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -35,19 +35,17 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { EM_ASM({ Module.webnnUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); - if (trace) { - console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnUploadTensor")); - } } else { auto webnnDownloadTensor = emscripten::val::module_property("webnnDownloadTensor"); auto subarray = emscripten::typed_memory_view(bytes, static_cast(dst_data)); webnnDownloadTensor(reinterpret_cast(src_data), subarray).await(); - if (trace) { - console.call("timeEnd", emscripten::val("ORT::DataTransfer::webnnDownloadTensor")); - } } } + if (trace) { + console.call("timeEnd", emscripten::val("ORT::DataTransfer::CopyTensor")); + } + return Status::OK(); } diff --git a/onnxruntime/core/providers/xnnpack/math/gemm.cc b/onnxruntime/core/providers/xnnpack/math/gemm.cc index a3ff3b585ae45..9b78e943122de 100644 --- a/onnxruntime/core/providers/xnnpack/math/gemm.cc +++ b/onnxruntime/core/providers/xnnpack/math/gemm.cc @@ -139,7 +139,6 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, // flags - 1 - for no transpose - 0 for transpose uint32_t flags = trans_B_ == CblasTrans ? 0 : XNN_FLAG_TRANSPOSE_WEIGHTS; - auto code_cache = GetCodeCache(); auto weights_cache = GetWeightsCache(); xnn_status status = xnn_status::xnn_status_uninitialized; struct xnn_operator* p = nullptr; @@ -159,7 +158,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, bias_data, // const float* bias, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (op_compute_type_ == OpComputeType::op_compute_type_fp16) { const MLFloat16* bias_data = nullptr; @@ -175,7 +174,7 @@ Status Gemm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr, bias_data, // const float* bias, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } diff --git a/onnxruntime/core/providers/xnnpack/math/matmul.cc b/onnxruntime/core/providers/xnnpack/math/matmul.cc index 9083b9c22f64a..7870bcff298f2 100644 --- a/onnxruntime/core/providers/xnnpack/math/matmul.cc +++ b/onnxruntime/core/providers/xnnpack/math/matmul.cc @@ -102,10 +102,8 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, } #ifdef XNN_CACHE_ENABLE - xnn_code_cache_t code_cache = GetCodeCache(); xnn_weights_cache_t weight_cache = GetWeightsCache(); #else - xnn_code_cache_t code_cache = nullptr; xnn_weights_cache_t weight_cache = nullptr; #endif @@ -122,7 +120,6 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, foutput_min, foutput_max, flags, - code_cache, weight_cache, &p); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { @@ -136,7 +133,6 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, foutput_min, foutput_max, flags, - code_cache, weight_cache, &p); } diff --git a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc index f320274f65db3..963dfa5fa26d7 100644 --- a/onnxruntime/core/providers/xnnpack/nn/average_pool.cc +++ b/onnxruntime/core/providers/xnnpack/nn/average_pool.cc @@ -17,7 +17,6 @@ namespace { Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, const std::optional>& clip_min_max, struct xnn_operator*& p, - const OpQuantParam& quant_param, OpComputeType avgpool_type) { uint32_t input_padding_top = narrow(pool_attrs.pads[0]); uint32_t input_padding_left = narrow(pool_attrs.pads[1]); @@ -48,20 +47,6 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, pooling_height, pooling_width, stride_height, stride_width, foutput_min, foutput_max, flags, &p); - } else if (avgpool_type == OpComputeType::op_compute_type_qu8) { - const float output_scale = quant_param[1].first[0]; - const uint8_t output_zero_point = quant_param[1].second; - const uint8_t output_min = xnn_u8s8_quantize(foutput_min, output_scale, output_zero_point); - const uint8_t output_max = xnn_u8s8_quantize(foutput_max, output_scale, output_zero_point); - status = xnn_create_average_pooling2d_nhwc_qu8(input_padding_top, input_padding_right, - input_padding_bottom, input_padding_left, - pooling_height, pooling_width, - stride_height, stride_width, - quant_param[0].second, - quant_param[0].first[0], - quant_param[1].second, - quant_param[1].first[0], - output_min, output_max, flags, &p); } if (status != xnn_status_success) { @@ -72,9 +57,9 @@ Status CreateXnnpackKernel(const PoolAttributes& pool_attrs, } bool IsQuantAvgPoolSupported(const NodeUnit& node_unit, const GraphViewer& graph) { - TensorQuantType x_input_type = GetTensorQuantType(node_unit, 0, false, graph); - TensorQuantType output_type = GetTensorQuantType(node_unit, 0, true, graph); - return (x_input_type == TensorTypeUint8 && output_type == TensorTypeUint8); + (void)node_unit; + (void)graph; + return false; } bool IsQuantizedAvgPool(QuantizedOpType quant_op_type) { @@ -209,14 +194,10 @@ AveragePool::AveragePool(const OpKernelInfo& info) avgpool_type_ = OpComputeType::op_compute_type_fp32; } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) { avgpool_type_ = OpComputeType::op_compute_type_fp16; - } else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - // the order of input tensor, x,x_scale, x_zp, y_scale, y_zp - quant_param = ParseQuantParamForOp(info, input_dtype, 1); - avgpool_type_ = OpComputeType::op_compute_type_qu8; } struct xnn_operator* p; auto ret = CreateXnnpackKernel(pool_attrs_, clip_min_max_, p, - quant_param, avgpool_type_); + avgpool_type_); ORT_ENFORCE(ret.IsOK(), ret.ErrorMessage()); op0_.reset(p); } @@ -242,23 +223,12 @@ Status AveragePool::Compute(OpKernelContext* context) const { pthreadpool_t threadpool = GetThreadPool(); - // setup allocator/automated dellocate for workspace - size_t workspace_size = 0; - size_t workspace_alignment = 0; - xnn_allocator* allocator = GetStoredAllocator().second; - auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; - - std::unique_ptr workspace(nullptr, deallocator); - auto reshape_fn = xnn_reshape_average_pooling2d_nhwc_f32; if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { reshape_fn = xnn_reshape_average_pooling2d_nhwc_f16; - } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { - reshape_fn = xnn_reshape_average_pooling2d_nhwc_qu8; } auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, - &workspace_size, &workspace_alignment, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); @@ -267,17 +237,12 @@ Status AveragePool::Compute(OpKernelContext* context) const { " returned ", status); } - workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); - if (avgpool_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); + status = xnn_setup_average_pooling2d_nhwc_f32(op0_.get(), X.Data(), + Y.MutableData()); } else if (avgpool_type_ == OpComputeType::op_compute_type_fp16) { - status = xnn_setup_average_pooling2d_nhwc_f16(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); - } else if (avgpool_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_average_pooling2d_nhwc_qu8(op0_.get(), workspace.get(), - X.Data(), Y.MutableData()); + status = xnn_setup_average_pooling2d_nhwc_f16(op0_.get(), X.Data(), + Y.MutableData()); } if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv.cc b/onnxruntime/core/providers/xnnpack/nn/conv.cc index 4e6b308e28ae5..3ef0c1a7cf495 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv.cc @@ -91,7 +91,6 @@ Status Conv::Compute(OpKernelContext* context) const { // setup allocator/automated dellocate for workspace size_t workspace_size = 0; - size_t workspace_alignment = 0; xnn_allocator* allocator = GetStoredAllocator().second; auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; std::unique_ptr workspace(nullptr, deallocator); @@ -108,7 +107,7 @@ Status Conv::Compute(OpKernelContext* context) const { } auto status = reshape_fn(op0_.get(), N, H, W, - &workspace_size, &workspace_alignment, + &workspace_size, /*output_height_out=*/nullptr, /*output_width_out=*/nullptr, threadpool); if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc index 44962c1796631..9742f397315a7 100644 --- a/onnxruntime/core/providers/xnnpack/nn/conv_base.cc +++ b/onnxruntime/core/providers/xnnpack/nn/conv_base.cc @@ -24,7 +24,6 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, const std::optional>& clip_min_max, const Tensor& Weight, const Tensor* Bias, XnnpackOperator& op_uptr, - xnn_code_cache_t code_cache, xnn_weights_cache_t weights_cache, const OpQuantParam& quant_param, OpComputeType conv_type, @@ -79,7 +78,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, C, M, // input channel stride, output channel stride Weight.Data(), B_data, foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_fp16) { const auto* B_data = Bias ? Bias->Data() : nullptr; @@ -97,7 +96,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, Weight.Data(), B_data, // kernel, bias foutput_min, foutput_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8) { const float output_scale = quant_param[2].first[0]; @@ -121,7 +120,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qs8_per_channel) { auto* B_data = Bias ? Bias->Data() : nullptr; @@ -145,7 +144,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } else if (conv_type == OpComputeType::op_compute_type_qu8) { const auto* B_data = Bias ? Bias->Data() : nullptr; @@ -170,7 +169,7 @@ Status CreateXnnpackKernel(const ConvAttributes& conv_attrs, quant_param[2].second, quant_param[2].first[0], output_min, output_max, flags, - code_cache, weights_cache, + weights_cache, &p); } @@ -521,7 +520,7 @@ ConvBase::ConvBase(const OpKernelInfo& info, bool is_transpose) Status ConvBase::CreateKernel() { auto ret = CreateXnnpackKernel(convbase_attrs_ref_, C_, M_, kernel_shape_, clip_min_max_, packed_w_, B_, op0_, - GetCodeCache(), GetWeightsCache(), + GetWeightsCache(), quant_param_, conv_type_, is_transpose_); return ret; } diff --git a/onnxruntime/core/providers/xnnpack/tensor/resize.cc b/onnxruntime/core/providers/xnnpack/tensor/resize.cc index 0bb1194643743..32d91084d3507 100644 --- a/onnxruntime/core/providers/xnnpack/tensor/resize.cc +++ b/onnxruntime/core/providers/xnnpack/tensor/resize.cc @@ -228,13 +228,13 @@ Resize::Resize(const OpKernelInfo& info) : UpsampleBase(info), XnnpackKernel{inf auto out_h = output_dims_[1]; auto out_w = output_dims_[2]; if (op_type_ == OpComputeType::op_compute_type_fp32) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f32(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_fp32, out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { - xstatus = xnn_create_resize_bilinear2d_nhwc_f16(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_fp16, out_h, out_w, flags, &p); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - xstatus = xnn_create_resize_bilinear2d_nhwc_u8(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_quint8, out_h, out_w, flags, &p); } else { - xstatus = xnn_create_resize_bilinear2d_nhwc_s8(out_h, out_w, flags, &p); + xstatus = xnn_create_resize_bilinear2d_nhwc(xnn_datatype_qint8, out_h, out_w, flags, &p); } ORT_ENFORCE(xstatus == xnn_status_success, "xnn_create_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " failed. Status:", @@ -257,22 +257,14 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, // setup allocator/automated dellocate for workspace size_t workspace_size = 0; - size_t workspace_alignment = 0; xnn_allocator* allocator = GetStoredAllocator().second; auto deallocator = [allocator](void* ptr) { allocator->aligned_deallocate(allocator->context, ptr); }; std::unique_ptr workspace(nullptr, deallocator); - auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f32; - if (op_type_ == OpComputeType::op_compute_type_fp16) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_f16; - } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_u8; - } else if (op_type_ == OpComputeType::op_compute_type_qs8) { - reshape_fn = xnn_reshape_resize_bilinear2d_nhwc_s8; - } + auto reshape_fn = xnn_reshape_resize_bilinear2d_nhwc; auto status = reshape_fn(op0_.get(), N, H, W, C, C, C, - &workspace_size, &workspace_alignment, threadpool); + &workspace_size, threadpool); if (status != xnn_status_success) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "xnn_reshape_resize_bilinear2d_nhwc_", OpTypeToString(op_type_), " returned ", status); @@ -281,17 +273,17 @@ Status Resize::ComputeInternal(OpKernelContext* ctx, const Tensor* input, workspace.reset(allocator->aligned_allocate(allocator->context, XNN_ALLOCATION_ALIGNMENT, workspace_size)); if (op_type_ == OpComputeType::op_compute_type_fp32) { - status = xnn_setup_resize_bilinear2d_nhwc_f32(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_fp16) { - status = xnn_setup_resize_bilinear2d_nhwc_f16(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else if (op_type_ == OpComputeType::op_compute_type_qu8) { - status = xnn_setup_resize_bilinear2d_nhwc_u8(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } else { - status = xnn_setup_resize_bilinear2d_nhwc_s8(op0_.get(), workspace.get(), input->Data(), - output->MutableData()); + status = xnn_setup_resize_bilinear2d_nhwc(op0_.get(), workspace.get(), input->Data(), + output->MutableData()); } if (status != xnn_status_success) { diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h index 31512586be19d..1779f51046c59 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h +++ b/onnxruntime/core/providers/xnnpack/xnnpack_kernel.h @@ -24,8 +24,6 @@ class XnnpackKernel : public OpKernel { } // see comment below about enabling code cache - // xnn_code_cache_t GetCodeCache() { return caches_.auto_code_cache.get();} - xnn_code_cache_t GetCodeCache() { return nullptr; } xnn_weights_cache_t GetWeightsCache() { return caches_.auto_weights_cache.get(); } private: @@ -42,11 +40,6 @@ class XnnpackKernel : public OpKernel { if (enable) { #ifdef XNN_CACHE_ENABLE xnn_status status = xnn_status_success; -#if XNN_PLATFORM_JIT - // status = xnn_init_code_cache(&code_cache_); - // ORT_ENFORCE(status == xnn_status_success, "Failed to initialize XNNPACK code cache");) - // auto_code_cache.reset(&code_cache_); -#endif // status = xnn_init_weights_cache(&weights_cache_); xnn_weights_cache_t weights_cache_provider = nullptr; status = xnn_create_weights_cache(&weights_cache, 0); diff --git a/onnxruntime/core/session/environment.cc b/onnxruntime/core/session/environment.cc index 450a8bad09392..2b553aecbca6c 100644 --- a/onnxruntime/core/session/environment.cc +++ b/onnxruntime/core/session/environment.cc @@ -16,10 +16,10 @@ #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/inference_session.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_library_internal.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/ort_apis.h" #include "core/session/utils.h" diff --git a/onnxruntime/core/session/ep_library_internal.cc b/onnxruntime/core/session/ep_library_internal.cc deleted file mode 100644 index 986ccb1fa17fc..0000000000000 --- a/onnxruntime/core/session/ep_library_internal.cc +++ /dev/null @@ -1,281 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_internal.h" - -#include "core/framework/error_code_helper.h" -#include "core/framework/ortmemoryinfo.h" -#include "core/framework/session_options.h" -#include "core/providers/cpu/cpu_execution_provider.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_logger.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api.h" -#include "core/session/ort_apis.h" - -#if defined(USE_DML) -#include "core/providers/dml/dml_provider_factory_creator.h" -#endif - -#if defined(USE_WEBGPU) -#include "core/providers/webgpu/webgpu_provider_factory_creator.h" -#endif - -namespace onnxruntime { - -class CpuEpFactory : public EpFactoryInternalImpl { - public: - CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { - ORT_API_RETURN_IF_ERROR( - OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "CPU EP factory currently only supports one device at a time."); - } - - CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; - *ep = std::make_unique(epi); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } -}; - -std::unique_ptr EpLibraryInternal::CreateCpuEp() { - auto cpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} - -#if defined(USE_DML) -class DmlEpFactory : public EpFactoryInternalImpl { - public: - DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - std::unique_ptr ep_options; - - // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is - // associated with a specific device. - // How would we know what options should not allow user overrides if set in OrtEpDevice? - int32_t device_id = 0; // If no device_id was found default to 0 - if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { - ep_options = std::make_unique(); - device_id = std::stoi(it->second); - } - - ep_options->Add("device_id", std::to_string(device_id)); - - auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, ep_options.get(), - &ep_devices[num_ep_devices]); - - if (device_memory_infos.size() < device_id + 1) { - device_memory_infos.resize(device_id + 1); - device_allocators.resize(device_id + 1); - } - - if (device_memory_infos[device_id] == nullptr) { - // Create memory info for the device if it doesn't already exist - device_memory_infos[device_id] = std::make_unique( - "DML", OrtAllocatorType::OrtDeviceAllocator, - OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, - narrow(device_id))); - } - - // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. - // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], - // device_memory_infos[device_id].get()); - - if (api_status != nullptr) { - return api_status; - } - - ++num_ep_devices; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "DML EP factory currently only supports one device at a time."); - } - - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, - ep_options); - - *ep = dml_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* /*memory_info*/, - const OrtKeyValuePairs* /*allocator_options*/, - OrtAllocator** allocator) noexcept override { - // TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That - // requires pulling lots of things out of the DML EP to get the D3D12 device and create a - // BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp - //*allocator = device_allocators[memory_info->device.Id()].get(); - *allocator = nullptr; - return nullptr; - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - - std::vector> device_memory_infos; // memory info for each device - std::vector> device_allocators; // allocators for each device -}; - -std::unique_ptr EpLibraryInternal::CreateDmlEp() { - auto dml_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(dml_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -#if defined(USE_WEBGPU) -class WebGpuEpFactory : public EpFactoryInternalImpl { - public: - WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* p_num_ep_devices) noexcept override { - size_t& num_ep_devices = *p_num_ep_devices; - num_ep_devices = 0; - - for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { - const OrtHardwareDevice& device = *devices[i]; - if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { - // TODO: any metadata or options to add? - ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, - &device, nullptr, nullptr, - &ep_devices[num_ep_devices++])); - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - *ep = nullptr; - - if (num_devices != 1) { - return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, - "WebGPU EP factory currently only supports one device at a time."); - } - - auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); - *ep = webgpu_ep_factory->CreateProvider(); - (*ep)->SetLogger(session_logger->ToInternal()); - - return nullptr; - } - - /* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of - an InferenceSession. - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - *allocator = device_allocators[memory_info->device.Id()].get(); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. - *data_transfer = nullptr; - return nullptr; - } - */ -}; - -std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { - auto webgpu_factory_impl = std::make_unique(); - auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); - return std::make_unique(std::move(internal_factory)); -} -#endif - -std::vector> EpLibraryInternal::CreateInternalEps() { - std::vector> internal_eps; - internal_eps.reserve(4); - - // CPU EP - internal_eps.push_back(CreateCpuEp()); - -#if defined(USE_WEBGPU) - internal_eps.push_back(CreateWebGpuEp()); -#endif - -#if defined(USE_DML) - internal_eps.push_back(CreateDmlEp()); -#endif - - return internal_eps; -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.cc b/onnxruntime/core/session/ep_library_provider_bridge.cc deleted file mode 100644 index ae553891beaa7..0000000000000 --- a/onnxruntime/core/session/ep_library_provider_bridge.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/session/ep_library_provider_bridge.h" - -#include "core/common/status.h" -#include "core/framework/error_code_helper.h" -#include "core/framework/session_options.h" -#include "core/providers/cuda/cuda_provider_options.h" -#include "core/providers/shared_library/provider_host_api.h" -#include "core/session/abi_devices.h" -#include "core/session/abi_session_options_impl.h" -#include "core/session/ep_factory_internal.h" - -namespace onnxruntime { -class ProviderBridgeEpFactory : public EpFactoryInternalImpl { - public: - ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) - : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), - ep_factory.GetVendor(&ep_factory), - ep_factory.GetVendorId(&ep_factory)), - ep_factory_{ep_factory}, - provider_library_{provider_library} { - } - - private: - OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - const OrtHardwareDevice* const* devices, - size_t num_devices, - OrtEpDevice** ep_devices, - size_t max_ep_devices, - size_t* num_ep_devices) noexcept override { - ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, - max_ep_devices, num_ep_devices)); - - // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. - for (size_t i = 0; i < *num_ep_devices; ++i) { - auto* ep_device = ep_devices[i]; - if (ep_device) { - ep_device->ep_factory = &ep_factory; - } - } - - return nullptr; - } - - OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, - const OrtKeyValuePairs* const* ep_metadata_pairs, - size_t num_devices, - const OrtSessionOptions* session_options, - const OrtLogger* session_logger, - std::unique_ptr* ep) noexcept override { - // get the provider specific options - auto ep_options = GetOptionsFromSessionOptions(session_options->value); - auto& provider = provider_library_.Get(); - - auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, - ep_options, *session_options, *session_logger, *ep); - - return ToOrtStatus(status); - } - - OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, - const OrtKeyValuePairs* allocator_options, - OrtAllocator** allocator) noexcept override { - return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); - } - - void ReleaseAllocator(OrtAllocator* allocator) noexcept override { - ep_factory_.ReleaseAllocator(&ep_factory_, allocator); - } - - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { - return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); - } - - bool IsStreamAware() const noexcept override { - return ep_factory_.IsStreamAware(&ep_factory_); - } - - OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, - const OrtKeyValuePairs* stream_options, - OrtSyncStreamImpl** stream) noexcept override { - return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); - } - - OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP - ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP -}; - -Status EpLibraryProviderBridge::Load() { - std::lock_guard lock{mutex_}; - - if (!factories_.empty()) { - // already loaded - return Status::OK(); - } - - // if we have been unloaded we can't just be reloaded. - if (!ep_library_plugin_ || !provider_library_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "EpLibraryProviderBridge has been unloaded. " - "Please create a new instance using LoadPluginOrProviderBridge."); - } - - // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. - // use GetSupportedDevices from the library's factory. - // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. - // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can - // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. - for (const auto& factory : ep_library_plugin_->GetFactories()) { - auto factory_impl = std::make_unique(*factory, *provider_library_); - auto internal_factory = std::make_unique(std::move(factory_impl)); - - factory_ptrs_.push_back(internal_factory.get()); - internal_factory_ptrs_.push_back(internal_factory.get()); - factories_.push_back(std::move(internal_factory)); - } - - return Status::OK(); -} - -Status EpLibraryProviderBridge::Unload() { - std::lock_guard lock{mutex_}; - - internal_factory_ptrs_.clear(); - factory_ptrs_.clear(); - factories_.clear(); - - // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. - ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); - ep_library_plugin_ = nullptr; - - provider_library_->Unload(); - provider_library_ = nullptr; - - return Status::OK(); -} - -} // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index f4f76a389030e..e0542768aef2f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2016,7 +2016,7 @@ common::Status InferenceSession::Initialize() { ORT_TRY { LOGS(*session_logger_, INFO) << "Initializing session."; const Env& env = Env::Default(); - env.GetTelemetryProvider().LogSessionCreationStart(); + env.GetTelemetryProvider().LogSessionCreationStart(session_id_); bool have_cpu_ep = false; @@ -2955,7 +2955,7 @@ Status InferenceSession::Run(const RunOptions& run_options, } // log evaluation start to trace logging provider - env.GetTelemetryProvider().LogEvaluationStart(); + env.GetTelemetryProvider().LogEvaluationStart(session_id_); ORT_RETURN_IF_ERROR_SESSIONID_(ValidateInputs(feed_names, feeds)); ORT_RETURN_IF_ERROR_SESSIONID_(ValidateOutputs(output_names, p_fetches)); @@ -3108,7 +3108,7 @@ Status InferenceSession::Run(const RunOptions& run_options, } // log evaluation stop to trace logging provider - env.GetTelemetryProvider().LogEvaluationStop(); + env.GetTelemetryProvider().LogEvaluationStop(session_id_); // send out profiling events (optional) if (session_profiler_.IsEnabled()) { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 27f81b18be0c9..4c7b4d7b29c2f 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -38,8 +38,8 @@ #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" -#include "core/session/ep_api.h" -#include "core/session/ep_library_internal.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/IOBinding.h" @@ -2993,7 +2993,8 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributes, _In_ const OrtNode* node, API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, _Outptr_ const OrtOpAttr** attribute) { +ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, + _Outptr_result_maybenull_ const OrtOpAttr** attribute) { API_IMPL_BEGIN if (attribute == nullptr) { return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'attribute' argument is NULL"); @@ -3004,14 +3005,30 @@ ORT_API_STATUS_IMPL(OrtApis::Node_GetAttributeByName, _In_ const OrtNode* node, return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Node_GetAttributeByName."); } - *attribute = ep_node->GetAttribute(attribute_name); + bool is_unset_optional_attr = false; + *attribute = ep_node->GetAttribute(attribute_name, is_unset_optional_attr); - if (*attribute) { + if (*attribute || is_unset_optional_attr) { return nullptr; } else { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Attribute does not exist."); + std::ostringstream oss; + oss << "Node attribute does not exist: " << attribute_name; + return OrtApis::CreateStatus(OrtErrorCode::ORT_NOT_FOUND, oss.str().c_str()); + } + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, _Outptr_result_maybenull_ OrtValue** attr_tensor) { + API_IMPL_BEGIN + if (attr_tensor == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attr_tensor argument is null"); + } + if (attribute == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "attribute argument is null"); } + ORT_API_RETURN_IF_STATUS_NOT_OK(node->GetTensorAttributeAsOrtValue(attribute, *attr_tensor)); + return nullptr; API_IMPL_END } @@ -3052,6 +3069,10 @@ ORT_API_STATUS_IMPL(OrtApis::OpAttr_GetType, _In_ const OrtOpAttr* attribute, _O *type = OrtOpAttrType::ORT_OP_ATTR_GRAPH; break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + *type = OrtOpAttrType::ORT_OP_ATTR_TENSOR; + break; + } default: return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "Unexpected attribute type."); } @@ -4034,6 +4055,7 @@ static constexpr OrtApi ort_api_1_to_23 = { &OrtApis::Node_GetNumAttributes, &OrtApis::Node_GetAttributes, &OrtApis::Node_GetAttributeByName, + &OrtApis::Node_GetTensorAttributeAsOrtValue, &OrtApis::OpAttr_GetType, &OrtApis::OpAttr_GetName, &OrtApis::Node_GetNumSubgraphs, diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index d2f22397bf82c..3eee174ff81f4 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -678,7 +678,9 @@ ORT_API_STATUS_IMPL(Node_GetNumAttributes, _In_ const OrtNode* node, _Out_ size_ ORT_API_STATUS_IMPL(Node_GetAttributes, _In_ const OrtNode* node, _Out_writes_(num_attributes) const OrtOpAttr** attributes, _In_ size_t num_attributes); ORT_API_STATUS_IMPL(Node_GetAttributeByName, _In_ const OrtNode* node, _In_ const char* attribute_name, - _Outptr_ const OrtOpAttr** attribute); + _Outptr_result_maybenull_ const OrtOpAttr** attribute); +ORT_API_STATUS_IMPL(Node_GetTensorAttributeAsOrtValue, _In_ const OrtNode* node, _In_ const OrtOpAttr* attribute, + _Outptr_result_maybenull_ OrtValue** attr_tensor); ORT_API_STATUS_IMPL(OpAttr_GetType, _In_ const OrtOpAttr* attribute, _Out_ OrtOpAttrType* type); ORT_API_STATUS_IMPL(OpAttr_GetName, _In_ const OrtOpAttr* attribute, _Outptr_ const char** name); ORT_API_STATUS_IMPL(Node_GetNumSubgraphs, _In_ const OrtNode* node, _Out_ size_t* num_subgraphs); diff --git a/onnxruntime/core/session/ep_api.cc b/onnxruntime/core/session/plugin_ep/ep_api.cc similarity index 99% rename from onnxruntime/core/session/ep_api.cc rename to onnxruntime/core/session/plugin_ep/ep_api.cc index 8fd1fc198374f..cae0b086af66c 100644 --- a/onnxruntime/core/session/ep_api.cc +++ b/onnxruntime/core/session/plugin_ep/ep_api.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_api.h" +#include "core/session/plugin_ep/ep_api.h" #include #include diff --git a/onnxruntime/core/session/ep_api.h b/onnxruntime/core/session/plugin_ep/ep_api.h similarity index 100% rename from onnxruntime/core/session/ep_api.h rename to onnxruntime/core/session/plugin_ep/ep_api.h diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc new file mode 100644 index 0000000000000..7e6d0dd2ae5df --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.cc @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_cpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/graph/constants.h" +#include "core/providers/cpu/cpu_execution_provider.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* CpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_CPU) { + ORT_API_RETURN_IF_ERROR( + OrtExecutionProviderApi::CreateEpDevice(&ep_factory, &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* CpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "CPU EP factory currently only supports one device at a time."); + } + + CPUExecutionProviderInfo epi{session_options->value.enable_cpu_mem_arena}; + *ep = std::make_unique(epi); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h new file mode 100644 index 0000000000000..fba9bac976bb2 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_cpu.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class CpuEpFactory : public EpFactoryInternalImpl { + public: + CpuEpFactory() : EpFactoryInternalImpl(kCpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc new file mode 100644 index 0000000000000..2f12ffa394537 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_dml.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/dml/dml_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* DmlEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + auto ep_options = std::make_unique(); + + // TODO: Should we ignore a user provided 'device_id' when they select an OrtEpDevice as that is + // associated with a specific device. + // How would we know what options should not allow user overrides if set in OrtEpDevice? + int32_t device_id = 0; // If no device_id was found default to 0 + if (auto it = device.metadata.Entries().find("DxgiAdapterNumber"); it != device.metadata.Entries().end()) { + device_id = std::stoi(it->second); + } + + ep_options->Add("device_id", std::to_string(device_id)); + + auto* api_status = OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, ep_options.get(), + &ep_devices[num_ep_devices]); + + if (device_memory_infos.size() < device_id + 1) { + device_memory_infos.resize(device_id + 1); + device_allocators.resize(device_id + 1); + } + + if (device_memory_infos[device_id] == nullptr) { + // Create memory info for the device if it doesn't already exist + device_memory_infos[device_id] = std::make_unique( + "DML", OrtAllocatorType::OrtDeviceAllocator, + OrtDevice(OrtDevice::DML, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::MICROSOFT, + narrow(device_id))); + } + + // This is what we need to add once CreateAllocator is implemented to create a shared allocator for the device. + // OrtExecutionProviderApi::EpDevice_AddAllocatorInfo(ep_devices[num_ep_devices], + // device_memory_infos[device_id].get()); + + if (api_status != nullptr) { + return api_status; + } + + ++num_ep_devices; + } + } + + return nullptr; +} + +OrtStatus* DmlEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "DML EP factory currently only supports one device at a time."); + } + + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto dml_ep_factory = DMLProviderFactoryCreator::CreateFromProviderOptions(session_options->value.config_options, + ep_options); + + *ep = dml_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* +// TODO: This needs to create an allocator for the specific device so it's available as a shared allocator. That +// requires pulling lots of things out of the DML EP to get the D3D12 device and create a +// BucketizedBufferAllocator. See providers\dml\DmlExecutionProvider\src\ExecutionProvider.cpp +OrtStatus* DmlEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept { +} + +// TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. +OrtStatus* DmlEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { +} +*/ +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_dml.h b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h new file mode 100644 index 0000000000000..1cdd172901942 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_dml.h @@ -0,0 +1,40 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_DML) + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class DmlEpFactory : public EpFactoryInternalImpl { + public: + DmlEpFactory() : EpFactoryInternalImpl(kDmlExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + std::vector> device_memory_infos; // memory info for each device + std::vector> device_allocators; // allocators for each device +}; + +} // namespace onnxruntime + +#endif // USE_DML diff --git a/onnxruntime/core/session/ep_factory_internal.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc similarity index 58% rename from onnxruntime/core/session/ep_factory_internal.cc rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.cc index 9804aa6a5c42d..3610b0f797a46 100644 --- a/onnxruntime/core/session/ep_factory_internal.cc +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.cc @@ -1,18 +1,16 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" -#include "core/session/ep_api_utils.h" +#include "core/session/plugin_ep/forward_to_factory_impl.h" #include "core/session/ort_apis.h" -#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { - -using Forward = ForwardToFactory; +using Forward = ForwardToFactoryImpl; EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl) : impl_{std::move(impl)} { @@ -32,38 +30,6 @@ EpFactoryInternal::EpFactoryInternal(std::unique_ptr impl OrtEpFactory::CreateSyncStreamForDevice = Forward::CreateSyncStreamForDevice; } -const char* EpFactoryInternal::GetVersion() const noexcept { - return ORT_VERSION; -} - -OrtStatus* EpFactoryInternal::CreateEp(const OrtHardwareDevice* const* /*devices*/, - const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, - size_t /*num_devices*/, - const OrtSessionOptions* /*api_session_options*/, - const OrtLogger* /*api_logger*/, - OrtEp** /*ep*/) { - ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); -} - -// Prior to addition to SessionOptions the EP options do not have a prefix. -// They are prefixed with 'ep..' when added to SessionOptions. -// -// Use this function to get the options without the prefix from SessionOptions. -// Required by the option parsing for multiple existing EPs. -ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { - const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); - ProviderOptions ep_options; - - for (const auto& [key, value] : session_options.config_options.configurations) { - if (key.find(option_prefix) == 0) { - // remove the prefix and add - ep_options[key.substr(option_prefix.length())] = value; - } - } - - return ep_options; -} - InternalExecutionProviderFactory::InternalExecutionProviderFactory(EpFactoryInternal& ep_factory, gsl::span ep_devices) : ep_factory_{ep_factory} { diff --git a/onnxruntime/core/session/ep_factory_internal.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h similarity index 50% rename from onnxruntime/core/session/ep_factory_internal.h rename to onnxruntime/core/session/plugin_ep/ep_factory_internal.h index ae450efa394e8..0e34fef0ff74c 100644 --- a/onnxruntime/core/session/ep_factory_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal.h @@ -7,85 +7,16 @@ #include #include "core/common/common.h" -#include "core/framework/execution_provider.h" #include "core/providers/providers.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/ort_apis.h" +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "onnxruntime_config.h" // for ORT_VERSION namespace onnxruntime { -class EpFactoryInternal; -class EpLibraryInternal; struct SessionOptions; - -// class with virtual methods that are implemented for each internal EP -class EpFactoryInternalImpl { - public: - EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) - : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { - } - - const char* GetName() const noexcept { return ep_name_.c_str(); } - const char* GetVendor() const noexcept { return vendor_.c_str(); } - uint32_t GetVendorId() const noexcept { return vendor_id_; } - const char* GetVersion() const noexcept; - - virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices) noexcept = 0; - - virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, - _Out_ std::unique_ptr* ep) = 0; - - virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, - _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, - _Outptr_ OrtAllocator** allocator) noexcept { - // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned - // so this should never be called - *allocator = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); - } - - virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { - // we don't create any allocators so we don't need to release any - } - - virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { - *data_transfer = nullptr; - return nullptr; // Default implementation does nothing - } - - virtual bool IsStreamAware() const { - return false; - } - - virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, - _In_opt_ const OrtKeyValuePairs* /*stream_options*/, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { - *stream = nullptr; - return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, - "CreateSyncStreamForDevice is not implemented for this EP factory."); - } - - // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* ep); - - virtual ~EpFactoryInternalImpl() = default; - - protected: - ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; - - private: - const std::string ep_name_; // EP name library was registered with - const std::string vendor_; // EP vendor name - const uint32_t vendor_id_; // EP vendor ID -}; +class EpFactoryInternalImpl; // this class can't have any virtual methods as they break using it as an OrtEpFactory* in OrtEpDevice. class EpFactoryInternal : public OrtEpFactory { @@ -95,7 +26,7 @@ class EpFactoryInternal : public OrtEpFactory { const char* GetName() const noexcept { return impl_->GetName(); } const char* GetVendor() const noexcept { return impl_->GetVendor(); } uint32_t GetVendorId() const noexcept { return impl_->GetVendorId(); } - const char* GetVersion() const noexcept; + const char* GetVersion() const noexcept { return ORT_VERSION; } OrtStatus* GetSupportedDevices(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, _In_ size_t num_devices, @@ -106,11 +37,14 @@ class EpFactoryInternal : public OrtEpFactory { } // we don't implement this. CreateIExecutionProvider should be used. - OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Out_ OrtEp** ep); + OrtStatus* CreateEp(_In_reads_(num_devices) const OrtHardwareDevice* const* /*devices*/, + _In_reads_(num_devices) const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + _In_ size_t /*num_devices*/, + _In_ const OrtSessionOptions* /*session_options*/, + _In_ const OrtLogger* /*logger*/, + _Out_ OrtEp** /*ep*/) { + ORT_THROW("Internal error. CreateIExecutionProvider should be used for EpFactoryInternal."); + } // same input args as CreateEp in case we need something from device or ep_metadata_pairs in the future. OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, @@ -132,24 +66,23 @@ class EpFactoryInternal : public OrtEpFactory { return impl_->ReleaseAllocator(allocator); } - OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) { + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { return impl_->CreateDataTransfer(data_transfer); } - bool IsStreamAware() const { + bool IsStreamAware() const noexcept { return impl_->IsStreamAware(); } OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* memory_device, _In_opt_ const OrtKeyValuePairs* stream_options, - _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) { + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { return impl_->CreateSyncStreamForDevice(memory_device, stream_options, stream); } // Function ORT calls to release an EP instance. - void ReleaseEp(OrtEp* /*ep*/) { + void ReleaseEp(OrtEp* /*ep*/) noexcept { // we never create an OrtEp so we should never be trying to release one - ORT_THROW("Internal error. No ReleaseEp call is required for EpFactoryInternal."); } private: diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc new file mode 100644 index 0000000000000..e61804d842859 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.cc @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" + +namespace onnxruntime { + +// Prior to addition to SessionOptions the EP options do not have a prefix. +// They are prefixed with 'ep..' when added to SessionOptions. +// +// Use this function to get the options without the prefix from SessionOptions. +// Required by the option parsing for multiple existing EPs. +ProviderOptions EpFactoryInternalImpl::GetOptionsFromSessionOptions(const SessionOptions& session_options) const { + const std::string option_prefix = OrtSessionOptions::GetProviderOptionPrefix(GetName()); + ProviderOptions ep_options; + + for (const auto& [key, value] : session_options.config_options.configurations) { + if (key.find(option_prefix) == 0) { + // remove the prefix and add + ep_options[key.substr(option_prefix.length())] = value; + } + } + + return ep_options; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h new file mode 100644 index 0000000000000..bd0b76b21511f --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/execution_provider.h" +#include "core/framework/provider_options.h" +#include "core/session/onnxruntime_c_api.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { +class EpFactoryInternal; +struct SessionOptions; + +// class with virtual methods that are implemented for each internal EP +class EpFactoryInternalImpl { + public: + EpFactoryInternalImpl(const std::string& ep_name, const std::string& vendor, uint32_t vendor_id) + : ep_name_(ep_name), vendor_(vendor), vendor_id_(vendor_id) { + } + + const char* GetName() const noexcept { return ep_name_.c_str(); } + const char* GetVendor() const noexcept { return vendor_.c_str(); } + uint32_t GetVendorId() const noexcept { return vendor_id_; } + const char* GetVersion() const noexcept; + + virtual OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices) noexcept = 0; + + virtual OrtStatus* CreateIExecutionProvider(_In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, + _Out_ std::unique_ptr* ep) = 0; + + virtual OrtStatus* CreateAllocator(_In_ const OrtMemoryInfo* /*memory_info*/, + _In_opt_ const OrtKeyValuePairs* /*allocator_options*/, + _Outptr_ OrtAllocator** allocator) noexcept { + // default implementation does not add OrtMemoryInfo to OrtEpDevice instances returned + // so this should never be called + *allocator = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "CreateAllocator is not implemented for this EP factory."); + } + + virtual void ReleaseAllocator(_In_ OrtAllocator* /*allocator*/) noexcept { + // we don't create any allocators so we don't need to release any + } + + virtual OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept { + *data_transfer = nullptr; + return nullptr; // Default implementation does nothing + } + + virtual bool IsStreamAware() const noexcept { + return false; + } + + virtual OrtStatus* CreateSyncStreamForDevice(_In_ const OrtMemoryDevice* /*memory_device*/, + _In_opt_ const OrtKeyValuePairs* /*stream_options*/, + _Outptr_result_maybenull_ OrtSyncStreamImpl** stream) noexcept { + *stream = nullptr; + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, + "CreateSyncStreamForDevice is not implemented for this EP factory."); + } + + // Function ORT calls to release an EP instance. + void ReleaseEp(OrtEp* ep); + + virtual ~EpFactoryInternalImpl() = default; + + protected: + ProviderOptions GetOptionsFromSessionOptions(const SessionOptions& session_options) const; + + private: + const std::string ep_name_; // EP name library was registered with + const std::string vendor_; // EP vendor name + const uint32_t vendor_id_; // EP vendor ID +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc new file mode 100644 index 0000000000000..d6e51a44c1c69 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.cc @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +#include "core/providers/shared_library/provider_host_api.h" + +namespace onnxruntime { +OrtStatus* ProviderBridgeEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept { + ORT_API_RETURN_IF_ERROR(ep_factory_.GetSupportedDevices(&ep_factory_, devices, num_devices, ep_devices, + max_ep_devices, num_ep_devices)); + + // add the EpFactoryInternal layer back in so that we can redirect to CreateIExecutionProvider. + for (size_t i = 0; i < *num_ep_devices; ++i) { + auto* ep_device = ep_devices[i]; + if (ep_device) { + ep_device->ep_factory = &ep_factory; + } + } + + return nullptr; +} + +OrtStatus* ProviderBridgeEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + // get the provider specific options + auto ep_options = GetOptionsFromSessionOptions(session_options->value); + auto& provider = provider_library_.Get(); + + auto status = provider.CreateIExecutionProvider(devices, ep_metadata_pairs, num_devices, + ep_options, *session_options, *session_logger, *ep); + + return ToOrtStatus(status); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h new file mode 100644 index 0000000000000..437af62dc2c0c --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/error_code_helper.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/provider_bridge_library.h" + +namespace onnxruntime { +class ProviderBridgeEpFactory : public EpFactoryInternalImpl { + public: + ProviderBridgeEpFactory(OrtEpFactory& ep_factory, ProviderLibrary& provider_library) + : EpFactoryInternalImpl(ep_factory.GetName(&ep_factory), + ep_factory.GetVendor(&ep_factory), + ep_factory.GetVendorId(&ep_factory)), + ep_factory_{ep_factory}, + provider_library_{provider_library} { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; + + OrtStatus* CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + return ep_factory_.CreateAllocator(&ep_factory_, memory_info, allocator_options, allocator); + } + + void ReleaseAllocator(OrtAllocator* allocator) noexcept override { + ep_factory_.ReleaseAllocator(&ep_factory_, allocator); + } + + OrtStatus* CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) noexcept override { + return ep_factory_.CreateDataTransfer(&ep_factory_, data_transfer); + } + + bool IsStreamAware() const noexcept override { + return ep_factory_.IsStreamAware(&ep_factory_); + } + + OrtStatus* CreateSyncStreamForDevice(const OrtMemoryDevice* device, + const OrtKeyValuePairs* stream_options, + OrtSyncStreamImpl** stream) noexcept override { + return ep_factory_.CreateSyncStreamForDevice(&ep_factory_, device, stream_options, stream); + } + + OrtEpFactory& ep_factory_; // OrtEpFactory from the provider bridge EP + ProviderLibrary& provider_library_; // ProviderLibrary from the provider bridge EP +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc new file mode 100644 index 0000000000000..0f955e0bab248 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.cc @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +#include "core/framework/error_code_helper.h" +#include "core/providers/webgpu/webgpu_provider_factory_creator.h" +#include "core/session/abi_devices.h" +#include "core/session/abi_logger.h" +#include "core/session/abi_session_options_impl.h" +#include "core/session/plugin_ep/ep_api.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/ort_apis.h" + +namespace onnxruntime { + +OrtStatus* WebGpuEpFactory::GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept { + size_t& num_ep_devices = *p_num_ep_devices; + num_ep_devices = 0; + + for (size_t i = 0; i < num_devices && num_ep_devices < max_ep_devices; ++i) { + const OrtHardwareDevice& device = *devices[i]; + if (device.type == OrtHardwareDeviceType::OrtHardwareDeviceType_GPU) { + // TODO: any metadata or options to add? + ORT_API_RETURN_IF_ERROR(OrtExecutionProviderApi::CreateEpDevice(&ep_factory, + &device, nullptr, nullptr, + &ep_devices[num_ep_devices++])); + } + } + + return nullptr; +} + +OrtStatus* WebGpuEpFactory::CreateIExecutionProvider(const OrtHardwareDevice* const* /*devices*/, + const OrtKeyValuePairs* const* /*ep_metadata_pairs*/, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept { + *ep = nullptr; + + if (num_devices != 1) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "WebGPU EP factory currently only supports one device at a time."); + } + + auto webgpu_ep_factory = WebGpuProviderFactoryCreator::Create(session_options->value.config_options); + *ep = webgpu_ep_factory->CreateProvider(); + (*ep)->SetLogger(session_logger->ToInternal()); + + return nullptr; +} + +/* TODO: Implement CreateAllocator and CreateDataTransfer to support shared allocators and data transfer outside of + an InferenceSession. +OrtStatus* WebGpuEpFactory::CreateAllocator(const OrtMemoryInfo* memory_info, + const OrtKeyValuePairs* allocator_options, + OrtAllocator** allocator) noexcept override { + *allocator = device_allocators[memory_info->device.Id()].get(); +} + +OrtStatus* WebGpuEpFactory::CreateDataTransfer(_Outptr_result_maybenull_ OrtDataTransferImpl** data_transfer) override { + // TODO: Wrap the IDataTransfer implementation so we can copy to device using OrtApi CopyTensors. + *data_transfer = nullptr; + return nullptr; +} +*/ +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h new file mode 100644 index 0000000000000..06ecfa744bbda --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_factory_webgpu.h @@ -0,0 +1,35 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if defined(USE_WEBGPU) +#include "core/session/plugin_ep/ep_factory_internal_impl.h" + +#include "core/graph/constants.h" + +namespace onnxruntime { + +class WebGpuEpFactory : public EpFactoryInternalImpl { + public: + WebGpuEpFactory() : EpFactoryInternalImpl(kWebGpuExecutionProvider, "Microsoft", OrtDevice::VendorIds::MICROSOFT) { + } + + private: + OrtStatus* GetSupportedDevices(EpFactoryInternal& ep_factory, + const OrtHardwareDevice* const* devices, + size_t num_devices, + OrtEpDevice** ep_devices, + size_t max_ep_devices, + size_t* p_num_ep_devices) noexcept override; + + OrtStatus* CreateIExecutionProvider(const OrtHardwareDevice* const* devices, + const OrtKeyValuePairs* const* ep_metadata_pairs, + size_t num_devices, + const OrtSessionOptions* session_options, + const OrtLogger* session_logger, + std::unique_ptr* ep) noexcept override; +}; +} // namespace onnxruntime + +#endif // USE_WEBGPU diff --git a/onnxruntime/core/session/ep_library.h b/onnxruntime/core/session/plugin_ep/ep_library.h similarity index 100% rename from onnxruntime/core/session/ep_library.h rename to onnxruntime/core/session/plugin_ep/ep_library.h diff --git a/onnxruntime/core/session/plugin_ep/ep_library_internal.cc b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc new file mode 100644 index 0000000000000..d4015e0bbd366 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.cc @@ -0,0 +1,52 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_internal.h" +#include "core/session/plugin_ep/ep_factory_cpu.h" +#include "core/session/plugin_ep/ep_factory_dml.h" +#include "core/session/plugin_ep/ep_factory_webgpu.h" + +namespace onnxruntime { + +std::unique_ptr EpLibraryInternal::CreateCpuEp() { + auto cpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(cpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} + +#if defined(USE_DML) + +std::unique_ptr EpLibraryInternal::CreateDmlEp() { + auto dml_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(dml_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +#if defined(USE_WEBGPU) +std::unique_ptr EpLibraryInternal::CreateWebGpuEp() { + auto webgpu_factory_impl = std::make_unique(); + auto internal_factory = std::make_unique(std::move(webgpu_factory_impl)); + return std::make_unique(std::move(internal_factory)); +} +#endif + +std::vector> EpLibraryInternal::CreateInternalEps() { + std::vector> internal_eps; + internal_eps.reserve(4); + + // CPU EP + internal_eps.push_back(CreateCpuEp()); + +#if defined(USE_WEBGPU) + internal_eps.push_back(CreateWebGpuEp()); +#endif + +#if defined(USE_DML) + internal_eps.push_back(CreateDmlEp()); +#endif + + return internal_eps; +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_internal.h b/onnxruntime/core/session/plugin_ep/ep_library_internal.h similarity index 94% rename from onnxruntime/core/session/ep_library_internal.h rename to onnxruntime/core/session/plugin_ep/ep_library_internal.h index ab529edc2507f..1587f01360e26 100644 --- a/onnxruntime/core/session/ep_library_internal.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_internal.h @@ -4,8 +4,8 @@ #pragma once #include "core/common/common.h" -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/provider_bridge_library.h" diff --git a/onnxruntime/core/session/ep_library_plugin.cc b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc similarity index 98% rename from onnxruntime/core/session/ep_library_plugin.cc rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.cc index 32ddd8a765b4c..ebfa364f4f1df 100644 --- a/onnxruntime/core/session/ep_library_plugin.cc +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_plugin.h" #include "core/common/logging/logging.h" #include "core/framework/error_code_helper.h" diff --git a/onnxruntime/core/session/ep_library_plugin.h b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h similarity index 96% rename from onnxruntime/core/session/ep_library_plugin.h rename to onnxruntime/core/session/plugin_ep/ep_library_plugin.h index e2b02ccc654da..e044e91b61e37 100644 --- a/onnxruntime/core/session/ep_library_plugin.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_plugin.h @@ -6,7 +6,7 @@ #include #include -#include "core/session/ep_library.h" +#include "core/session/plugin_ep/ep_library.h" namespace onnxruntime { /// diff --git a/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc new file mode 100644 index 0000000000000..06cf54aea4071 --- /dev/null +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/session/plugin_ep/ep_library_provider_bridge.h" + +#include "core/session/plugin_ep/ep_factory_provider_bridge.h" + +namespace onnxruntime { +Status EpLibraryProviderBridge::Load() { + std::lock_guard lock{mutex_}; + + if (!factories_.empty()) { + // already loaded + return Status::OK(); + } + + // if we have been unloaded we can't just be reloaded. + if (!ep_library_plugin_ || !provider_library_) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "EpLibraryProviderBridge has been unloaded. " + "Please create a new instance using LoadPluginOrProviderBridge."); + } + + // wrap the EpLibraryPlugin factories that were created via calling CreateEpFactories in the library. + // use GetSupportedDevices from the library's factory. + // to do this we need to capture `factory` and plug it in to is_supported_fn and create_fn. + // we also need to update any returned OrtEpDevice instances to swap the wrapper EpFactoryInternal in so that we can + // call Provider::CreateIExecutionProvider in EpFactoryInternal::CreateIExecutionProvider. + for (const auto& factory : ep_library_plugin_->GetFactories()) { + auto factory_impl = std::make_unique(*factory, *provider_library_); + auto internal_factory = std::make_unique(std::move(factory_impl)); + + factory_ptrs_.push_back(internal_factory.get()); + internal_factory_ptrs_.push_back(internal_factory.get()); + factories_.push_back(std::move(internal_factory)); + } + + return Status::OK(); +} + +Status EpLibraryProviderBridge::Unload() { + std::lock_guard lock{mutex_}; + + internal_factory_ptrs_.clear(); + factory_ptrs_.clear(); + factories_.clear(); + + // we loaded ep_library_plugin_ after provider_library_ in LoadPluginOrProviderBridge so do the reverse order here. + ORT_RETURN_IF_ERROR(ep_library_plugin_->Unload()); + ep_library_plugin_ = nullptr; + + provider_library_->Unload(); + provider_library_ = nullptr; + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/session/ep_library_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h similarity index 95% rename from onnxruntime/core/session/ep_library_provider_bridge.h rename to onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h index 0717ccd957de7..c7e8ebefc3785 100644 --- a/onnxruntime/core/session/ep_library_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_library_provider_bridge.h @@ -5,8 +5,8 @@ #include #include -#include "core/session/ep_library.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_library.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_bridge_library.h" namespace onnxruntime { diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc similarity index 99% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.cc rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index c7d7ea2e8a4ec..2aac1e1c21cc7 100644 --- a/onnxruntime/core/session/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include #include diff --git a/onnxruntime/core/session/ep_plugin_provider_interfaces.h b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h similarity index 100% rename from onnxruntime/core/session/ep_plugin_provider_interfaces.h rename to onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.h diff --git a/onnxruntime/core/session/ep_api_utils.h b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h similarity index 99% rename from onnxruntime/core/session/ep_api_utils.h rename to onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h index 77528565eced7..67b22779395ec 100644 --- a/onnxruntime/core/session/ep_api_utils.h +++ b/onnxruntime/core/session/plugin_ep/forward_to_factory_impl.h @@ -7,7 +7,7 @@ namespace onnxruntime { // helper to forward a call from the C API to an instance of the factory implementation. // used by EpFactoryInternal and EpFactoryProviderBridge. template -struct ForwardToFactory { +struct ForwardToFactoryImpl { static const char* ORT_API_CALL GetFactoryName(const OrtEpFactory* this_ptr) noexcept { return static_cast(this_ptr)->GetName(); } diff --git a/onnxruntime/core/session/provider_policy_context.cc b/onnxruntime/core/session/provider_policy_context.cc index 211bf8b2d15a4..6bcbda0f13b92 100644 --- a/onnxruntime/core/session/provider_policy_context.cc +++ b/onnxruntime/core/session/provider_policy_context.cc @@ -11,8 +11,8 @@ #include "core/framework/error_code_helper.h" #include "core/session/abi_devices.h" #include "core/session/abi_logger.h" -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "core/session/inference_session.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_c_api.h" diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 69039beb49363..f90ace95d6e58 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -19,10 +19,10 @@ #include "core/session/ort_env.h" #if !defined(ORT_MINIMAL_BUILD) -#include "core/session/ep_factory_internal.h" -#include "core/session/ep_plugin_provider_interfaces.h" -#include "core/session/ep_library_plugin.h" -#include "core/session/ep_library_provider_bridge.h" +#include "core/session/plugin_ep/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_library_plugin.h" +#include "core/session/plugin_ep/ep_library_provider_bridge.h" #include "core/session/model_compilation_options.h" #include "core/session/provider_policy_context.h" #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.cc b/onnxruntime/python/onnxruntime_pybind_exceptions.cc index 8f3b97c8c7786..6b3062205b52e 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.cc +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.cc @@ -37,6 +37,7 @@ void RegisterExceptions(pybind11::module& m) { pybind11::register_exception(m, "EPFail"); pybind11::register_exception(m, "ModelLoadCanceled"); pybind11::register_exception(m, "ModelRequiresCompilation"); + pybind11::register_exception(m, "NotFound"); } void OrtPybindThrowIfError(onnxruntime::common::Status status) { @@ -67,6 +68,8 @@ void OrtPybindThrowIfError(onnxruntime::common::Status status) { throw ModelLoadCanceled(std::move(msg)); case onnxruntime::common::StatusCode::MODEL_REQUIRES_COMPILATION: throw ModelRequiresCompilation(std::move(msg)); + case onnxruntime::common::StatusCode::NOT_FOUND: + throw NotFound(std::move(msg)); default: throw std::runtime_error(std::move(msg)); } diff --git a/onnxruntime/python/onnxruntime_pybind_exceptions.h b/onnxruntime/python/onnxruntime_pybind_exceptions.h index 86bc4a5da8d46..7680c06c59d79 100644 --- a/onnxruntime/python/onnxruntime_pybind_exceptions.h +++ b/onnxruntime/python/onnxruntime_pybind_exceptions.h @@ -50,6 +50,9 @@ struct ModelLoadCanceled : std::runtime_error { struct ModelRequiresCompilation : std::runtime_error { explicit ModelRequiresCompilation(const std::string& what) : std::runtime_error(what) {} }; +struct NotFound : std::runtime_error { + explicit NotFound(const std::string& what) : std::runtime_error(what) {} +}; void RegisterExceptions(pybind11::module& m); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ec4d8c6330c8d..acf0681cf8752 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -46,7 +46,7 @@ #if !defined(ORT_MINIMAL_BUILD) #include "core/session/abi_devices.h" -#include "core/session/ep_factory_internal.h" +#include "core/session/plugin_ep/ep_factory_internal.h" #include "core/session/provider_policy_context.h" #include "core/session/utils.h" #endif diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index c9a7116bf8052..2918e4baf86a4 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -82,21 +82,48 @@ static void CalculateDynamicQuantizeMatMul(const int64_t M, const int64_t N, con } } +struct TestDynamicQuantizeMatMulOptions { + bool is_matrix_b_constant = true; + + bool per_column = false; + + bool is_scale_constant = false; + + bool has_zp = true; + bool is_zp_constant = false; + bool is_zp_zero = false; + + bool has_bias = false; + bool is_bias_constant = false; + + bool empty_input = false; +}; + template -void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, - bool per_column = false, - bool has_zp = true, - bool has_bias = false, - bool empty_input = false) { +void TestDynamicQuantizeMatMul(const TestDynamicQuantizeMatMulOptions& opts) { + static_assert(std::is_same_v || std::is_same_v); + + SCOPED_TRACE(MakeString( + "b data type:", (std::is_same_v ? "uint8" : "int8"), + ", is_matrix_b_constant:", opts.is_matrix_b_constant, + ", per_column:", opts.per_column, + ", is_scale_constant:", opts.is_scale_constant, + ", has_zp:", opts.has_zp, + ", is_zp_constant:", opts.is_zp_constant, + ", is_zp_zero:", opts.is_zp_zero, + ", has_bias:", opts.has_bias, + ", is_bias_constant:", opts.is_bias_constant, + ", empty_input:", opts.empty_input)); + // create rand inputs RandomValueGenerator random{1668426375}; - int64_t M = empty_input ? 1 : 4; + int64_t M = opts.empty_input ? 1 : 4; int64_t N = 128; int64_t K = 128; - std::vector A_dims{empty_input ? 0 : M, K}; + std::vector A_dims{opts.empty_input ? 0 : M, K}; std::vector B_dims{K, N}; - std::vector Y_dims{empty_input ? 0 : M, K}; + std::vector Y_dims{opts.empty_input ? 0 : M, K}; std::vector A_data = random.Uniform(A_dims, -1.0f, 1.0f); std::vector B_data; std::vector tmp_B_data = random.Uniform(B_dims, @@ -106,101 +133,120 @@ void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, return static_cast(v); }); - int64_t b_scale_zp_size = per_column ? B_dims.back() : 1; + int64_t b_scale_zp_size = opts.per_column ? B_dims.back() : 1; std::vector B_scale = random.Uniform(AsSpan({b_scale_zp_size}), -0.1f, 0.1f); std::vector B_zero_point(b_scale_zp_size); - std::for_each(B_zero_point.begin(), - B_zero_point.end(), - [&random](T& zp) { - zp = static_cast(random.Uniform(std::array{1}, - std::numeric_limits::min(), - std::numeric_limits::max())[0]); - }); + if (!opts.is_zp_zero) { + std::for_each(B_zero_point.begin(), + B_zero_point.end(), + [&random](T& zp) { + zp = static_cast(random.Uniform(std::array{1}, + std::numeric_limits::min(), + std::numeric_limits::max())[0]); + }); + } std::vector Bias = random.Uniform(AsSpan({B_dims.back()}), -0.1f, 0.1f); OpTester test("DynamicQuantizeMatMul", 1, onnxruntime::kMSDomain); test.AddInput("A", A_dims, A_data); - test.AddInput("B", B_dims, B_data, is_matrix_b_constant); - test.AddInput("b_scale", {b_scale_zp_size}, B_scale); + test.AddInput("B", B_dims, B_data, opts.is_matrix_b_constant); + test.AddInput("b_scale", {b_scale_zp_size}, B_scale, opts.is_scale_constant); - if (has_zp) { - test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point); + if (opts.has_zp) { + test.AddInput("b_zero_point", {b_scale_zp_size}, B_zero_point, opts.is_zp_constant); } else { test.AddOptionalInputEdge(); } - if (has_bias) { - test.AddInput("bias", {B_dims.back()}, Bias); + if (opts.has_bias) { + test.AddInput("bias", {B_dims.back()}, Bias, opts.is_bias_constant); } else { test.AddOptionalInputEdge(); } std::vector Y_data(M * N); CalculateDynamicQuantizeMatMul(M, N, K, A_data, B_data, B_scale, B_zero_point, Bias, Y_data, - per_column, has_zp, has_bias); + opts.per_column, opts.has_zp, opts.has_bias); test.AddOutput("Y", Y_dims, Y_data); test.SetOutputRelErr("Y", 0.02f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kOpenVINOExecutionProvider}); } -template -void RunDynamicQuantizeMatMulTest() { - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - false, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(false, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); - - TestDynamicQuantizeMatMul(true, /*is_matrix_b_constant*/ - true, /*per_column*/ - HasZeroPoint, /*has_zp*/ - HasBias /*has_bias*/ - ); +template +void TestDynamicQuantizeMatMul(bool is_matrix_b_constant, + bool per_column = false, + bool has_zp = true, + bool has_bias = false, + bool empty_input = false) { + TestDynamicQuantizeMatMulOptions opts{}; + opts.is_matrix_b_constant = is_matrix_b_constant; + opts.per_column = per_column; + opts.has_zp = has_zp; + opts.has_bias = has_bias; + opts.empty_input = empty_input; + + TestDynamicQuantizeMatMul(opts); } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +template +void RunDynamicQuantizeMatMulTest() { + for (bool is_matrix_b_constant : {false, true}) { + for (bool per_column : {false, true}) { + for (bool has_zp : {false, true}) { + for (bool has_bias : {false, true}) { + TestDynamicQuantizeMatMul(is_matrix_b_constant, + per_column, + has_zp, + has_bias); + } + } + } + } } -TEST(DynamicQuantizeMatMul, HasZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, Int8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); +TEST(DynamicQuantizeMatMul, UInt8) { + RunDynamicQuantizeMatMulTest(); } -TEST(DynamicQuantizeMatMul, NoZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} +TEST(DynamicQuantizeMatMul, WithConstantBInputs) { + TestDynamicQuantizeMatMulOptions base_opts{}; + base_opts.is_matrix_b_constant = true; + base_opts.is_scale_constant = true; + base_opts.is_zp_constant = true; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // no zp + auto opts = base_opts; + opts.has_zp = false; -TEST(DynamicQuantizeMatMul, NoZeroPoint_NoBias_test_U8) { - RunDynamicQuantizeMatMulTest(); -} + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_S8) { - RunDynamicQuantizeMatMulTest(); -} + { + // zp that is zero (symmetric quantization) + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = true; -TEST(DynamicQuantizeMatMul, HasZeroPoint_HasBias_test_U8) { - RunDynamicQuantizeMatMulTest(); + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } + + { + // zp that is non-zero + auto opts = base_opts; + opts.has_zp = true; + opts.is_zp_zero = false; + + TestDynamicQuantizeMatMul(opts); + TestDynamicQuantizeMatMul(opts); + } } TEST(DynamicQuantizeMatMul, UInt8_test_with_empty_input) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 45314f8f39eea..188edad572182 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -87,6 +87,92 @@ TEST(EpGraphTest, Check3LayerNestedSubgraphV2) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, GetAttributeByName) { + // Load model with a single Conv that has no explicit attributes set. + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_default_attrs.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // + // Pre-check + // + + // Original Conv has no explicit attributes but Graph::Resolve() fills in default values for + // 'auto_pad' and 'group'. The other optional attributes (i.e. dilations, kernel_shape, pads, strides) do not + // have statically computable default values, so will not be filled in by Graph::Resolve(). + const OrtGraph& ort_graph = test_graph->GetOrtGraph(); + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_nodes = 0; + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNumNodes(&ort_graph, &num_nodes)); + ASSERT_EQ(num_nodes, 1); + + std::vector nodes(num_nodes); + ASSERT_ORTSTATUS_OK(ort_api.Graph_GetNodes(&ort_graph, nodes.data(), nodes.size())); + + const OrtNode* conv_node = nodes[0]; + const char* op_type = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetOperatorType(conv_node, &op_type)); + ASSERT_STREQ(op_type, "Conv"); + + size_t num_attrs = 0; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetNumAttributes(conv_node, &num_attrs)); + ASSERT_EQ(num_attrs, 2); + + std::vector attrs(num_attrs); + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributes(conv_node, attrs.data(), attrs.size())); + for (const OrtOpAttr* attr : attrs) { + const char* attr_name_cstr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetName(attr, &attr_name_cstr)); + std::string_view attr_name = attr_name_cstr; + ASSERT_TRUE(attr_name == "auto_pad" || attr_name == "group"); // Only 'auto_pad' and 'group' have been set + } + + // + // Test 1: Get optional attribute that is not set (e.g., dilations). Should not get an error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "dilations", &attr)}; + ASSERT_TRUE(status.IsOK()); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 2: Get attribute that does not exist in operator schema. Should get a ORT_NOT_FOUND error. + // + { + const OrtOpAttr* attr = nullptr; + Ort::Status status{ort_api.Node_GetAttributeByName(conv_node, "_does_not_exist_", &attr)}; + ASSERT_FALSE(status.IsOK()); + ASSERT_EQ(status.GetErrorCode(), ORT_NOT_FOUND); + ASSERT_EQ(attr, nullptr); + } + + // + // Test 3: Get attribute that is known to be set. + // + { + const OrtOpAttr* attr = nullptr; + ASSERT_ORTSTATUS_OK(ort_api.Node_GetAttributeByName(conv_node, "auto_pad", &attr)); + ASSERT_NE(attr, nullptr); + + OrtOpAttrType attr_type = ORT_OP_ATTR_UNDEFINED; + ASSERT_ORTSTATUS_OK(ort_api.OpAttr_GetType(attr, &attr_type)); + ASSERT_EQ(attr_type, ORT_OP_ATTR_STRING); + + std::string auto_pad_val; + + // First call to ReadOpAttr gets the total byte size. Second call reads the data. + size_t total_attr_bytes = 0; + Ort::Status status2{ort_api.ReadOpAttr(attr, attr_type, nullptr, 0, &total_attr_bytes)}; + auto_pad_val.resize(total_attr_bytes); + + ASSERT_ORTSTATUS_OK(ort_api.ReadOpAttr(attr, attr_type, auto_pad_val.data(), total_attr_bytes, + &total_attr_bytes)); + ASSERT_EQ(auto_pad_val, "NOTSET"); + } +} + // Check correctness of an OrtGraph that has external initializers. TEST(EpGraphTest, CheckModelExternalInitializers) { auto test_graph = TestGraph::Load(ORT_TSTR("testdata/conv_qdq_external_ini.onnx")); @@ -220,6 +306,39 @@ static void RunMNISTModel(const ORTCHAR_T* model_path, std::vector& outpu output_data.assign(output_values, output_values + num_output_elems); } +static void RunConstantOfShapeModel(const ORTCHAR_T* model_path, std::vector& output_data) { + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + Ort::SessionOptions sess_options; + Ort::Session session(*ort_env, model_path, sess_options); + + std::vector input_shape = {3}; + std::vector input_data = {2, 3, 4}; + std::vector ort_inputs; + std::vector ort_input_names; + + // Add 'x' + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input_data.data(), input_data.size(), input_shape.data(), input_shape.size())); + ort_input_names.push_back("x"); + + // Run session and get outputs + std::array output_names{"y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check output type and number of elements. + Ort::Value& ort_output = ort_outputs[0]; + auto output_type_shape = ort_output.GetTensorTypeAndShapeInfo(); + size_t num_output_elems = output_type_shape.GetElementCount(); + + ASSERT_EQ(output_type_shape.GetElementType(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + ASSERT_EQ(num_output_elems, 24); + + // Return output data. + const float* output_values = ort_output.GetTensorData(); + output_data.assign(output_values, output_values + num_output_elems); +} + // Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. // Checks that the outputs of the serialized and original models are identical. TEST(EpGraphTest, SerializeToProto_Mnist) { @@ -350,6 +469,65 @@ TEST(EpGraphTest, SerializeToProto_ExternalInitializersInMemory) { } } +// Test serializing an OrtGraph (MNIST) to GraphProto. Saves initializers to external file. +// Checks that the outputs of the serialized and original models are identical. +TEST(EpGraphTest, SerializeToProto_ConstantOfShape) { + const ORTCHAR_T* original_model_path = ORT_TSTR("testdata/ort_minimal_test_models/tensor_attribute.onnx"); + const ORTCHAR_T* serialized_model_path = ORT_TSTR("constant_of_shape.onnx"); + std::filesystem::remove(serialized_model_path); + + { + auto test_graph = TestGraph::Load(original_model_path); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + // Serialize OrtGraph to GraphProto. Save initializers to external file. + std::string ext_ini_file_path = "constant_of_shape_serialized.bin"; + std::filesystem::remove(ext_ini_file_path); + std::ofstream ext_ini_ofs(ext_ini_file_path, std::ios::binary); + auto handle_initializer_data = [&ext_ini_ofs, &ext_ini_file_path](const OrtValueInfo* value_info, + const void* data, size_t bytes, + bool& is_external, std::string& location, + int64_t& offset) -> Ort::Status { + // OrtValueInfo* could be used to query initializer's name, type, shape, + // node consumers, etc. + static_cast(value_info); + + if (bytes <= 127) { + is_external = false; // Keep small initializers stored inside the TensorProto. + return Ort::Status{nullptr}; + } + + offset = ext_ini_ofs.tellp(); + location = ext_ini_file_path; + ext_ini_ofs.write(static_cast(data), bytes); + ext_ini_ofs.flush(); + is_external = true; // True if is external initializer. + + return Ort::Status{nullptr}; + }; + + ONNX_NAMESPACE::ModelProto model_proto; + ASSERT_CXX_ORTSTATUS_OK(OrtEpUtils::OrtGraphToProto(test_graph->GetOrtGraph(), model_proto, + handle_initializer_data)); + + std::ofstream ofs(serialized_model_path, std::ios::binary); + model_proto.SerializeToOstream(&ofs); + ofs.flush(); + + ASSERT_TRUE(std::filesystem::exists(serialized_model_path)); + ASSERT_TRUE(std::filesystem::exists(ext_ini_file_path)); + } + + // Compare output of the original and serialized models. Should be identical. + std::vector output_original; + std::vector output_serialized; + + RunConstantOfShapeModel(original_model_path, output_original); + RunConstantOfShapeModel(serialized_model_path, output_serialized); + + EXPECT_EQ(output_serialized, output_original); +} + static void Run3LayerModel(const ORTCHAR_T* model_path, bool input_cond, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; @@ -892,6 +1070,10 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_GRAPH); break; } + case ONNX_NAMESPACE::AttributeProto_AttributeType::AttributeProto_AttributeType_TENSOR: { + ASSERT_EQ(api_node_attr_type, OrtOpAttrType::ORT_OP_ATTR_TENSOR); + break; + } default: // The unsupported type should be skipped by 'continue' above. It's unexpected so we force test to fail. ASSERT_ORTSTATUS_OK(ort_api.CreateStatus(ORT_FAIL, "The attribute type is not in AttributeProto_AttributeType and this case shouldn't be hit.")); diff --git a/onnxruntime/test/framework/ep_plugin_provider_test.cc b/onnxruntime/test/framework/ep_plugin_provider_test.cc index 4c5dcd2bd7580..35f7d06fb0912 100644 --- a/onnxruntime/test/framework/ep_plugin_provider_test.cc +++ b/onnxruntime/test/framework/ep_plugin_provider_test.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/session/ep_plugin_provider_interfaces.h" +#include "core/session/plugin_ep/ep_plugin_provider_interfaces.h" #include "gsl/gsl" #include "gtest/gtest.h" diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 98640bb2f6b4c..1baa6e529cbde 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -12,6 +12,7 @@ #include "core/mlas/inc/mlas.h" #include "core/optimizer/double_qdq_pairs_remover.h" #include "core/optimizer/qdq_transformer/weight_bias_quantization.h" +#include "core/optimizer/qdq_transformer/where_dummy_dq.h" #include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" #include "core/optimizer/qdq_transformer/qdq_propagation.h" #include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" @@ -3220,6 +3221,79 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 } +template +void TestWhereWithDqInput(bool is_dq_1, + bool is_dq_2, + int expected_num_where, + int expected_num_dq, + int expected_num_q, + bool expected_modified) { + auto& logger = DefaultLoggingManager().DefaultLogger(); + Model model("WhereDummyDqTester", false, logger); + Graph& graph = model.MainGraph(); + ModelTestBuilder builder(graph); + + NodeArg* where_in1 = nullptr; + NodeArg* where_in2 = nullptr; + if (is_dq_1) { + // DQ + auto* dq_Input = builder.MakeInput({4, 3, 32}, 0.0, 1.0); + auto* dq_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* dq_zp = builder.MakeInitializer({}, 0.0, 1.0); + where_in1 = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in1}); + } else { + where_in1 = builder.MakeInitializer({}, 0.0, 1.0); + } + if (is_dq_2) { + // DQ + auto* dq_Input = builder.MakeInput({4, 3, 32}, 0.0, 1.0); + auto* dq_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* dq_zp = builder.MakeInitializer({}, 0.0, 1.0); + where_in2 = builder.MakeIntermediate(); + builder.AddNode("DequantizeLinear", {dq_Input, dq_scale, dq_zp}, {where_in2}); + } else { + where_in2 = builder.MakeInitializer({}, 0.0, 1.0); + } + + // Where + auto* where_cond = builder.MakeInputBool({4, 3, 32}); + auto* where_out = builder.MakeIntermediate(); + builder.AddNode("Where", {where_cond, where_in1, where_in2}, {where_out}); + + // Q + auto* q_scale = builder.MakeInitializer({}, 0.0, 1.0); + auto* q_zp = builder.MakeInitializer({}, 0.0, 1.0); + auto* q_out = builder.MakeOutput(); + builder.AddNode("QuantizeLinear", {where_out, q_scale, q_zp}, {q_out}); + + builder.SetGraphOutputs(); + ASSERT_STATUS_OK(graph.Resolve()); + + auto where_optimizer = std::make_unique(); + bool modified = false; + ASSERT_STATUS_OK(where_optimizer->Apply(graph, modified, logger)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Where"], expected_num_where); + ASSERT_EQ(op_to_count["DequantizeLinear"], expected_num_dq); + ASSERT_EQ(op_to_count["QuantizeLinear"], expected_num_q); + ASSERT_EQ(modified, expected_modified); + + return; +}; + +TEST(QDQTransformerTests, WhereDummyDqTest) { + TestWhereWithDqInput(true, true, 1, 2, 1, false); + TestWhereWithDqInput(true, false, 1, 2, 1, true); + TestWhereWithDqInput(false, true, 1, 2, 1, true); + TestWhereWithDqInput(false, false, 1, 0, 1, false); + TestWhereWithDqInput(true, true, 1, 2, 1, false); + TestWhereWithDqInput(true, false, 1, 2, 1, true); + TestWhereWithDqInput(false, true, 1, 2, 1, true); + TestWhereWithDqInput(false, false, 1, 0, 1, false); +} + TEST(QDQTransformerTests, Concat) { auto test_case = [&](const std::vector>& input_shapes, int64_t axis, diff --git a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc index 5de7885a9452a..761ddf1975d15 100644 --- a/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc +++ b/onnxruntime/test/providers/migraphx/migraphx_basic_test.cc @@ -188,6 +188,24 @@ TEST(MIGraphXExecutionProviderTest, canEvalArgument) { ASSERT_EQ(canEvalNodeArgument(gv, node2, {1}, input_nodes), true); } +static bool SessionHasEp(Ort::Session& session, const char* ep_name) { + // Access the underlying InferenceSession. + const OrtSession* ort_session = session; + const InferenceSession* s = reinterpret_cast(ort_session); + bool has_ep = false; + + for (const auto& provider : s->GetRegisteredProviderTypes()) { + if (provider == ep_name) { + has_ep = true; + break; + } + } + return has_ep; +} + +#if defined(WIN32) +// Tests autoEP feature to automatically select an EP that supports the GPU. +// Currently only works on Windows. TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("migraphx_basic_test.onnx"); @@ -212,6 +230,7 @@ TEST(MIGraphXExecutionProviderTest, AutoEp_PreferGpu) { env.UnregisterExecutionProviderLibrary(kMIGraphXExecutionProvider); } +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/einsum_op_test.cc b/onnxruntime/test/providers/qnn/einsum_op_test.cc index a2a0ce485bb35..11a3d5a083aab 100644 --- a/onnxruntime/test/providers/qnn/einsum_op_test.cc +++ b/onnxruntime/test/providers/qnn/einsum_op_test.cc @@ -202,6 +202,32 @@ TEST_F(QnnCPUBackendTests, EinsumRank4MatMulTransposeAll2) { /*tolerance=*/1e-4f); } +TEST_F(QnnCPUBackendTests, EinsumMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-4f); +} + +TEST_F(QnnCPUBackendTests, EinsumReduceSumMulBroadcastX) { + const std::vector shape0{2, 3, 4, 5}; + const std::vector shape1{4, 6, 5}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeCpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/1e-4f); +} + // // QNN HTP F16 // @@ -273,6 +299,32 @@ TEST_F(QnnHTPBackendTests, EinsumF16Rank4MatMulTransposeAll2) { /*tolerance=*/1e-2f); } +TEST_F(QnnHTPBackendTests, EinsumF16MatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-2f); +} + +TEST_F(QnnHTPBackendTests, EinsumF16ReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeHtp, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/1e-2f); +} + // // QNN HTP QDQ // @@ -337,6 +389,31 @@ TEST_F(QnnHTPBackendTests, EinsumQdqRank4MatMulTransposeAll2) { /*tolerance=*/QDQTolerance()); } +TEST_F(QnnHTPBackendTests, EinsumQdqMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/QDQTolerance()); +} + +// TODO: Re-enable. QAIRT 3.36.1: failed to finalize QNN graph 1002. +TEST_F(QnnHTPBackendTests, DISABLED_EinsumQdqReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnHtpQdqEinsum( + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/QDQTolerance()); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #if defined(_M_ARM64) @@ -422,6 +499,34 @@ TEST_F(QnnGPUBackendTests, EinsumRank4MatMulTransposeAll2) { /*tolerance=*/1e-4f); } +// Numeric instability in GPU backend, see also MatMul tests. +TEST_F(QnnGPUBackendTests, DISABLED_EinsumMatMulBroadcastTransposeY) { + const std::vector shape0{2, 3, 3, 4}; + const std::vector shape1{3, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,hkc->bhwk", + /*tolerance=*/1e-4f); +} + +// TODO: Re-enable. Failed on QAIRT 3.36.1. +TEST_F(QnnGPUBackendTests, DISABLED_EinsumReduceSumMulBroadcastX) { + const std::vector shape0{1, 3, 2, 4}; + const std::vector shape1{2, 3, 4}; + const std::vector data0 = GetSequentialFloatData(shape0, /*start=*/-0.1f, /*step=*/0.05f); + const std::vector data1 = GetSequentialFloatData(shape1, /*start=*/-0.1f, /*step=*/0.05f); + RunQnnEinsum( + /*backend=*/kQnnBackendTypeGpu, + /*in0=*/TestInputDef(shape0, /*is_initializer=*/false, std::move(data0)), + /*in1=*/TestInputDef(shape1, /*is_initializer=*/false, std::move(data1)), + /*equation=*/"bhwc,wkc->bhwk", + /*tolerance=*/1e-4f); +} + #endif // defined(_M_ARM64) GPU tests } // namespace test diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index f127e14bde635..4383faab369ab 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -54,18 +54,17 @@ TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) { ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. } -// Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1). -// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` -TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { +// Test Gemm with 2D bias is supported. +TEST_F(QnnCPUBackendTests, Gemm_2D_Bias) { std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12); - // 2D matrix mul with bias not supported. + // 2D matrix mul with bias is supported. RunGemmTest({TestInputDef({2, 3}, false, input_a_data), TestInputDef({3, 4}, false, input_b_data), TestInputDef({2, 4}, false, -1.0f, 1.0f)}, {}, - ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + ExpectedEPNodeAssignment::All); // Assigned to QNN EP. // However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`. RunGemmTest({TestInputDef({2, 3}, false, input_a_data), @@ -525,15 +524,19 @@ TEST_F(QnnGPUBackendTests, Gemm_AlphaBetaUnsupported) { "gpu"); } -// Gemm with matrix bias ie 2D (M, N) is NOT supported. (Note: vector bias is supported ie when M == 1). +// Gemm with matrix bias ie 2D (M, N) is supported. +// When vector bias ie M == 1 // QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` -TEST_F(QnnGPUBackendTests, Gemm_2DBiasUnsupported) { - // 2D matrix mul with 2D bias not supported. +// When 2D bias i.e. M != 1, N != 1. +// When 2D bias i.e. M != 1, N != 1. +// QNN's Gemm will be split in to FullyConnected and ElementwiseAdd. +TEST_F(QnnGPUBackendTests, Gemm_2D_Bias) { + // 2D matrix mul with 2D bias is supported when Gemm is not a QDQ node. RunGemmTest({TestInputDef({2, 3}, false, -10.0f, 10.0f), TestInputDef({3, 4}, false, -10.0f, 10.0f), TestInputDef({2, 4}, false, -1.0f, 1.0f)}, {}, - ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + ExpectedEPNodeAssignment::All, // Should be assigned to QNN EP. "gpu"); } diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc new file mode 100644 index 0000000000000..773e6146f102f --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqgemm_fusion_test.cc @@ -0,0 +1,126 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) + +namespace { + +GetQDQTestCaseFn BuildLPBQGemmTestCase() { + return [](ModelTestBuilder& builder) -> void { + // Define the test case for LPBQGemm fusion here + const int64_t input_channels = 16; + const int64_t output_channels = 16; + const int64_t blocks_per_axis = 4; + const std::vector input_shape{1, input_channels}; + auto input_def = TestInputDef(input_shape, false, -0.5f, 0.5f); + NodeArg* input = MakeTestInput(builder, input_def); + + // QuantizeLinear for Activation + NodeArg* act_ql_output = builder.MakeIntermediate(); + NodeArg* act_ql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_ql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("QuantizeLinear", {input, act_ql_scale, act_ql_zero_point}, {act_ql_output}); + + // DequantizeLinear for Activation + NodeArg* act_dql_output = builder.MakeIntermediate(); + NodeArg* act_dql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_dql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("DequantizeLinear", {act_ql_output, act_dql_scale, act_dql_zero_point}, {act_dql_output}); + + // DequantizeLinear for Scale + NodeArg* scale_dql_input = builder.MakeInitializer({blocks_per_axis, output_channels}, 1, 15); + NodeArg* scale_dql_scale = builder.MakeInitializer({output_channels}, 0.01f, 0.02f); + std::vector dql_zero_points_data = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + NodeArg* scale_dql_zero_point = builder.Make1DInitializer(dql_zero_points_data); + NodeArg* scale_dql_output = builder.MakeIntermediate(); + Node& scale_dql = builder.AddNode("DequantizeLinear", {scale_dql_input, scale_dql_scale, scale_dql_zero_point}, {scale_dql_output}); + scale_dql.AddAttribute("axis", static_cast(1)); + + // QuantizeLinear for Weight + NodeArg* w_ql_input = builder.MakeInitializer({input_channels, output_channels}, -1.0f, 1.0f); + std::vector zero_points_data; + size_t num_storage_elems = blocks_per_axis * output_channels; + zero_points_data.resize(Int4x2::CalcNumInt4Pairs(num_storage_elems)); + for (size_t i = 0; i < num_storage_elems; ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + zero_points_data[r].SetElem(c, 0); + } + NodeArg* w_ql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_ql_output = builder.MakeIntermediate(); + Node& w_ql = builder.AddNode("QuantizeLinear", {w_ql_input, scale_dql_output, w_ql_zero_point}, {w_ql_output}); + w_ql.AddAttribute("axis", static_cast(0)); + w_ql.AddAttribute("block_size", static_cast(4)); + + // DequantizeLinear for Weight + NodeArg* w_dql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_dql_output = builder.MakeIntermediate(); + Node& w_dql = builder.AddNode("DequantizeLinear", {w_ql_output, scale_dql_output, w_dql_zero_point}, {w_dql_output}); + w_dql.AddAttribute("axis", static_cast(0)); + w_dql.AddAttribute("block_size", static_cast(4)); + + // Gemm + NodeArg* gemm_bias = builder.MakeInitializer({output_channels}, -1.0f, 1.0f); + NodeArg* gemm_output = builder.MakeIntermediate(); + builder.AddNode("Gemm", {act_dql_output, w_dql_output, gemm_bias}, {gemm_output}); + + // QuantizeLinear for Output + NodeArg* output_ql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_ql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_ql_output = builder.MakeIntermediate(); + builder.AddNode("QuantizeLinear", {gemm_output, output_ql_scale, output_ql_zero_point}, {output_ql_output}); + + // DequantizeLinear for Output + NodeArg* output_dql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_dql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_dql_output = builder.MakeOutput(); + builder.AddNode("DequantizeLinear", {output_ql_output, output_dql_scale, output_dql_zero_point}, {output_dql_output}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +TEST_F(QnnHTPBackendTests, LPBQGemmFusion) { + ProviderOptions provider_options = GetProviderOptions(); + RunQnnModelTest(BuildLPBQGemmTestCase(), + provider_options, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::Some, + /*fp32_abs_err=*/1e-2f, + /*log_severity =*/logging::Severity::kERROR, + /*verify_outputs=*/false); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc new file mode 100644 index 0000000000000..92c0895cd691e --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/lpbqmatmul_fusion_test.cc @@ -0,0 +1,125 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/optimizer/qdq_test_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) + +namespace { + +GetQDQTestCaseFn BuildLPBQMatMulTestCase() { + return [](ModelTestBuilder& builder) -> void { + // Define the test case for LPBQGemm fusion here + const int64_t input_channels = 16; + const int64_t output_channels = 16; + const int64_t blocks_per_axis = 4; + const std::vector input_shape{1, input_channels}; + auto input_def = TestInputDef(input_shape, false, -0.5f, 0.5f); + NodeArg* input = MakeTestInput(builder, input_def); + + // QuantizeLinear for Activation + NodeArg* act_ql_output = builder.MakeIntermediate(); + NodeArg* act_ql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_ql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("QuantizeLinear", {input, act_ql_scale, act_ql_zero_point}, {act_ql_output}); + + // DequantizeLinear for Activation + NodeArg* act_dql_output = builder.MakeIntermediate(); + NodeArg* act_dql_scale = builder.MakeScalarInitializer(0.00005509183756657876f); + NodeArg* act_dql_zero_point = builder.MakeScalarInitializer(23715); + builder.AddNode("DequantizeLinear", {act_ql_output, act_dql_scale, act_dql_zero_point}, {act_dql_output}); + + // DequantizeLinear for Scale + NodeArg* scale_dql_input = builder.MakeInitializer({blocks_per_axis, output_channels}, 1, 16); + NodeArg* scale_dql_scale = builder.MakeInitializer({output_channels}, 0.01f, 0.02f); + std::vector dql_zero_points_data = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + NodeArg* scale_dql_zero_point = builder.Make1DInitializer(dql_zero_points_data); + NodeArg* scale_dql_output = builder.MakeIntermediate(); + Node& scale_dql = builder.AddNode("DequantizeLinear", {scale_dql_input, scale_dql_scale, scale_dql_zero_point}, {scale_dql_output}); + scale_dql.AddAttribute("axis", static_cast(1)); + + // QuantizeLinear for Weight + NodeArg* w_ql_input = builder.MakeInitializer({input_channels, output_channels}, -2.0f, 2.0f); + std::vector zero_points_data; + size_t num_storage_elems = blocks_per_axis * output_channels; + zero_points_data.resize(Int4x2::CalcNumInt4Pairs(num_storage_elems)); + for (size_t i = 0; i < num_storage_elems; ++i) { + size_t r = i >> 1; + size_t c = i & 0x1; + zero_points_data[r].SetElem(c, 0); + } + NodeArg* w_ql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_ql_output = builder.MakeIntermediate(); + Node& w_ql = builder.AddNode("QuantizeLinear", {w_ql_input, scale_dql_output, w_ql_zero_point}, {w_ql_output}); + w_ql.AddAttribute("axis", static_cast(0)); + w_ql.AddAttribute("block_size", static_cast(4)); + + // DequantizeLinear for Weight + NodeArg* w_dql_zero_point = builder.MakeInitializer({blocks_per_axis, output_channels}, zero_points_data); + NodeArg* w_dql_output = builder.MakeIntermediate(); + Node& w_dql = builder.AddNode("DequantizeLinear", {w_ql_output, scale_dql_output, w_dql_zero_point}, {w_dql_output}); + w_dql.AddAttribute("axis", static_cast(0)); + w_dql.AddAttribute("block_size", static_cast(4)); + + // MatMul + NodeArg* matmul_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {act_dql_output, w_dql_output}, {matmul_output}); + + // QuantizeLinear for Output + NodeArg* output_ql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_ql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_ql_output = builder.MakeIntermediate(); + builder.AddNode("QuantizeLinear", {matmul_output, output_ql_scale, output_ql_zero_point}, {output_ql_output}); + + // DequantizeLinear for Output + NodeArg* output_dql_scale = builder.MakeScalarInitializer(0.00019595865160226822f); + NodeArg* output_dql_zero_point = builder.MakeScalarInitializer(31693); + NodeArg* output_dql_output = builder.MakeOutput(); + builder.AddNode("DequantizeLinear", {output_ql_output, output_dql_scale, output_dql_zero_point}, {output_dql_output}); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +TEST_F(QnnHTPBackendTests, LPBQMatMulFusion) { + ProviderOptions provider_options = GetProviderOptions(); + RunQnnModelTest(BuildLPBQMatMulTestCase(), + provider_options, + /*opset_version=*/21, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::Some, + /*fp32_abs_err=*/1e-2f, + /*log_severity =*/logging::Severity::kERROR, + /*verify_outputs=*/false); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/python/transformers/test_parity_moe.py b/onnxruntime/test/python/transformers/test_moe_cuda.py similarity index 53% rename from onnxruntime/test/python/transformers/test_parity_moe.py rename to onnxruntime/test/python/transformers/test_moe_cuda.py index 252d89a2257fc..9b69d63970311 100644 --- a/onnxruntime/test/python/transformers/test_parity_moe.py +++ b/onnxruntime/test/python/transformers/test_moe_cuda.py @@ -9,6 +9,8 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import itertools +import os import unittest from collections import OrderedDict @@ -21,28 +23,42 @@ import onnxruntime -torch.manual_seed(42) -numpy.random.seed(42) +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" -USE_QUANT = False -ORT_DTYPE = TensorProto.FLOAT16 if USE_QUANT else TensorProto.FLOAT -NP_TYPE = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 -THRESHOLD = 5e-1 if USE_QUANT else 1e-2 +onnxruntime.preload_dlls() +# Determine the execution provider and device based on CUDA availability. +use_cuda = "CUDAExecutionProvider" in onnxruntime.get_available_providers() and torch.cuda.is_available() +device = torch.device("cuda:0" if use_cuda else "cpu") +ort_provider = ["CUDAExecutionProvider"] if use_cuda else ["CPUExecutionProvider"] -def value_string_of(numpy_array): - arr = numpy_array.flatten() - lines = ["f, ".join([str(v) for v in arr[i : min(i + 8, arr.size)]]) for i in range(0, arr.size, 8)] - return "{\n " + "f,\n ".join(lines) + "f}" +torch.manual_seed(42) +numpy.random.seed(42) +onnx_to_torch_type_map = { + TensorProto.FLOAT16: torch.float16, + TensorProto.FLOAT: torch.float, + TensorProto.BFLOAT16: torch.bfloat16, + TensorProto.UINT8: torch.uint8, +} + +ort_to_numpy_type_map = { + TensorProto.FLOAT16: numpy.float16, + TensorProto.FLOAT: numpy.float32, + TensorProto.UINT8: numpy.uint8, +} -def print_tensor(name, numpy_array): - print(f"const std::vector {name} = {value_string_of(numpy_array)};") +ort_dtype_name_map = { + TensorProto.FLOAT16: "FP16", + TensorProto.FLOAT: "FP32", + TensorProto.BFLOAT16: "BF16", +} -def quant_dequant(weights, quant_mode: bool = True): +def quant_dequant(weights, is_4_bit_quantization: bool = True): # use the test version `_symmetric_...` to get the non-interleaved weights - type = torch.quint4x2 if quant_mode else torch.int8 + type = torch.quint4x2 if is_4_bit_quantization else torch.int8 # This import is needed to use torch.ops.trtllm._symmetric_quantize_last_axis_of_batched_matrix() # Comment out this line for passing the lintrunner check in the CI. # import tensorrt_llm @@ -52,7 +68,7 @@ def quant_dequant(weights, quant_mode: bool = True): ) # Unpack the int4s int int8s - if quant_mode: + if is_4_bit_quantization: upper = quant_weights >> 4 lower = (quant_weights << 4) >> 4 # Arithmetic right shift sign extends quant_weights = torch.stack((lower, upper), dim=2).view(weights.T.shape) @@ -71,6 +87,7 @@ def create_moe_onnx_graph( fc1_experts_bias, fc2_experts_weights, fc2_experts_bias, + onnx_dtype, ): nodes = [ helper.make_node( @@ -94,21 +111,21 @@ def create_moe_onnx_graph( fc1_shape = [num_experts, hidden_size, inter_size] fc2_shape = [num_experts, inter_size, hidden_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] @@ -119,35 +136,35 @@ def create_moe_onnx_graph( [ helper.make_tensor( "fc1_experts_bias", - ORT_DTYPE, + onnx_dtype, fc1_bias_shape, - fc1_experts_bias.to(torch_type).flatten().tolist(), + fc1_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_bias", - ORT_DTYPE, + onnx_dtype, fc2_bias_shape, - fc2_experts_bias.to(torch_type).flatten().tolist(), + fc2_experts_bias.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -171,6 +188,7 @@ def create_mixtral_moe_onnx_graph( fc2_experts_weights, fc3_experts_weights, topk, + onnx_dtype, ): nodes = [ helper.make_node( @@ -197,46 +215,46 @@ def create_mixtral_moe_onnx_graph( fc2_shape = [num_experts, inter_size, hidden_size] fc3_shape = [num_experts, hidden_size, inter_size] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE, + onnx_dtype, fc1_shape, - fc1_experts_weights.to(torch_type).flatten().tolist(), + fc1_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE, + onnx_dtype, fc2_shape, - fc2_experts_weights.to(torch_type).flatten().tolist(), + fc2_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE, + onnx_dtype, fc3_shape, - fc3_experts_weights.to(torch_type).flatten().tolist(), + fc3_experts_weights.to(torch_dtype).flatten().tolist(), raw=False, ), ] graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -259,12 +277,14 @@ def create_phi_moe_onnx_graph( fc1_experts_weights, fc2_experts_weights, fc3_experts_weights, - fc1_scales, - fc2_scales, - fc3_scales, topk, + onnx_dtype, + quant_bits=0, + fc1_scales=None, + fc2_scales=None, + fc3_scales=None, ): - use_quant = USE_QUANT + use_quant = quant_bits > 0 if use_quant: assert fc1_experts_weights.dtype == torch.int8 assert fc2_experts_weights.dtype == torch.int8 @@ -276,34 +296,37 @@ def create_phi_moe_onnx_graph( assert fc2_scales.dtype == torch.float16 assert fc3_scales.dtype == torch.float16 + op_name = "QMoE" if use_quant else "MoE" + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_scales", + "", + "fc2_experts_weights", + "fc2_scales", + "", + "fc3_experts_weights", + "fc3_scales", + "", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "", + "fc2_experts_weights", + "", + "fc3_experts_weights", + ] + ) + nodes = [ helper.make_node( - "MoE" if not use_quant else "QMoE", - ( - [ - "input", - "router_probs", - "fc1_experts_weights", - "", - "fc2_experts_weights", - "", - "fc3_experts_weights", - ] - if not use_quant - else [ - "input", - "router_probs", - "fc1_experts_weights", - "fc1_scales", - "", - "fc2_experts_weights", - "fc2_scales", - "", - "fc3_experts_weights", - "fc3_scales", - "", - ] - ), + op_name, + inputs, ["output"], "MoE_0", k=topk, @@ -315,37 +338,38 @@ def create_phi_moe_onnx_graph( ] if use_quant: - nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", 8)]) + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) - fc1_shape = [num_experts, hidden_size, inter_size] - fc2_shape = [num_experts, inter_size, hidden_size] - fc3_shape = [num_experts, hidden_size, inter_size] + components = 2 if quant_bits == 4 else 1 + fc1_shape = [num_experts, hidden_size, inter_size // components] + fc2_shape = [num_experts, inter_size, hidden_size // components] + fc3_shape = [num_experts, hidden_size, inter_size // components] - torch_type = torch.float16 if ORT_DTYPE == TensorProto.FLOAT16 else torch.float32 - numpy_type = numpy.float16 if ORT_DTYPE == TensorProto.FLOAT16 else numpy.float32 - if use_quant: - numpy_type = numpy.uint8 + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + + weight_numpy_type = numpy.uint8 if use_quant else ort_to_numpy_type_map[onnx_dtype] + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype initializers = [ helper.make_tensor( "fc1_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc1_shape, - fc1_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc1_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc2_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc2_shape, - fc2_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc2_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), helper.make_tensor( "fc3_experts_weights", - ORT_DTYPE if not use_quant else TensorProto.UINT8, + weight_onnx_type, fc3_shape, - fc3_experts_weights.flatten().detach().numpy().astype(numpy_type).tolist(), + fc3_experts_weights.flatten().detach().cpu().numpy().astype(weight_numpy_type).tolist(), raw=False, ), ] @@ -358,42 +382,42 @@ def create_phi_moe_onnx_graph( [ helper.make_tensor( "fc1_scales", - ORT_DTYPE, + onnx_dtype, fc1_scale_shape, - fc1_scales.to(torch_type).flatten().tolist(), + fc1_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc2_scales", - ORT_DTYPE, + onnx_dtype, fc2_scale_shape, - fc2_scales.to(torch_type).flatten().tolist(), + fc2_scales.to(torch_dtype).flatten().tolist(), raw=False, ), helper.make_tensor( "fc3_scales", - ORT_DTYPE, + onnx_dtype, fc3_scale_shape, - fc3_scales.to(torch_type).flatten().tolist(), + fc3_scales.to(torch_dtype).flatten().tolist(), raw=False, ), ] ) graph_inputs = [ - helper.make_tensor_value_info("input", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] graph_inputs.append( helper.make_tensor_value_info( "router_probs", - ORT_DTYPE, + onnx_dtype, [sequence_length, num_experts], ) ) graph_outputs = [ - helper.make_tensor_value_info("output", ORT_DTYPE, [sequence_length, hidden_size]), + helper.make_tensor_value_info("output", onnx_dtype, [sequence_length, hidden_size]), ] graph = helper.make_graph( @@ -546,126 +570,127 @@ def __init__(self, config: PhiMoEConfig): class SparseMoeBlockORTHelper(nn.Module): - def __init__(self): + def __init__(self, quant_bits=0, onnx_dtype=None): super().__init__() + self.quant_bits = quant_bits + if onnx_dtype is None: + self.onnx_dtype = TensorProto.FLOAT16 if self.quant_bits > 0 else TensorProto.FLOAT + else: + self.onnx_dtype = onnx_dtype + self.np_type = numpy.float16 if self.onnx_dtype == TensorProto.FLOAT16 else numpy.float32 def create_ort_session(self, moe_onnx_graph): from onnxruntime import InferenceSession, SessionOptions # noqa: PLC0415 sess_options = SessionOptions() + sess_options.log_severity_level = 2 - cuda_providers = ["CUDAExecutionProvider"] - if cuda_providers[0] not in onnxruntime.get_available_providers(): + try: + ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=ort_provider) + except Exception as e: + print(f"Failed to create ONNX Runtime session with provider {ort_provider}: {e}") + print("Skipping ONNX Runtime execution for this test case.") return None - sess_options.log_severity_level = 2 - ort_session = InferenceSession(moe_onnx_graph, sess_options, providers=["CUDAExecutionProvider"]) - return ort_session def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: pass - def ort_forward(self, hidden_states: torch.Tensor, iobinding=False) -> torch.Tensor: + def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False) -> torch.Tensor: + if self.ort_sess is None: + return None + batch_size, sequence_length, hidden_dim = hidden_states.shape - hidden_states = hidden_states.view(-1, hidden_dim) + hidden_states_flat = hidden_states.view(-1, hidden_dim) # router_logits: (batch * sequence_length, n_experts) - router_logits = self.gate(hidden_states) - - ort_inputs = { - "input": numpy.ascontiguousarray(hidden_states.detach().numpy().astype(NP_TYPE)), - "router_probs": numpy.ascontiguousarray(router_logits.detach().numpy().astype(NP_TYPE)), - } - - ort_output = None - if self.ort_sess is not None: - if not iobinding: - ort_output = self.ort_sess.run(None, ort_inputs) - return torch.tensor(ort_output).reshape(batch_size, sequence_length, -1) # , router_logits - else: - self.ort_run_with_iobinding(ort_inputs) - return None + router_logits = self.gate(hidden_states_flat) - # print_tensor("input", ort_inputs["input"]) - # print_tensor("router_probs", ort_inputs["router_probs"]) - # print_tensor("fc1_experts_weights", self.moe_experts_weight1.detach().numpy()) - # print_tensor("fc2_experts_weights", self.moe_experts_weight2.detach().numpy()) - # print_tensor("fc3_experts_weights", self.moe_experts_weight3.detach().numpy()) - # print_tensor("output", ort_output[0]) + # Determine the correct torch dtype from the onnx_dtype + torch_dtype = onnx_to_torch_type_map[self.onnx_dtype] - return None + # Prepare tensors on the correct device for ORT inference with the CORRECT dtype + tensors = { + "input": hidden_states_flat.clone().to(device=device, dtype=torch_dtype), + "router_probs": router_logits.clone().to(device=device, dtype=torch_dtype), + "output": torch.zeros_like(hidden_states_flat, device=device, dtype=torch_dtype), + } - def ort_run_with_iobinding(self, ort_inputs, repeat=1000): + # Bind inputs and outputs to torch tensors directly. iobinding = self.ort_sess.io_binding() - device_id = torch.cuda.current_device() - - iobinding.bind_input( - name="input", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy(ort_inputs["input"], "cuda", device_id).data_ptr(), - ) - - iobinding.bind_input( - name="router_probs", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["router_probs"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - ort_inputs["router_probs"], "cuda", device_id - ).data_ptr(), - ) - - iobinding.bind_output( - name="output", - device_type="cuda", - device_id=device_id, - element_type=NP_TYPE, - shape=ort_inputs["input"].shape, - buffer_ptr=onnxruntime.OrtValue.ortvalue_from_numpy( - numpy.zeros(ort_inputs["input"].shape), "cuda", device_id - ).data_ptr(), - ) - - # warm up - for _ in range(5): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - import time # noqa: PLC0415 - - s = time.time() - for _ in range(repeat): - iobinding.synchronize_inputs() - self.ort_sess.run_with_iobinding(iobinding) - iobinding.synchronize_outputs() - e = time.time() - print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + for name, tensor in tensors.items(): + # Ensure tensor is on the globally defined device + if name == "output": + iobinding.bind_output( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + else: + iobinding.bind_input( + name=name, + device_type=tensor.device.type, + device_id=tensor.device.index or 0, + element_type=self.onnx_dtype, + shape=tensor.shape, + buffer_ptr=tensor.data_ptr(), + ) + + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + + if enable_performance_test: + import time # noqa: PLC0415 + + repeat = 1000 + s = time.time() + for _ in range(repeat): + iobinding.synchronize_inputs() + self.ort_sess.run_with_iobinding(iobinding) + iobinding.synchronize_outputs() + e = time.time() + print(f"MoE cuda kernel time: {(e - s) / repeat * 1000} ms") + + # The output tensor is on `device`. Reshape and return it. + return tensors["output"].reshape(batch_size, sequence_length, hidden_dim) def parity_check(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) torch_output = self.forward(hidden_state) ort_output = self.ort_forward(hidden_state) + + dtype_str = ort_dtype_name_map[self.onnx_dtype] + + # Maps "ort_type:quant_bits" to (atol, rtol) + ort_dtype_quant_bits_tolerance_map = { + "FP32:0": (5e-3, 1e-3), + "FP16:0": (5e-2, 1e-3), + "FP16:4": (3.0, 1e-2), + "FP16:8": (2.0, 1e-2), + "BF16:0": (1.0, 1e-2), + "BF16:4": (30.0, 1e-1), + "BF16:8": (20.0, 1e-1), + } + + atol, rtol = ort_dtype_quant_bits_tolerance_map[f"{dtype_str}:{self.quant_bits}"] if ort_output is not None: print( - "name:", - self.__class__.__name__, - " batch_size:", - self.batch_size, - " sequence_length:", - self.sequence_length, - " max_diff:", - (torch_output - ort_output).abs().max(), + f"name: {self.__class__.__name__}, quant_bits: {self.quant_bits}, dtype: {dtype_str}," + f" batch: {self.batch_size}, seq_len: {self.sequence_length}," + f" max_diff: {(torch_output.cpu() - ort_output.cpu()).abs().max()}" + ) + torch.testing.assert_close( + ort_output.cpu().to(torch.float32), torch_output.cpu().to(torch.float32), rtol=rtol, atol=atol ) - torch.testing.assert_close(ort_output.to(torch.float32), torch_output, rtol=THRESHOLD, atol=THRESHOLD) def benchmark_ort(self): - hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim) - self.ort_forward(hidden_state, iobinding=True) + hidden_state = torch.randn(self.batch_size, self.sequence_length, self.hidden_dim).to(device) + self.ort_forward(hidden_state, enable_performance_test=True) class SwitchMoE(SparseMoeBlockORTHelper): @@ -680,7 +705,7 @@ def __init__( eval_capacity=-1, activation="gelu", ): - super().__init__() + super().__init__(quant_bits=0) # SwitchMoE is not quantized self.batch_size = batch_size self.sequence_length = sequence_length self.num_experts = num_experts @@ -709,6 +734,7 @@ def __init__( self.moe_experts.bias1, self.moe_experts.weight2.transpose(1, 2), self.moe_experts.bias2, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -744,7 +770,7 @@ class MixtralSparseMoeBlock(SparseMoeBlockORTHelper): """ def __init__(self, config, batch_size, sequence_length): - super().__init__() + super().__init__(quant_bits=0) # Mixtral test is not quantized self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts @@ -778,6 +804,7 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2, self.moe_experts_weight3, self.top_k, + self.onnx_dtype, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -874,40 +901,41 @@ class PhiMoESparseMoeBlock(SparseMoeBlockORTHelper): and memory on padding. """ - def __init__(self, config, batch_size, sequence_length): - super().__init__() + def __init__(self, config, batch_size, sequence_length, quant_bits=0, onnx_dtype=None): + super().__init__(quant_bits, onnx_dtype) self.hidden_dim = config.hidden_size self.ffn_dim = config.intermediate_size self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise + use_quant = self.quant_bits > 0 # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) self.experts = nn.ModuleList([PhiMoEBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) - w1_list = [] - w2_list = [] - w3_list = [] - w1_scale_list = [] - w2_scale_list = [] - w3_scale_list = [] - if not USE_QUANT: + w1_list, w2_list, w3_list = [], [], [] + w1_scale_list, w2_scale_list, w3_scale_list = [], [], [] + + if not use_quant: for i in range(self.num_experts): w1_list.append(self.experts[i].w1.weight) w2_list.append(self.experts[i].w2.weight) w3_list.append(self.experts[i].w3.weight) else: + is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, False) - w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, False) - w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, False) + # Corrected quantization logic for per-output-channel quantization + w1_scale, pre_qweight1, w1_qdq = quant_dequant(self.experts[i].w1.weight, is_4_bit) + w2_scale, pre_qweight2, w2_qdq = quant_dequant(self.experts[i].w2.weight, is_4_bit) + w3_scale, pre_qweight3, w3_qdq = quant_dequant(self.experts[i].w3.weight, is_4_bit) self.experts[i].w1.weight.data = w1_qdq self.experts[i].w2.weight.data = w2_qdq self.experts[i].w3.weight.data = w3_qdq + # Transpose quantized weights to match the expected ONNX layout w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) w3_list.append(pre_qweight3) @@ -919,9 +947,9 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight2 = torch.stack(w2_list, dim=0) self.moe_experts_weight3 = torch.stack(w3_list, dim=0) - moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if USE_QUANT else None - moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if USE_QUANT else None + moe_experts_weight_scale1 = torch.stack(w1_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(w2_scale_list, dim=0) if use_quant else None + moe_experts_weight_scale3 = torch.stack(w3_scale_list, dim=0) if use_quant else None self.batch_size = batch_size self.sequence_length = sequence_length @@ -933,10 +961,12 @@ def __init__(self, config, batch_size, sequence_length): self.moe_experts_weight1, self.moe_experts_weight2, self.moe_experts_weight3, + self.top_k, + self.onnx_dtype, + self.quant_bits, moe_experts_weight_scale1, moe_experts_weight_scale2, moe_experts_weight_scale3, - self.top_k, ) self.ort_sess = self.create_ort_session(self.moe_onnx_graph) @@ -995,18 +1025,10 @@ def small_test_cases(): yield batch_size, sequence_length -def phi3_test_cases(): - # TODO: phi3 moe failed in long sequence lengths (max diff 0.22 > threshold 0.01), need investigation. - for batch_size in [1, 4, 16]: - for sequence_length in [128]: - yield batch_size, sequence_length - - +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestSwitchMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_switch_moe_parity(self, batch_size, sequence_length): - # if platform.system() == "Windows": - # pytest.skip("Skip on Windows") switch_moe = SwitchMoE( batch_size=batch_size, sequence_length=sequence_length, @@ -1015,26 +1037,411 @@ def test_switch_moe_parity(self, batch_size, sequence_length): hidden_features=1024, out_features=256, ) + switch_moe.to(device) switch_moe.parity_check() - # switch_moe.benchmark_ort() +# quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) +# since qMoE test requires tensorrt_llm for quant_dequant. We disable it in CI pipeline to avoid extra dependency. +quant_bits_list = [0] if pipeline_mode else [0, 8, 4] + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestMixtralMoE(unittest.TestCase): @parameterized.expand(small_test_cases()) def test_mixtral_moe_parity(self, batch_size, sequence_length): config = MixtralConfig(hidden_size=256, intermediate_size=1024) mixtral_moe = MixtralSparseMoeBlock(config, batch_size, sequence_length) + mixtral_moe.to(device) mixtral_moe.parity_check() - # mixtral_moe.benchmark_ort() +phi3_test_cases = list( + itertools.product( + [1, 4], # batch_size + [1, 32], # sequence_length + quant_bits_list, + ) +) + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") class TestPhiMoE(unittest.TestCase): - @parameterized.expand(phi3_test_cases()) - def test_phi3_moe_parity(self, batch_size, sequence_length): + @parameterized.expand(phi3_test_cases) + def test_phi3_moe_parity(self, batch_size, sequence_length, quant_bits): config = PhiMoEConfig(hidden_size=256, intermediate_size=1024) - phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length) + phi3_moe = PhiMoESparseMoeBlock(config, batch_size, sequence_length, quant_bits) + phi3_moe.to(device) phi3_moe.parity_check() - # phi3_moe.benchmark_ort() + + +# --------------------------------------------- +# The following test are for swiglu activation +# --------------------------------------------- +class SwigluMoeConfig: + def __init__( + self, + hidden_size=2048, + intermediate_size=2048, + num_experts_per_token=2, + num_local_experts=8, + ): + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_experts_per_token = num_experts_per_token + self.num_local_experts = num_local_experts + + +def swiglu(x: torch.Tensor): + dim = x.shape[-1] + x = x.view(-1, dim // 2, 2) + x_glu, x_linear = x[..., 0], x[..., 1] + y = x_glu * torch.sigmoid(1.702 * x_glu) * (x_linear + 1) + return y + + +class SwigluMlp(nn.Module): + def __init__(self, config): + super().__init__() + self.intermediate_size = config.intermediate_size + self.hidden_dim = config.hidden_size + self.w1 = nn.Linear(self.hidden_dim, 2 * self.intermediate_size, bias=True) + self.w2 = nn.Linear(self.intermediate_size, self.hidden_dim, bias=True) + + def forward(self, x): + x1 = self.w1(x) + y = swiglu(x1) + y = self.w2(y) + return y + + +# Note that the shape might not match the tensor shape. See Attention note in this file. +def make_onnx_intializer(name: str, tensor: torch.Tensor, shape, onnx_dtype): + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + if torch_dtype == torch.bfloat16: + numpy_vals_uint16 = tensor.to(torch.bfloat16).cpu().view(torch.uint16).numpy() + initializer = helper.make_tensor( + name=name, + data_type=TensorProto.BFLOAT16, + dims=shape, + vals=numpy_vals_uint16.tobytes(), + raw=True, + ) + else: + initializer = helper.make_tensor( + name=name, + data_type=onnx_dtype, + dims=shape, + vals=tensor.flatten().detach().cpu().numpy().astype(numpy.uint8).tolist() + if onnx_dtype == TensorProto.UINT8 + else tensor.detach().to(torch_dtype).flatten().tolist(), + raw=False, + ) + return initializer + + +def create_swiglu_moe_onnx_graph( + num_tokens: int, + num_experts: int, + hidden_size: int, + inter_size: int, + topk: int, + onnx_dtype: int, + quant_bits: int, + fc1_experts_weights: torch.Tensor, + fc1_experts_bias: torch.Tensor, + fc2_experts_weights: torch.Tensor, + fc2_experts_bias: torch.Tensor, + fc1_experts_weight_scale: torch.Tensor = None, + fc2_experts_weight_scale: torch.Tensor = None, +): + use_quant = quant_bits > 0 + op_name = "QMoE" if use_quant else "MoE" + + inputs = ( + [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_weight_scale", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_weight_scale", + "fc2_experts_bias", + ] + if use_quant + else [ + "input", + "router_probs", + "fc1_experts_weights", + "fc1_experts_bias", + "fc2_experts_weights", + "fc2_experts_bias", + ] + ) + + nodes = [ + helper.make_node( + op_name, + inputs, + ["output"], + "MoE_0", + k=topk, + normalize_routing_weights=1, + activation_type="swiglu", + domain="com.microsoft", + ), + ] + + if use_quant: + nodes[0].attribute.extend([helper.make_attribute("expert_weight_bits", quant_bits)]) + + components = 2 if quant_bits == 4 else 1 + # ATTENTION: Actual weight layout is like [num_experts, 2 * inter_size, hidden_size // components] + # Here we claim a different shape for the initializer to match the operator spec for weight tensor! + fc1_weight_shape = [num_experts, hidden_size, 2 * inter_size // components] + fc1_bias_shape = [num_experts, 2 * inter_size] + fc1_experts_weight_scale_shape = [num_experts, 2 * inter_size] + + fc2_weight_shape = [num_experts, inter_size, hidden_size // components] + fc2_bias_shape = [num_experts, hidden_size] + fc2_experts_weight_scale_shape = [num_experts, hidden_size] + + weight_onnx_type = TensorProto.UINT8 if use_quant else onnx_dtype + + torch_dtype = onnx_to_torch_type_map[onnx_dtype] + weight_torch_dtype = onnx_to_torch_type_map[weight_onnx_type] + + initializers = [ + make_onnx_intializer( + "fc1_experts_weights", fc1_experts_weights.to(weight_torch_dtype), fc1_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc1_experts_bias", fc1_experts_bias.to(torch_dtype), fc1_bias_shape, onnx_dtype), + make_onnx_intializer( + "fc2_experts_weights", fc2_experts_weights.to(weight_torch_dtype), fc2_weight_shape, weight_onnx_type + ), + make_onnx_intializer("fc2_experts_bias", fc2_experts_bias.to(torch_dtype), fc2_bias_shape, onnx_dtype), + ] + + if use_quant: + initializers.extend( + [ + make_onnx_intializer( + "fc1_experts_weight_scale", + fc1_experts_weight_scale.to(torch_dtype), + fc1_experts_weight_scale_shape, + onnx_dtype, + ), + make_onnx_intializer( + "fc2_experts_weight_scale", + fc2_experts_weight_scale.to(torch_dtype), + fc2_experts_weight_scale_shape, + onnx_dtype, + ), + ] + ) + + graph_inputs = [ + helper.make_tensor_value_info("input", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph_inputs.append( + helper.make_tensor_value_info( + "router_probs", + onnx_dtype, + [num_tokens, num_experts], + ) + ) + + graph_outputs = [ + helper.make_tensor_value_info("output", onnx_dtype, [num_tokens, hidden_size]), + ] + + graph = helper.make_graph( + nodes, + "MoE_Graph", + graph_inputs, + graph_outputs, + initializers, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +class SwigluMoEBlock(SparseMoeBlockORTHelper): + def __init__( + self, config: SwigluMoeConfig, batch_size: int, sequence_length: int, quant_bits: int = 0, onnx_dtype=None + ): + super().__init__(quant_bits, onnx_dtype=onnx_dtype) + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_token + use_quant = self.quant_bits > 0 + + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=True) + + self.experts = nn.ModuleList([SwigluMlp(config) for _ in range(self.num_experts)]) + + # For the ONNX MoE operator, weights must be transposed to [In, Out] format. + # Biases do not require transposition. + fc1_w_list, fc2_w_list = [], [] + fc1_b_list, fc2_b_list = [], [] + scale_1_list, scale_2_list = [], [] + + for expert in self.experts: + fc1_b_list.append(expert.w1.bias) + fc2_b_list.append(expert.w2.bias) + if not use_quant: + # ATTENTION: Weight tensor for CUDA shall have [E, out, in] memory layout just like Linear. + # But the initializer shape shall be [E, in, out] to match op spec. + fc1_w_list.append(expert.w1.weight) + fc2_w_list.append(expert.w2.weight) + else: + is_4_bit = self.quant_bits == 4 + + # quant_dequant expects [Out, In] format, matching nn.Linear.weight + scale1, pre_qweight1, w1_qdq = quant_dequant(expert.w1.weight, is_4_bit) + scale2, pre_qweight2, w2_qdq = quant_dequant(expert.w2.weight, is_4_bit) + + # Update the expert's weight with the dequantized version for the PyTorch reference. + expert.w1.weight.data = w1_qdq + expert.w2.weight.data = w2_qdq + + fc1_w_list.append(pre_qweight1) + fc2_w_list.append(pre_qweight2) + scale_1_list.append(scale1) + scale_2_list.append(scale2) + + # Stack the prepared tensors for the graph builder + fc1_experts_weights = torch.stack(fc1_w_list, dim=0) + fc2_experts_weights = torch.stack(fc2_w_list, dim=0) + fc1_experts_bias = torch.stack(fc1_b_list, dim=0) + fc2_experts_bias = torch.stack(fc2_b_list, dim=0) + + moe_experts_weight_scale1 = torch.stack(scale_1_list, dim=0) if use_quant else None + moe_experts_weight_scale2 = torch.stack(scale_2_list, dim=0) if use_quant else None + + self.batch_size = batch_size + self.sequence_length = sequence_length + + # Build the ONNX graph with the correctly shaped tensors + self.moe_onnx_graph = create_swiglu_moe_onnx_graph( + num_tokens=self.batch_size * self.sequence_length, + num_experts=self.num_experts, + hidden_size=self.hidden_dim, + inter_size=self.ffn_dim, + topk=self.top_k, + onnx_dtype=self.onnx_dtype, + quant_bits=self.quant_bits, + fc1_experts_weights=fc1_experts_weights, + fc1_experts_bias=fc1_experts_bias, + fc2_experts_weights=fc2_experts_weights, + fc2_experts_bias=fc2_experts_bias, + fc1_experts_weight_scale=moe_experts_weight_scale1, + fc2_experts_weight_scale=moe_experts_weight_scale2, + ) + + self.ort_sess = self.create_ort_session(self.moe_onnx_graph) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + This is the robust PyTorch reference implementation. It directly uses the + nn.Module experts, which is cleaner and less error-prone than manual matmul. + """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + router_logits = self.gate(hidden_states) + routing_weights, selected_experts = torch.topk(router_logits, self.top_k, dim=-1) + routing_weights = F.softmax(routing_weights, dim=1, dtype=torch.float) + + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None] + + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states + + +swiglu_test_cases = list( + itertools.product( + [1, 2], # batch_size + [1, 3], # sequence_length + quant_bits_list, # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +@unittest.skipIf(not use_cuda, "skipping moe test since it requires cuda environment.") +class TestSwigluMoE(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=256, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.parity_check() + + +def has_bf16_moe(): + if "CUDAExecutionProvider" not in onnxruntime.get_available_providers() or not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 8 + + +@unittest.skipIf(not has_bf16_moe(), "skipping bf16 moe tests.") +class TestSwigluMoeBf16(unittest.TestCase): + @parameterized.expand(swiglu_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + config = SwigluMoeConfig(hidden_size=64, intermediate_size=128, num_experts_per_token=2, num_local_experts=4) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits, onnx_dtype=TensorProto.BFLOAT16) + moe.to(device) + moe.parity_check() + + +perf_test_cases = list( + itertools.product( + [1], # batch_size + [128, 512, 1024, 2048, 4096], # sequence_length + [0, 8, 4], # quant_bits (0 for fp32/fp32, 8 for int8/fp16, 4 for int4/fp16) + ) +) + + +@unittest.skipIf(pipeline_mode or not use_cuda, "skipping performance test in CI pipeline.") +class TestSwigluMoEPerf(unittest.TestCase): + @parameterized.expand(perf_test_cases) + def test_swiglu_moe_parity(self, batch_size, sequence_length, quant_bits): + hidden_size = 2880 + intermediate_size = 2880 + num_experts_per_token = 8 + num_local_experts = 128 + config = SwigluMoeConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_experts_per_token=num_experts_per_token, + num_local_experts=num_local_experts, + ) + moe = SwigluMoEBlock(config, batch_size, sequence_length, quant_bits) + moe.to(device) + moe.benchmark_ort() if __name__ == "__main__": diff --git a/onnxruntime/test/testdata/conv_default_attrs.onnx b/onnxruntime/test/testdata/conv_default_attrs.onnx new file mode 100644 index 0000000000000..fc7ee58dee15e Binary files /dev/null and b/onnxruntime/test/testdata/conv_default_attrs.onnx differ diff --git a/onnxruntime/test/testdata/make_conv_default_attrs.py b/onnxruntime/test/testdata/make_conv_default_attrs.py new file mode 100644 index 0000000000000..fc092bf8b25fb --- /dev/null +++ b/onnxruntime/test/testdata/make_conv_default_attrs.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import numpy as np +import onnx + + +def main(): + inp_shape = (1, 2, 8, 8) + input_0 = onnx.helper.make_tensor_value_info("input_0", onnx.TensorProto.FLOAT, inp_shape) + output_0 = onnx.helper.make_tensor_value_info("output_0", onnx.TensorProto.FLOAT, None) + + weight_data = [ + [[[-1.5, 0.0], [0.2, 1.5]], [[-1.5, 0.0], [0.2, 1.5]]], + [[[-1.0, 0.0], [0.1333, 1.0]], [[-1.0, 0.0], [0.1333, 1.0]]], + ] + weight = onnx.numpy_helper.from_array(np.array(weight_data, dtype=np.float32), "weight") + bias = onnx.numpy_helper.from_array(np.array([0.0, 0.0], dtype=np.float32), "bias") + conv_node = onnx.helper.make_node("Conv", ["input_0", "weight", "bias"], ["output_0"], name="Conv0") + graph = onnx.helper.make_graph( + [conv_node], + "Convf32", + [input_0], + [output_0], + initializer=[weight, bias], + ) + opset_imports = [onnx.helper.make_opsetid("", 21)] + model = onnx.helper.make_model(graph, opset_imports=opset_imports) + model = onnx.shape_inference.infer_shapes(model) + + onnx.checker.check_model(model, True) + onnx.save_model(model, "conv_default_attrs.onnx") + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/wasm/pre-async.js b/onnxruntime/wasm/pre-async.js index 8c75dc7c5cf1e..1f8f17535e7d4 100644 --- a/onnxruntime/wasm/pre-async.js +++ b/onnxruntime/wasm/pre-async.js @@ -15,78 +15,20 @@ let initAsyncImpl = () => { // It removes some overhead in cwarp() and ccall() that we don't need. // // Currently in ASYNCIFY build, we only use this for the following functions: + // - OrtAppendExecutionProvider() // - OrtCreateSession() // - OrtRun() // - OrtRunWithBinding() // - OrtBindInput() // - // Note: about parameters "getFunc" and "setFunc": - // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // We need to wrap these functions with an async wrapper so that they can be called in an async context. // - // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a - // wrapper for OrtRun() like this (minified): - // ``` - // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); - // ``` - // - // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates - // a wrapper for OrtRun() like this (minified): - // ``` - // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // - // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once - // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will - // reset d._OrtRun to J.ka when the first time it is called. - // - // The difference is important because we need to design the async wrapper in a way that it can handle both cases. - // - // Now, let's look at how the async wrapper is designed to work for both cases: - // - // - Debug build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // - Release build: - // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. - // 2. When the first time `Module["initAsync"]` is called, `Module["_OrtRun"]` is re-assigned to a new async - // wrapper function. - // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this - // function: - // ``` - // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); - // ``` - // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). - // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored - // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. - // Value of `Module["_OrtRun"]` will not be changed again. - // - // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release - // build. - // - // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an - // exported function and set the new value of an exported function. - // - const wrapAsync = (func, getFunc, setFunc) => { + const wrapAsync = (func) => { return (...args) => { // cache the async data before calling the function. const previousAsync = Asyncify.currData; - const previousFunc = getFunc?.(); const ret = func(...args); - const newFunc = getFunc?.(); - if (previousFunc !== newFunc) { - // The exported function has been updated. - // Set the sync function reference to the new function. - func = newFunc; - // Set the exported function back to the async wrapper. - setFunc(previousFunc); - // Remove getFunc and setFunc. They are no longer needed. - setFunc = null; - getFunc = null; - } // If the async data has been changed, it means that the function started an async operation. if (Asyncify.currData != previousAsync) { @@ -101,11 +43,7 @@ let initAsyncImpl = () => { // replace the original functions with asyncified versions const wrapAsyncAPIs = (funcNames) => { for (const funcName of funcNames) { - Module[funcName] = wrapAsync( - Module[funcName], - () => Module[funcName], - (v) => (Module[funcName] = v) - ); + Module[funcName] = wrapAsync(Module[funcName]); } }; diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 561a76be5fa89..56fd3f1323e92 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -620,6 +620,7 @@ def generate_build_tree( ) generate_vcpkg_triplets_for_emscripten( build_dir, + configs, emscripten_root_path, not args.disable_rtti, not args.disable_wasm_exception_catching, @@ -627,25 +628,21 @@ def generate_build_tree( args.enable_address_sanitizer, ) elif args.android: - generate_android_triplets(build_dir, args.android_cpp_shared, args.android_api) + generate_android_triplets(build_dir, configs, args.android_cpp_shared, args.android_api) elif is_windows(): - generate_windows_triplets(build_dir, args.msvc_toolset) + generate_windows_triplets(build_dir, configs, args.msvc_toolset) elif is_macOS(): osx_target = args.apple_deploy_target if args.apple_deploy_target is None: osx_target = os.environ.get("MACOSX_DEPLOYMENT_TARGET") if osx_target is not None: log.info(f"Setting VCPKG_OSX_DEPLOYMENT_TARGET to {osx_target}") - generate_macos_triplets(build_dir, osx_target) + generate_macos_triplets(build_dir, configs, osx_target) else: # Linux, *BSD, AIX or other platforms - generate_linux_triplets(build_dir) + generate_linux_triplets(build_dir, configs) add_default_definition(cmake_extra_defines, "CMAKE_TOOLCHAIN_FILE", str(vcpkg_toolchain_path)) - vcpkg_install_options = generate_vcpkg_install_options(build_dir, args) - # VCPKG_INSTALL_OPTIONS is a CMake list. It must be joined by semicolons - # Therefore, if any of the option string contains a semicolon, it must be escaped - add_default_definition(cmake_extra_defines, "VCPKG_INSTALL_OPTIONS", ";".join(vcpkg_install_options)) # Choose the cmake triplet triplet = None if args.build_wasm: @@ -1251,6 +1248,16 @@ def generate_build_tree( ] env = {} if args.use_vcpkg: + # append VCPKG_INSTALL_OPTIONS + # + # VCPKG_INSTALL_OPTIONS is a CMake list. It must be joined by semicolons + # Therefore, if any of the option string contains a semicolon, it must be escaped + temp_cmake_args += [ + "-DVCPKG_INSTALL_OPTIONS={}".format( + ";".join(generate_vcpkg_install_options(Path(build_dir) / config, args)) + ) + ] + vcpkg_keep_env_vars = ["TRT_UPLOAD_AUTH_TOKEN"] if args.build_wasm: diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index c42f8e3219da4..82118148d35f9 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -342,7 +342,7 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for WebAssembly (WASM) platform builds.""" parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly.") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build WebAssembly static library.") - parser.add_argument("--emsdk_version", default="4.0.8", help="Specify version of emsdk.") + parser.add_argument("--emsdk_version", default="4.0.11", help="Specify version of emsdk.") parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD.") parser.add_argument("--enable_wasm_relaxed_simd", action="store_true", help="Enable WebAssembly Relaxed SIMD.") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threading.") diff --git a/tools/ci_build/get_docker_image.py b/tools/ci_build/get_docker_image.py index e656cedae5916..90947e534918d 100755 --- a/tools/ci_build/get_docker_image.py +++ b/tools/ci_build/get_docker_image.py @@ -71,11 +71,14 @@ def main(): log.info(f"Image: {full_image_name}") - dst_deps_file = Path(args.context) / "scripts" / "deps.txt" + dst_scripts_dir = Path(args.context) / "scripts" + dst_deps_file = dst_scripts_dir / "deps.txt" # The docker file may provide a special deps.txt in its docker context dir and uses that one. # Otherwise, copy a generic one from this repo's cmake dir. if not dst_deps_file.exists(): log.info(f"Copy deps.txt to : {dst_deps_file}") + if not dst_scripts_dir.exists(): + dst_scripts_dir.mkdir(parents=True, exist_ok=True) shutil.copyfile(Path(REPO_DIR) / "cmake" / "deps.txt", str(dst_deps_file)) if "manylinux" in args.dockerfile and args.multiple_repos: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index 3772b5e9c4c20..a5eb2ad216998 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -121,109 +121,6 @@ stages: parameters: StageSuffix: 'macOS_CPU_x64' -- template: templates/final-jar-testing.yml - parameters: - OS: Windows - PoolName: 'onnxruntime-Win-CPU-2022' - -- template: templates/final-jar-testing.yml - parameters: - OS: Linux - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' - -- template: templates/final-jar-testing.yml - parameters: - OS: MacOS - PoolName: 'macOS-14' - - -- stage: GPU_JAR_Testing - dependsOn: Setup - jobs: - - job: Final_Jar_Testing_Windows_GPU - workspace: - clean: all - pool: 'onnxruntime-Win2022-GPU-A10' - timeoutInMinutes: 60 - variables: - - name: runCodesignValidationInjection - value: false - - steps: - - template: templates/set-version-number-variables-step.yml - - - template: templates/jobs/download_win_gpu_library.yml - parameters: - CudaVersion: 12.2 - DownloadCUDA: true - DownloadTRT: true - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java-gpu - TargetPath: '$(Build.BinariesDirectory)\final-jar' - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Jar Tools' - ArtifactName: onnxruntime-java-tools - TargetPath: '$(Build.BinariesDirectory)\final-jar' - - - task: CmdLine@2 - inputs: - script: | - mkdir test - pushd test - jar xf $(Build.BinariesDirectory)\final-jar\testing.jar - popd - java -DUSE_CUDA=1 -jar junit-platform-console-standalone-1.6.2.jar -cp .;.\test;protobuf-java-3.25.5.jar;onnxruntime_gpu-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)\final-jar' - - - job: Final_Jar_Testing_Linux_GPU - workspace: - clean: all - pool: - name: 'Onnxruntime-Linux-GPU-A10' - os: linux - variables: - - name: runCodesignValidationInjection - value: false - - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 - timeoutInMinutes: 60 - steps: - - checkout: self - submodules: false - - template: templates/set-version-number-variables-step.yml - - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java-gpu - TargetPath: '$(Build.BinariesDirectory)/final-jar' - - - template: templates/get-docker-image-steps.yml - parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 - Context: tools/ci_build/github/linux/docker/ - DockerBuildArgs: " - --build-arg BUILD_UID=$( id -u ) - --build-arg BASEIMAGE=${{ variables.docker_base_image }} - --build-arg TRT_VERSION=${{ variables.linux_trt_version }} - " - Repository: onnxruntimeubi8packagestest - - - bash: | - docker run -e SYSTEM_COLLECTIONURI --rm \ - --gpus all \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Build.BinariesDirectory):/build \ - --volume /data/models:/build/models:ro \ - onnxruntimeubi8packagestest \ - /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) - displayName: 'Test' - - template: nuget/templates/test_win.yml parameters: AgentPool: 'onnxruntime-Win2022-GPU-A10' @@ -274,13 +171,14 @@ stages: clean: true submodules: none - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - ArtifactName: "Windows_Packaging_cuda_build_artifacts" - StepName: 'Download Pipeline Artifact - Windows GPU Packages Build' - TargetPath: '$(Build.BinariesDirectory)/RelWithDebInfo/' + - download: build + artifact: 'Windows_Packaging_cuda_build_artifacts' + displayName: 'Download Windows GPU Packages Build' - - template: templates/telemetry-steps.yml + - task: CmdLine@2 + inputs: + script: | + move $(Pipeline.Workspace)/build/Windows_Packaging_cuda_build_artifacts $(Build.BinariesDirectory)/RelWithDebInfo - template: templates/set-version-number-variables-step.yml @@ -323,7 +221,7 @@ stages: displayName: 'test' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests' workingDirectory: '$(Build.BinariesDirectory)' # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - template: templates/make_java_win_binaries.yml @@ -345,13 +243,15 @@ stages: clean: true submodules: none - - template: templates/flex-downloadPipelineArtifact.yml - parameters: - ArtifactName: "Windows_Packaging_tensorrt_build_artifacts" - StepName: 'Download Pipeline Artifact - Windows GPU Packages Build' - TargetPath: '$(Build.BinariesDirectory)/RelWithDebInfo/' - - template: templates/telemetry-steps.yml + - download: build + artifact: 'Windows_Packaging_tensorrt_build_artifacts' + displayName: 'Download Windows GPU Packages Build' + + - task: CmdLine@2 + inputs: + script: | + move $(Pipeline.Workspace)/build/Windows_Packaging_tensorrt_build_artifacts $(Build.BinariesDirectory)/RelWithDebInfo - template: templates/set-version-number-variables-step.yml @@ -394,7 +294,7 @@ stages: displayName: 'test' inputs: scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests $(TelemetryOption) ' + arguments: '--config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --test --enable_onnx_tests' workingDirectory: '$(Build.BinariesDirectory)' # Previous stage only assembles the java binaries, testing will be done in this stage with GPU machine - template: templates/make_java_win_binaries.yml diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index 0e807b7beb6e9..52268af15e776 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -157,7 +157,7 @@ extends: targetType: 'inline' script: | mkdir -p $(Build.BinariesDirectory)/osx-x64 - Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-x86_64* -Destination $(Build.BinariesDirectory)/osx-x64 + Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-x86_64* -Destination $(Build.BinariesDirectory)/osx-x64 mkdir -p $(Build.BinariesDirectory)/osx-arm64 Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-arm64* -Destination $(Build.BinariesDirectory)/osx-arm64 diff --git a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml new file mode 100644 index 0000000000000..24e17de06e6bd --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml @@ -0,0 +1,162 @@ +resources: + pipelines: + - pipeline: build + source: 'Zip-Nuget-Java-Nodejs Packaging Pipeline' + trigger: true + branch: main + +variables: + mavenVersion: '3.9.8' + +stages: +- template: templates/final-jar-testing-win.yml + parameters: + PoolName: 'onnxruntime-Win-CPU-2022' + +- template: templates/final-jar-testing-linux.yml + parameters: + OS: Linux + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' + +- template: templates/final-jar-testing-linux.yml + parameters: + OS: MacOS + PoolName: 'macOS-14' + +- stage: GPU_JAR_Testing + dependsOn: [] + jobs: + - job: Final_Jar_Testing_Windows_GPU + workspace: + clean: all + pool: 'onnxruntime-Win2022-GPU-A10' + timeoutInMinutes: 60 + variables: + - name: runCodesignValidationInjection + value: false + + steps: + - template: templates/set-version-number-variables-step.yml + + - template: templates/jobs/download_win_gpu_library.yml + parameters: + CudaVersion: 12.2 + DownloadCUDA: true + DownloadTRT: true + + - template: templates/download_maven_for_tests.yml + - download: build + artifact: 'onnxruntime-java-gpu' + displayName: 'Download Final Jar' + - script: | + move $(Pipeline.Workspace)\build\onnxruntime-java-gpu\*.jar $(Pipeline.Workspace)\build\onnxruntime-java\ + + - task: PowerShell@2 + displayName: 'Run Java Tests with PowerShell' + inputs: + targetType: 'inline' + script: | + # Exit script on any error + $ErrorActionPreference = "Stop" + + cd $(Pipeline.Workspace)/build/onnxruntime-java + del *.asc + del *.sha256 + del *.sha512 + del *.pom + cd .. + mkdir tests + cd tests + jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + del $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + dir $(Pipeline.Workspace)/build/tests + Write-Host "Running JUnit Tests..." + & java -DUSE_CUDA=1 ` + -cp "$(Pipeline.Workspace)\build\tests;$(Pipeline.Workspace)\build\onnxruntime-java\*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)\build\tests ` + --fail-if-no-tests --disable-banner --reports-dir "$($env:Build_ArtifactStagingDirectory)/TestResults" + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' + failTaskOnFailedTests: true + + + - job: Final_Jar_Testing_Linux_GPU + workspace: + clean: all + pool: + name: 'Onnxruntime-Linux-GPU-A10' + variables: + - name: runCodesignValidationInjection + value: false + - name: docker_base_image + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250724.1 + timeoutInMinutes: 60 + steps: + - checkout: self + submodules: false + + - template: templates/set-version-number-variables-step.yml + + - bash: | + sudo apt-get install -y msopenjdk-17 + dpkg -l msopenjdk-17 + + - bash: | + echo "Downloading and installing Maven $(mavenVersion) for Linux..." + MAVEN_DIR="$(Agent.TempDirectory)/apache-maven-$(mavenVersion)" + # Download Maven binary + wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz + + # Extract to the temp directory + mkdir -p ${MAVEN_DIR} + tar -xzf $(Agent.TempDirectory)/maven.tar.gz -C $(Agent.TempDirectory) + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]${MAVEN_DIR}/bin" + displayName: 'Install Maven (Linux)' + + - script: | + echo "Maven is now on the PATH." + mvn --version + + - download: build + artifact: 'onnxruntime-java-gpu' + displayName: 'Download Final Jar' + + # Rename the downloaded folder + - script: | + mv $(Pipeline.Workspace)/build/onnxruntime-java-gpu $(Pipeline.Workspace)/build/onnxruntime-java + + - task: Maven@4 + displayName: 'Download Dependencies' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'Path' + jdkDirectory: '/usr/lib/jvm/msopenjdk-17-amd64' + jdkVersionOption: 'Default' + mavenVersionOption: 'Default' + + # Now all the jars are in the $(Pipeline.Workspace)/build folder + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ variables.docker_base_image }} --build-arg TRT_VERSION=${{ variables.linux_trt_version }}" + Repository: onnxruntimeubi8packagestest + + - bash: | + docker run --network=none --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Pipeline.Workspace)/build:/build \ + --volume /data/models:/build/models:ro \ + onnxruntimeubi8packagestest \ + /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) + displayName: 'Test' diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index b304ccdb4c533..0410001d77d13 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -65,6 +65,12 @@ jobs: clean: true submodules: none + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + architecture: 'x64' + - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml index 6c998f9c3da13..ae595bbf0c96b 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml @@ -4,8 +4,13 @@ steps: displayName: 'Download NPM_packages' artifact: 'NPM_packages' -- script: | - mv $(Pipeline.Workspace)/build/NPM_packages '$(Build.BinariesDirectory)/nodejs-artifact' + +- task: PowerShell@2 + displayName: 'Move Artifact Directory' + inputs: + targetType: 'inline' + script: | + Move-Item -Path "$(Pipeline.Workspace)/build/NPM_packages" -Destination "$(Build.BinariesDirectory)/nodejs-artifact" - script: mkdir e2e_test workingDirectory: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index bb4f600395ac9..4dd19ce2c250c 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -11,9 +11,7 @@ stages: clean: all timeoutInMinutes: 120 pool: - name: 'Azure Pipelines' - image: 'macOS-15' - os: 'macOS' + vmImage: 'macOS-15' variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml index 3615f9f7c0960..7fdddfa0d03f5 100644 --- a/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/npm-packaging-pipeline.yml @@ -50,6 +50,11 @@ extends: sourceAnalysisPool: name: onnxruntime-Win-CPU-2022 os: windows + codeql: + compiled: + enabled: false + justificationForDisabling: 'CodeQL causes the React Native Android tests to fail when trying to load Linux x64 .so' + stages: - template: templates/web-ci.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index c1b83c5e579dc..f4a62208059c8 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -316,7 +316,21 @@ stages: MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true + PYTHON_VERSION: '3.11' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.12' + + - template: ../templates/py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: ${{ parameters.qnn_sdk_version }} + BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} + PYTHON_VERSION: '3.13' - ${{ if eq(parameters.enable_windows_arm64ec_qnn, true) }}: - stage: Python_Packaging_Windows_arm64ec_QNN @@ -327,7 +341,6 @@ stages: MACHINE_POOL: 'Onnxruntime-QNNEP-Windows-2022-CPU' QNN_SDK: ${{ parameters.qnn_sdk_version }} BUILD_PY_PARAMETERS: ${{ parameters.build_py_parameters }} - is1ES: true - ${{ if eq(parameters.enable_windows_x64_qnn, true) }}: - stage: Python_Packaging_Windows_x64_QNN diff --git a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml index 6e6fb98e6e68c..7370910eb1e28 100644 --- a/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/android-java-api-aar-test.yml @@ -58,7 +58,7 @@ jobs: mkdir -p android_test/android/app/libs cd android_test/android cp -av $(Build.SourcesDirectory)/java/src/test/android/* ./ - cp $(Pipeline.Workspace)/build/onnxruntime-android-full-aar/${{parameters.packageName}}-$(OnnxRuntimeVersion)${{parameters.ReleaseVersionSuffix}}.aar app/libs/${{parameters.packageName}}.aar + cp $(Pipeline.Workspace)/build/${{parameters.artifactName}}/${{parameters.packageName}}-$(OnnxRuntimeVersion)${{parameters.ReleaseVersionSuffix}}.aar app/libs/${{parameters.packageName}}.aar displayName: Copy Android test files and AAR to android_test directory workingDirectory: $(Build.BinariesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml b/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml new file mode 100644 index 0000000000000..e53544458d494 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/download_maven_for_tests.yml @@ -0,0 +1,28 @@ +steps: +- pwsh: | + echo "Downloading and installing Maven $(mavenVersion) for Windows..." + $MAVEN_DIR = "$(Agent.TempDirectory)\apache-maven-$(mavenVersion)" + # Download Maven binary + Invoke-WebRequest -Uri "https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.zip" -OutFile "$(Agent.TempDirectory)\maven.zip" + + # Extract to the temp directory + Expand-Archive -Path "$(Agent.TempDirectory)\maven.zip" -DestinationPath "$(Agent.TempDirectory)" + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]$MAVEN_DIR\bin" + + +- script: | + echo "Maven is now on the PATH." + mvn --version + +- task: Maven@4 + displayName: 'Download Java Dependencies' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' \ No newline at end of file diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml new file mode 100644 index 0000000000000..9d1cb58aee8bb --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml @@ -0,0 +1,107 @@ +# Run Java tests on CPU machines with the CPU java package + +parameters: +- name: OS + displayName: Operating System + type: string + +- name: PoolName + type: string + +stages: +- stage: Final_Jar_Testing_${{parameters.OS}} + dependsOn: [] + jobs: + - job: Final_Jar_Testing_${{parameters.OS}} + workspace: + clean: all + ${{ if eq(parameters.OS, 'MacOS') }}: + pool: + vmImage: 'macOS-15' + ${{ if eq(parameters.OS, 'Linux') }}: + pool: + name: ${{ parameters.PoolName }} + variables: + - name: runCodesignValidationInjection + value: false + timeoutInMinutes: 60 + steps: + - template: set-version-number-variables-step.yml + + - bash: | + echo "Downloading and installing Maven $(mavenVersion) for Linux..." + MAVEN_DIR="$(Agent.TempDirectory)/apache-maven-$(mavenVersion)" + # Download Maven binary + wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz + + # Extract to the temp directory + mkdir -p ${MAVEN_DIR} + tar -xzf $(Agent.TempDirectory)/maven.tar.gz -C $(Agent.TempDirectory) + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]${MAVEN_DIR}/bin" + displayName: 'Install Maven (Linux)' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Linux')) + + - script: | + echo "Maven is now on the PATH." + mvn --version + + - download: build + artifact: 'onnxruntime-java' + displayName: 'Download Final Jar' + + - task: Maven@4 + displayName: 'Download Dependencies into App Folder' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' + + - task: Bash@3 + displayName: 'Run Java Tests on Linux/macOS' + condition: and(succeeded(), in(variables['Agent.OS'], 'Linux', 'Darwin')) + inputs: + targetType: 'inline' + script: | + set -e -x + cd $(Pipeline.Workspace)/build/onnxruntime-java + rm -f *.asc + rm -f *.sha256 + rm -f *.sha512 + rm -f *.pom + ls + cd .. + mkdir tests + cd tests + jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + rm -f $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + ls $(Pipeline.Workspace)/build/tests + echo "Java Version" + java -version + + # Set the correct library path based on the OS + os_name=$(uname) + if [[ "$os_name" == "Linux" ]]; then + echo "Platform: Linux. Setting LD_LIBRARY_PATH." + export LD_LIBRARY_PATH="$(pwd):$LD_LIBRARY_PATH" + java -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ + --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" + elif [[ "$os_name" == "Darwin" ]]; then + echo "Platform: macOS. Setting DYLD_LIBRARY_PATH." + export DYLD_LIBRARY_PATH="$(pwd):$DYLD_LIBRARY_PATH" + java -DUSE_WEBGPU=1 -DUSE_COREML=1 -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ + --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" + fi + + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' + failTaskOnFailedTests: true diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml new file mode 100644 index 0000000000000..af7fa176d2ac0 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml @@ -0,0 +1,82 @@ +parameters: +- name: PoolName + type: string + +stages: +- stage: Final_Jar_Testing_Windows + dependsOn: [] + jobs: + - job: Final_Jar_Testing_Windows + workspace: + clean: all + pool: + name: ${{ parameters.PoolName }} + variables: + - name: runCodesignValidationInjection + value: false + timeoutInMinutes: 60 + steps: + - template: set-version-number-variables-step.yml + + - pwsh: | + echo "Downloading and installing Maven $(mavenVersion) for Windows..." + $MAVEN_DIR = "$(Agent.TempDirectory)\apache-maven-$(mavenVersion)" + # Download Maven binary + Invoke-WebRequest -Uri "https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.zip" -OutFile "$(Agent.TempDirectory)\maven.zip" + + # Extract to the temp directory + Expand-Archive -Path "$(Agent.TempDirectory)\maven.zip" -DestinationPath "$(Agent.TempDirectory)" + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]$MAVEN_DIR\bin" + displayName: 'Install Maven (Windows)' + + + - script: | + echo "Maven is now on the PATH." + mvn --version + + - download: build + artifact: 'onnxruntime-java' + displayName: 'Download Final Jar' + + - task: Maven@4 + displayName: 'Download Dependencies into App Folder' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' + + - task: PowerShell@2 + displayName: 'Run Java Tests on Windows' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT')) + inputs: + targetType: 'inline' + script: | + $ErrorActionPreference = "Stop" + cd $(Pipeline.Workspace)/build/onnxruntime-java + del *.asc + del *.sha256 + del *.sha512 + del *.pom + cd .. + mkdir tests + cd tests + jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + del $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + dir $(Pipeline.Workspace)/build/tests + Write-Host "Running JUnit Tests..." + & java ` + -cp "$(Pipeline.Workspace)\build\tests;$(Pipeline.Workspace)\build\onnxruntime-java\*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)\build\tests ` + --fail-if-no-tests --disable-banner --reports-dir "$($env:Build_ArtifactStagingDirectory)/TestResults" + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' + failTaskOnFailedTests: true diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml deleted file mode 100644 index bc40de130740a..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml +++ /dev/null @@ -1,80 +0,0 @@ -parameters: -- name: OS - displayName: Opserating System - type: string - -- name: PoolName - type: string - -stages: -- stage: Final_Jar_Testing_${{parameters.OS}} - jobs: - - job: Final_Jar_Testing_${{parameters.OS}} - workspace: - clean: all - ${{ if eq(parameters.OS, 'MacOS') }}: - pool: - name: 'Azure Pipelines' - image: macOS-14 - os: macOS - ${{ if eq(parameters.OS, 'Linux') }}: - pool: - name: ${{ parameters.PoolName }} - os: linux - ${{ if eq(parameters.OS, 'Windows') }}: - pool: - name: ${{ parameters.PoolName }} - os: windows - variables: - - name: runCodesignValidationInjection - value: false - timeoutInMinutes: 60 - - steps: - - template: set-version-number-variables-step.yml - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Final Jar' - ArtifactName: onnxruntime-java - TargetPath: '$(Build.BinariesDirectory)/final-jar' - - - template: flex-downloadPipelineArtifact.yml - parameters: - StepName: 'Download Jar Tools' - ArtifactName: onnxruntime-java-tools - TargetPath: '$(Build.BinariesDirectory)/final-jar' - - - ${{ if eq(parameters.OS, 'Windows') }}: - - task: CmdLine@2 - inputs: - script: | - mkdir test - pushd test - jar xf $(Build.BinariesDirectory)\final-jar\testing.jar - popd - java -jar junit-platform-console-standalone-1.6.2.jar -cp .;.\test;protobuf-java-3.25.5.jar;onnxruntime-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)\final-jar' - - ${{ else }}: - - task: Bash@3 - inputs: - targetType: 'inline' - script: | - set -e -x - echo "Java Version" - java -version - mkdir test - pushd test - jar xf '$(Build.BinariesDirectory)/final-jar/testing.jar' - popd - # if you want to run the tests in the power shell, you need to replace ':' to ';', that is, "-cp .;.\test;protobuf-java-3.25.5.jar;onnxruntime-$(OnnxRuntimeVersion).jar" - java -jar ./junit-platform-console-standalone-1.6.2.jar -cp .:./test:./protobuf-java-3.25.5.jar:./onnxruntime-$(OnnxRuntimeVersion).jar --scan-class-path --fail-if-no-tests --disable-banner - workingDirectory: '$(Build.BinariesDirectory)/final-jar' - env: - ${{ if eq(parameters.OS, 'MacOS') }}: - DYLD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(DYLD_LIBRARY_PATH)' - ${{ if eq(parameters.OS, 'Linux') }}: - LD_LIBRARY_PATH: '$(Build.BinariesDirectory)/final-jar/test:$(LD_LIBRARY_PATH)' - - - ${{ if eq(parameters['OS'], 'MacOS') }}: - - template: use-xcode-version.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_training_test_data.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_training_test_data.yml deleted file mode 100644 index 8f6434f7ac40d..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_training_test_data.yml +++ /dev/null @@ -1,8 +0,0 @@ -steps: - - script: | - azcopy cp --recursive https://lotusscus.blob.core.windows.net/orttrainingtestdatascus/mnist/ $(Agent.TempDirectory) - displayName: 'Download Training Test Data MNIST' - - - script: | - ls -al $(Agent.TempDirectory)/mnist - displayName: 'Print contents of Training Test Data MNIST' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index 674f16d8e9332..681138a5ab3d1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -20,9 +20,16 @@ parameters: steps: - ${{ if eq(parameters.DownloadCUDA, true) }}: - - powershell: | - azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }} $(Agent.TempDirectory) + - task: AzureCLI@2 displayName: 'Download CUDA SDK v${{ parameters.CudaVersion }}' + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }} $(Agent.TempDirectory) + - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\bin;$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}\extras\CUPTI\lib64" displayName: 'Append CUDA SDK Directory to PATH' @@ -31,6 +38,7 @@ steps: inputs: script: | echo %PATH% + dir $(Agent.TempDirectory) displayName: 'Print PATH after download CUDA SDK' - ${{ if eq(parameters.DownloadTRT, true) }}: @@ -51,9 +59,16 @@ steps: echo $(trtCudaVersion) && echo TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) displayName: Get trtCudaVersion and Directory Name - - powershell: | - azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) $(Agent.TempDirectory) + + - task: AzureCLI@2 displayName: 'Download TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion)' + inputs: + azureSubscription: AIInfraBuildOnnxRuntimeOSS + scriptType: 'batch' + scriptLocation: 'inlineScript' + inlineScript: | + set AZCOPY_AUTO_LOGIN_TYPE=AZCLI + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) $(Agent.TempDirectory) - powershell: | Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion)\lib" diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index ef0f4c6e0883c..586f7a2496eb4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -83,20 +83,23 @@ jobs: versionSpec: '3.12' addToPath: true architecture: $(buildArch) + - task: NodeTool@0 + inputs: + versionSpec: '22.x' - ${{if eq(parameters.WithCache, true)}}: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.8 ccache-git-emscripten-64bit - ./emsdk activate 4.0.8 ccache-git-emscripten-64bit + ./emsdk install 4.0.11 ccache-git-emscripten-64bit + ./emsdk activate 4.0.11 ccache-git-emscripten-64bit displayName: 'emsdk install and activate ccache for emscripten' - ${{if eq(parameters.WithCache, false)}}: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.8 - ./emsdk activate 4.0.8 + ./emsdk install 4.0.11 + ./emsdk activate 4.0.11 displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 761c551e9f4d9..3c2ef4741f049 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -4,6 +4,10 @@ parameters: type: string default: 'onnxruntime-qnn-windows-vs-2022-arm64' +- name: PYTHON_VERSION + type: string + default: '3.11' + - name: QNN_SDK displayName: QNN SDK Version type: string @@ -19,13 +23,8 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: -- job: Win_py_arm64_qnn_Wheels +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 210 workspace: clean: all @@ -48,41 +47,21 @@ jobs: outputs: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) - artifactName: onnxruntime_qnn_arm64_$(PythonVersion) - - strategy: - matrix: - Python311_arm64: - PythonVersion: '3.11.0' - LocalPythonDir: 'C:\Python\Python311' - Python312_arm64: - PythonVersion: '3.12.6' - LocalPythonDir: 'C:\Python\Python312' - Python313_arm64: - PythonVersion: '3.13.2' - LocalPythonDir: 'C:\Python\Python313' + artifactName: onnxruntime_qnn_arm64_${{ parameters.PYTHON_VERSION }} + variables: GRADLE_OPTS: '-Dorg.gradle.daemon=false' VSGenerator: 'Visual Studio 17 2022' steps: - checkout: self clean: true - submodules: recursive + submodules: none - template: telemetry-steps.yml - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - XCOPY /s /y /h /e /c /q "$(LocalPythonDir)\*.*" $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion) - DIR $(Agent.ToolsDirectory)\Python\$(PythonVersion)\arm64 - displayName: Copy python $(PythonVersion) version to agent tools directory - - task: UsePythonVersion@0 inputs: - versionSpec: $(PythonVersion) + versionSpec: ${{ parameters.PYTHON_VERSION }} addToPath: true architecture: 'arm64' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml index 74cae38393ea6..c8d37457a1034 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64ec-qnn.yml @@ -19,11 +19,6 @@ parameters: type: string default: '' -- name: is1ES - displayName: 'Whether the pipeline is running in 1ES' - type: boolean - default: false - jobs: - job: Win_py_x64_qnn_Wheels timeoutInMinutes: 210 diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 7ebf5394e4530..66d1cd1687d99 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -61,16 +61,6 @@ jobs: # because the python bindings also use the USE__PROVIDER_INTERFACE preprocessor macros. ExtraQnnBuildArgs: '--enable_generic_interface --build_wheel' steps: - - - script: | - MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ - COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete - DIR $(Agent.ToolsDirectory)\Python - DIR $(Agent.ToolsDirectory)\Python\3.11.0 - DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 - displayName: Copy python 3.11.0 version to agent tools directory - - task: UsePythonVersion@0 inputs: versionSpec: '3.x' diff --git a/tools/ci_build/github/linux/java_linux_final_test.sh b/tools/ci_build/github/linux/java_linux_final_test.sh index 2699d488acbb8..71eb24dc7a1e2 100755 --- a/tools/ci_build/github/linux/java_linux_final_test.sh +++ b/tools/ci_build/github/linux/java_linux_final_test.sh @@ -20,23 +20,22 @@ EXIT_CODE=1 uname -a -cd "$BINARY_DIR/final-jar" - -mkdir test +cd "$BINARY_DIR/onnxruntime-java" +rm -f *.asc +rm -f *.sha256 +rm -f *.sha512 +rm -f *.pom +ls +cd .. +mkdir tests +cd tests +jar xf ../onnxruntime-java/testing.jar +rm -f ../onnxruntime-java/testing.jar +echo "Java Version" +java -version echo "Directories created" echo "Library path:" "$LD_LIBRARY_PATH" -pushd test -jar xf "$BINARY_DIR/final-jar/testing.jar" -popd - -curl -O -sSL https://oss.sonatype.org/service/local/repositories/releases/content/org/junit/platform/junit-platform-console-standalone/1.6.2/junit-platform-console-standalone-1.6.2.jar -curl -O -sSL https://oss.sonatype.org/service/local/repositories/releases/content/com/google/protobuf/protobuf-java/3.25.5/protobuf-java-3.25.5.jar -java -DUSE_CUDA=1 -jar ./junit-platform-console-standalone-1.6.2.jar -cp .:./test:./protobuf-java-3.25.5.jar:./onnxruntime_gpu-"${VERSION_NUMBER}".jar --scan-class-path --fail-if-no-tests --disable-banner - - -EXIT_CODE=$? - -set -e -exit $EXIT_CODE +java -DUSE_CUDA=1 -cp "$BINARY_DIR/tests:$BINARY_DIR/onnxruntime-java/*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$BINARY_DIR/tests \ + --fail-if-no-tests --disable-banner diff --git a/tools/ci_build/java/pom.xml b/tools/ci_build/java/pom.xml new file mode 100644 index 0000000000000..d0856db7cee70 --- /dev/null +++ b/tools/ci_build/java/pom.xml @@ -0,0 +1,32 @@ + + 4.0.0 + + com.example + dependency-downloader + 1.0.0 + pom + + + UTF-8 + 3.25.5 + 1.10.2 + + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + + + + org.junit.platform + junit-platform-console-standalone + ${junit.platform.version} + test + + + + \ No newline at end of file diff --git a/tools/python/run_packaging_pipelines.py b/tools/python/run_packaging_pipelines.py index 10764d0b33871..2ae66332781b9 100644 --- a/tools/python/run_packaging_pipelines.py +++ b/tools/python/run_packaging_pipelines.py @@ -53,6 +53,7 @@ class EvaluationResult(TypedDict): pipeline: PipelineDetails packaging_type: Literal["nightly", "release", "none"] + has_pre_release_params: bool # --- Configuration --- @@ -121,6 +122,7 @@ def evaluate_single_pipeline( headers = {"Authorization": f"Bearer {token}"} thirty_days_ago = datetime.utcnow() - timedelta(days=DAYS_SINCE_LAST_RUN) target_repo_path = urlparse(TARGET_REPO_URL).path.strip("/") + has_pre_release_params = False try: detail_url = pipeline_summary.get("url") @@ -230,8 +232,24 @@ def evaluate_single_pipeline( ) return None + # Check for pre-release parameters if it's a release packaging pipeline + if packaging_type == "release": + yaml_params = data.get("parameters", []) + if isinstance(yaml_params, list): + param_names = {p.get("name") for p in yaml_params if isinstance(p, dict)} + if ( + "PreReleaseVersionSuffixString" in param_names + and "PreReleaseVersionSuffixNumber" in param_names + ): + has_pre_release_params = True + print(" - OK: Found pre-release versioning parameters.") + print(f" - MATCH: '{pipeline_name}' matches all criteria.") - return {"pipeline": pipeline_details, "packaging_type": packaging_type} + return { + "pipeline": pipeline_details, + "packaging_type": packaging_type, + "has_pre_release_params": has_pre_release_params, + } except KeyError as e: print(f" - SKIPPING '{pipeline_name}': Missing expected key {e} in pipeline details.") @@ -270,7 +288,6 @@ def cancel_running_builds(pipeline_id: int, branch: str, token: str, project: st ) headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"} - # The 'builds' API uses 'branchName' to filter. Based on the swagger and behavior, it expects the full ref name. builds_url = f"https://dev.azure.com/{ADO_ORGANIZATION}/{project}/_apis/build/builds?definitions={pipeline_id}&branchName={branch}&statusFilter=inProgress,notStarted&api-version=7.1" try: @@ -306,6 +323,9 @@ def trigger_pipeline( nightly_override: str | None, release_override: str | None, packaging_type: str, + has_pre_release_params: bool, + pre_release_suffix_string: str | None, + pre_release_suffix_number: int | None, ) -> int | None: """Triggers a pipeline and returns the new build ID.""" print(f"\n--- Triggering Pipeline ID: {pipeline_id} on branch '{branch}' in project '{project}' ---") @@ -313,13 +333,26 @@ def trigger_pipeline( run_url = f"https://dev.azure.com/{ADO_ORGANIZATION}/{project}/_apis/pipelines/{pipeline_id}/runs?api-version=7.1-preview.1" payload: dict[str, any] = {"resources": {"repositories": {"self": {"refName": branch}}}} + template_params: dict[str, any] = {} if nightly_override is not None and packaging_type == "nightly": print(f"Overriding NIGHTLY_BUILD variable to '{nightly_override}'.") payload["variables"] = {"NIGHTLY_BUILD": {"value": nightly_override}} elif release_override is not None and packaging_type == "release": print(f"Overriding IsReleaseBuild parameter to '{release_override}'.") - payload["templateParameters"] = {"IsReleaseBuild": release_override} + template_params["IsReleaseBuild"] = release_override + + # Add pre-release parameters if the pipeline supports them and user provided them + if has_pre_release_params: + if pre_release_suffix_string is not None: + print(f"Setting PreReleaseVersionSuffixString parameter to '{pre_release_suffix_string}'.") + template_params["PreReleaseVersionSuffixString"] = pre_release_suffix_string + if pre_release_suffix_number is not None: + print(f"Setting PreReleaseVersionSuffixNumber parameter to {pre_release_suffix_number}.") + template_params["PreReleaseVersionSuffixNumber"] = pre_release_suffix_number + + if template_params: + payload["templateParameters"] = template_params try: response = requests.post(run_url, headers=headers, data=json.dumps(payload)) @@ -327,12 +360,12 @@ def trigger_pipeline( build_info = response.json() build_id = build_info.get("id") print(f"Successfully triggered build. Build ID: {build_id}") - print(f" Web URL: {build_info.get('_links', {}).get('web', {}).get('href')}") + print(f" Web URL: {build_info.get('_links', {}).get('web', {}).get('href')}") return build_id except requests.exceptions.RequestException as e: print(f"ERROR triggering pipeline: {e}") if e.response is not None: - print(f" Response: {e.response.text}") + print(f" Response: {e.response.text}") return None @@ -351,6 +384,24 @@ def main(): choices=["nightly", "release"], help="Specify the build mode for packaging pipelines (nightly or release). This sets NIGHTLY_BUILD and IsReleaseBuild accordingly.", ) + parser.add_argument( + "--no-cancel", + action="store_true", + dest="no_cancel_builds", + help="Do not cancel existing running builds for the pipeline before triggering a new one.", + ) + # New arguments for pre-release versioning + parser.add_argument( + "--pre-release-suffix-string", + choices=["alpha", "beta", "rc", "none"], + help="Suffix for pre-release versions (e.g., 'rc'). Requires the pipeline to have the parameter.", + ) + parser.add_argument( + "--pre-release-suffix-number", + type=int, + help="Number for pre-release versions (e.g., '1'). Requires the pipeline to have the parameter.", + ) + args = parser.parse_args() project = "Lotus" @@ -358,6 +409,11 @@ def main(): branch_for_yaml_fetch = args.branch is_pr_build = False + if (args.pre_release_suffix_string and not args.pre_release_suffix_number) or ( + not args.pre_release_suffix_string and args.pre_release_suffix_number + ): + parser.error("--pre-release-suffix-string and --pre-release-suffix-number must be used together.") + if args.pr: print(f"--- Pull Request Mode Activated for PR #{args.pr} ---") project = "PublicPackages" @@ -400,9 +456,24 @@ def main(): for result in pipelines_to_trigger: pipeline = result["pipeline"] packaging_type = result["packaging_type"] - cancel_running_builds(pipeline["id"], branch_for_trigger, token, project) + has_pre_release_params = result["has_pre_release_params"] + + if not args.no_cancel_builds: + cancel_running_builds(pipeline["id"], branch_for_trigger, token, project) + else: + print(f"\nSkipping cancellation for Pipeline ID: {pipeline['id']} as per --no-cancel flag.") + trigger_pipeline( - pipeline["id"], token, branch_for_trigger, project, nightly_override, release_override, packaging_type + pipeline_id=pipeline["id"], + token=token, + branch=branch_for_trigger, + project=project, + nightly_override=nightly_override, + release_override=release_override, + packaging_type=packaging_type, + has_pre_release_params=has_pre_release_params, + pre_release_suffix_string=args.pre_release_suffix_string, + pre_release_suffix_number=args.pre_release_suffix_number, ) diff --git a/tools/python/util/vcpkg_helpers.py b/tools/python/util/vcpkg_helpers.py index 1fbb4a06b3c2b..a9f1420946354 100644 --- a/tools/python/util/vcpkg_helpers.py +++ b/tools/python/util/vcpkg_helpers.py @@ -93,8 +93,29 @@ def add_copyright_header(f) -> None: ) +def add_build_type(f, build_type: str) -> None: + """ + Add build type to the triplet file. + + Args: + f (file object): The file object to write the build type. + build_type (str): The build type to add. Must be one of "Debug", "Release", "RelWithDebInfo", or "MinSizeRel". + """ + if build_type not in ["Debug", "Release", "RelWithDebInfo", "MinSizeRel"]: + raise ValueError( + f"Invalid build type: {build_type}. Must be one of 'Debug', 'Release', 'RelWithDebInfo', or 'MinSizeRel'." + ) + + if build_type != "Debug": + f.write( + """set(VCPKG_BUILD_TYPE release) +""" + ) + + def generate_triplet_for_android( build_dir: str, + configs: set[str], target_abi: str, enable_rtti: bool, enable_exception: bool, @@ -130,100 +151,102 @@ def generate_triplet_for_android( file_name = f"{target_abi}-android.cmake" - dest_path = Path(build_dir) / folder_name / file_name + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + + os.makedirs(dest_path.parent, exist_ok=True) - os.makedirs(dest_path.parent, exist_ok=True) + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) + # Set target architecture for Android + if target_abi == "arm-neon": + f.write("set(VCPKG_TARGET_ARCHITECTURE arm)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=armv7a-linux-androideabi)\n") + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=armeabi-v7a -DANDROID_ARM_NEON=ON -DCMAKE_ANDROID_ARM_NEON=ON)\n" + ) + elif target_abi == "arm64": + f.write("set(VCPKG_TARGET_ARCHITECTURE arm64)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=aarch64-linux-android)\n") + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=arm64-v8a)\n") + elif target_abi == "x64": + f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=x86_64-linux-android)\n") + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86_64)\n") + elif target_abi == "x86": + f.write("set(VCPKG_TARGET_ARCHITECTURE x86)\n") + f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=i686-linux-android)\n") + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86)\n") - # Set target architecture for Android - if target_abi == "arm-neon": - f.write("set(VCPKG_TARGET_ARCHITECTURE arm)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=armv7a-linux-androideabi)\n") f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=armeabi-v7a -DANDROID_ARM_NEON=ON -DCMAKE_ANDROID_ARM_NEON=ON)\n" + f"list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_USE_LEGACY_TOOLCHAIN_FILE=false -DANDROID_PLATFORM=android-{android_api_level} -DANDROID_MIN_SDK={android_api_level})\n" ) - elif target_abi == "arm64": - f.write("set(VCPKG_TARGET_ARCHITECTURE arm64)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=aarch64-linux-android)\n") - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=arm64-v8a)\n") - elif target_abi == "x64": - f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=x86_64-linux-android)\n") - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86_64)\n") - elif target_abi == "x86": - f.write("set(VCPKG_TARGET_ARCHITECTURE x86)\n") - f.write("set(VCPKG_MAKE_BUILD_TRIPLET --host=i686-linux-android)\n") - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_ABI=x86)\n") - f.write( - f"list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DANDROID_USE_LEGACY_TOOLCHAIN_FILE=false -DANDROID_PLATFORM=android-{android_api_level} -DANDROID_MIN_SDK={android_api_level})\n" - ) - - # Set CRT linkage - # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). - # Valid options are dynamic and static. - crt_linkage = "static" - f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") + # Set CRT linkage + # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). + # Valid options are dynamic and static. + crt_linkage = "static" + f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") - # Set library linkage - # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. - # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. In our case, we prefer to use static libs. - f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") - if not enable_rtti: - f.write("set(CMAKE_ANDROID_RTTI OFF)\n") - if not enable_exception: - f.write("set(CMAKE_ANDROID_EXCEPTIONS OFF)\n") - if use_cpp_shared: - f.write("set(ANDROID_STL c++_shared)\n") + # Set library linkage + # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. + # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. In our case, we prefer to use static libs. + f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") + if not enable_rtti: + f.write("set(CMAKE_ANDROID_RTTI OFF)\n") + if not enable_exception: + f.write("set(CMAKE_ANDROID_EXCEPTIONS OFF)\n") + if use_cpp_shared: + f.write("set(ANDROID_STL c++_shared)\n") - ldflags = [] + ldflags = [] - cflags = ["-g", "-ffunction-sections", "-fdata-sections"] - cflags_release = ["-DNDEBUG", "-O3"] + cflags = ["-g", "-ffunction-sections", "-fdata-sections"] + cflags_release = ["-DNDEBUG", "-O3"] - if enable_asan: - cflags += ["-fsanitize=address"] - ldflags += ["-fsanitize=address"] + if enable_asan: + cflags += ["-fsanitize=address"] + ldflags += ["-fsanitize=address"] - ldflags.append("-g") + ldflags.append("-g") - cxxflags = cflags.copy() + cxxflags = cflags.copy() - if not enable_rtti: - cxxflags.append("-fno-rtti") + if not enable_rtti: + cxxflags.append("-fno-rtti") - if not enable_exception: - cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] + if not enable_exception: + cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] - if cflags: - f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') + if cflags: + f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') - if cxxflags: - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + if cxxflags: + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - if cflags_release: - f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + if cflags_release: + f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - # Set target platform - # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. - f.write("set(VCPKG_CMAKE_SYSTEM_NAME Android)\n") - f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") - f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" - ) + # Set target platform + # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. + f.write("set(VCPKG_CMAKE_SYSTEM_NAME Android)\n") + f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" + ) - if ldflags: - f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") - add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build + if ldflags: + f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") + add_build_type(f, config) + add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build -def generate_android_triplets(build_dir: str, use_cpp_shared: bool, android_api_level: int) -> None: +def generate_android_triplets(build_dir: str, configs: set[str], use_cpp_shared: bool, android_api_level: int) -> None: """ Generate triplet files for POSIX platforms (Linux, macOS, Android). @@ -240,6 +263,7 @@ def generate_android_triplets(build_dir: str, use_cpp_shared: bool, android_api_ for target_abi in target_abis: generate_triplet_for_android( build_dir, + configs, target_abi, enable_rtti, enable_exception, @@ -252,6 +276,7 @@ def generate_android_triplets(build_dir: str, use_cpp_shared: bool, android_api_ def generate_triplet_for_posix_platform( build_dir: str, + configs: set[str], os_name: str, enable_rtti: bool, enable_exception: bool, @@ -293,115 +318,118 @@ def generate_triplet_for_posix_platform( file_name = f"{target_abi}-{os_name}.cmake" - dest_path = Path(build_dir) / folder_name / file_name - - os.makedirs(dest_path.parent, exist_ok=True) - - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) - - # Set target architecture based on `os_name` and `target_abi`. - # - # In most cases VCPKG itself can help automatically detect the target architecture, but sometimes it is not as what we want. The following code process the special cases. - if target_abi == "universal2": - # Assume the host machine is Intel based - f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") - else: - f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - - # Set CRT linkage - # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). - # Valid options are dynamic and static. - f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") - - # Set library linkage - # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. - # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. - f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") - - ldflags = [] - - if enable_binskim and os_name == "linux": - # BinSkim rule 3005: Enable stack clash protection - # This check ensures that stack clash protection is enabled. Each program running on a computer uses a special memory region called the stack. - # This memory region is special because it grows automatically when the program needs more stack memory. But if it grows too much and gets too close to another memory region, - # the program may confuse the stack with the other memory region. An attacker can exploit this confusion to overwrite the stack with the other memory region, or the other way around. - # Use the compiler flags '-fstack-clash-protection' to enable this. - # BinSkim rule BA3011: Enable BIND_NOW - # This check ensures that some relocation data is marked as read-only after the executable is loaded, and moved below the '.data' section in memory. - # This prevents them from being overwritten, which can redirect control flow. Use the compiler flags '-Wl,-z,now' to enable this. - ldflags = ["-Wl,-Bsymbolic-functions", "-Wl,-z,relro", "-Wl,-z,now", "-Wl,-z,noexecstack"] - - cflags = ["-g", "-ffunction-sections", "-fdata-sections"] - cflags_release = ["-DNDEBUG", "-O3"] - - if enable_binskim: - cflags_release += ["-Wp,-D_FORTIFY_SOURCE=2", "-Wp,-D_GLIBCXX_ASSERTIONS", "-fstack-protector-strong"] - if target_abi == "x64": - cflags_release += ["-fstack-clash-protection", "-fcf-protection"] - - elif enable_asan: - cflags += ["-fsanitize=address"] - ldflags += ["-fsanitize=address"] - - ldflags.append("-g") - - if not enable_rtti: - cflags.append("-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") - - cxxflags = cflags.copy() - if os_name == "osx": - cxxflags += ["-fvisibility=hidden", "-fvisibility-inlines-hidden"] - if not enable_rtti: - cxxflags.append("-fno-rtti") - - if not enable_exception: - cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] - - if cflags: - f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') - - if cxxflags: - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - - if cflags_release: - f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') - - # Set target platform - # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. - if os_name == "linux": - f.write("set(VCPKG_CMAKE_SYSTEM_NAME Linux)\n") - else: - f.write("set(VCPKG_CMAKE_SYSTEM_NAME Darwin)\n") - osx_abi = None - if target_abi == "x64": - osx_abi = "x86_64" - elif target_abi == "universal2": - osx_abi = "x86_64;arm64" + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + + os.makedirs(dest_path.parent, exist_ok=True) + + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) + + # Set target architecture based on `os_name` and `target_abi`. + # + # In most cases VCPKG itself can help automatically detect the target architecture, but sometimes it is not as what we want. The following code process the special cases. + if target_abi == "universal2": + # Assume the host machine is Intel based + f.write("set(VCPKG_TARGET_ARCHITECTURE x64)\n") else: - osx_abi = target_abi - f.write(f'set(VCPKG_OSX_ARCHITECTURES "{osx_abi}")\n') - if osx_deployment_target: - f.write(f'set(VCPKG_OSX_DEPLOYMENT_TARGET "{osx_deployment_target}")\n') - f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") - f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" - ) + f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - if ldflags: - f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') - if os_name == "osx": - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=20)\n") - else: - f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") - add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build + # Set CRT linkage + # VCPKG_CRT_LINKAGE specifies the desired CRT linkage (for MSVC). + # Valid options are dynamic and static. + f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") + + # Set library linkage + # VCPKG_LIBRARY_LINKAGE specifies the preferred library linkage. + # Valid options are dynamic and static. Libraries can ignore this setting if they do not support the preferred linkage type. + f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") + + ldflags = [] + + if enable_binskim and os_name == "linux": + # BinSkim rule 3005: Enable stack clash protection + # This check ensures that stack clash protection is enabled. Each program running on a computer uses a special memory region called the stack. + # This memory region is special because it grows automatically when the program needs more stack memory. But if it grows too much and gets too close to another memory region, + # the program may confuse the stack with the other memory region. An attacker can exploit this confusion to overwrite the stack with the other memory region, or the other way around. + # Use the compiler flags '-fstack-clash-protection' to enable this. + # BinSkim rule BA3011: Enable BIND_NOW + # This check ensures that some relocation data is marked as read-only after the executable is loaded, and moved below the '.data' section in memory. + # This prevents them from being overwritten, which can redirect control flow. Use the compiler flags '-Wl,-z,now' to enable this. + ldflags = ["-Wl,-Bsymbolic-functions", "-Wl,-z,relro", "-Wl,-z,now", "-Wl,-z,noexecstack"] + + cflags = ["-g", "-ffunction-sections", "-fdata-sections"] + cflags_release = ["-DNDEBUG", "-O3"] + + if enable_binskim: + cflags_release += ["-Wp,-D_FORTIFY_SOURCE=2", "-Wp,-D_GLIBCXX_ASSERTIONS", "-fstack-protector-strong"] + if target_abi == "x64": + cflags_release += ["-fstack-clash-protection", "-fcf-protection"] + + elif enable_asan: + cflags += ["-fsanitize=address"] + ldflags += ["-fsanitize=address"] + + ldflags.append("-g") + + if not enable_rtti: + cflags.append("-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0") + + cxxflags = cflags.copy() + if os_name == "osx": + cxxflags += ["-fvisibility=hidden", "-fvisibility-inlines-hidden"] + if not enable_rtti: + cxxflags.append("-fno-rtti") + + if not enable_exception: + cxxflags += ["-fno-exceptions", "-fno-unwind-tables", "-fno-asynchronous-unwind-tables"] + + if cflags: + f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') + + if cxxflags: + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + + if cflags_release: + f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_C_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELWITHDEBINFO "{" ".join(cflags_release)}")\n') + + # Set target platform + # VCPKG_CMAKE_SYSTEM_NAME specifies the target platform. + if os_name == "linux": + f.write("set(VCPKG_CMAKE_SYSTEM_NAME Linux)\n") + else: + f.write("set(VCPKG_CMAKE_SYSTEM_NAME Darwin)\n") + osx_abi = None + if target_abi == "x64": + osx_abi = "x86_64" + elif target_abi == "universal2": + osx_abi = "x86_64;arm64" + else: + osx_abi = target_abi + f.write(f'set(VCPKG_OSX_ARCHITECTURES "{osx_abi}")\n') + if osx_deployment_target: + f.write(f'set(VCPKG_OSX_DEPLOYMENT_TARGET "{osx_deployment_target}")\n') + f.write("set(CMAKE_POSITION_INDEPENDENT_CODE ON)\n") + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DBENCHMARK_ENABLE_WERROR=OFF)\n" + ) + + if ldflags: + f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') + if os_name == "osx": + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=20)\n") + else: + f.write("list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS -DCMAKE_CXX_STANDARD=17)\n") + add_build_type(f, config) + add_port_configs(f, enable_exception, False, enable_minimal_build) # Pass enable_minimal_build def generate_vcpkg_triplets_for_emscripten( build_dir: str, + configs: set[str], emscripten_root: str, # Parameters defining the specific build configuration enable_rtti: bool, @@ -449,105 +477,108 @@ def generate_vcpkg_triplets_for_emscripten( for target_abi in ["wasm32", "wasm64"]: os_name = "emscripten" file_name = f"{target_abi}-{os_name}.cmake" - dest_path = Path(build_dir) / folder_name / file_name - os.makedirs(dest_path.parent, exist_ok=True) + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + os.makedirs(dest_path.parent, exist_ok=True) - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) - f.write(r""" + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) + f.write(r""" set(VCPKG_CRT_LINKAGE dynamic) set(VCPKG_LIBRARY_LINKAGE static) set(VCPKG_CMAKE_SYSTEM_NAME Emscripten) """) - f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - emscripten_root_path_cmake_path = emscripten_root.replace("\\", "/") - f.write(f'set(EMSCRIPTEN_ROOT_PATH "{emscripten_root_path_cmake_path}")\n') - - # Define the path to the intermediate toolchain file used by vcpkg for wasm - vcpkg_toolchain_file = (Path(build_dir) / "emsdk_vcpkg_toolchain.cmake").absolute() - vcpkg_toolchain_file_cmake_path = str(vcpkg_toolchain_file).replace("\\", "/") - f.write(f'set(VCPKG_CHAINLOAD_TOOLCHAIN_FILE "{vcpkg_toolchain_file_cmake_path}")\n') - - # --- Configure Flags based on Parameters --- - cflags_release = ["-DNDEBUG", "-O3", "-flto"] - ldflags = [] # Initialize linker flags list - # Base flags applicable to both C and C++ - base_flags = [ - "-ffunction-sections", - "-fdata-sections", - "-msimd128", - "-pthread", - "-Wno-pthreads-mem-growth", - ] - - # ASan (apply to Base, Linker) - if enable_asan: - asan_flag = "-fsanitize=address" - base_flags.append(asan_flag) - ldflags.append(asan_flag) # Add to linker flags - - # Wasm Exception Catching Runtime (-s flag, apply to Base and Linker flags) - exception_catching_flag = "" - if enable_wasm_exception_catching: - exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=0" - else: - exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=1" - - base_flags.append(exception_catching_flag) # Add to base C/C++ flags - ldflags.append(exception_catching_flag) # Add to linker flags - - # Wasm64 Memory (apply to Base, Linker) - if target_abi == "wasm64": - memory_flag = "-sMEMORY64" - base_flags.append(memory_flag) - ldflags.append(memory_flag) # Add to linker flags - - # --- C Flags --- - # VCPKG_C_FLAGS applies only base flags - f.write(f'set(VCPKG_C_FLAGS "{" ".join(base_flags)}")\n') - - # --- CXX Flags --- - # Start with base flags - cxxflags = list(base_flags) # Create a copy - - # C++ RTTI Compiler Flag - if not enable_rtti: - cxxflags.append("-fno-rtti") - - # C++ Exceptions Compiler Flag (Derived from enable_minimal_onnx_build) - if not cpp_exceptions_enabled: # i.e., if enable_minimal_onnx_build is True - cxxflags.append("-fno-exceptions") - # If cpp_exceptions_enabled=True, we assume -fexceptions is the default - # or handled by the Emscripten toolchain/CMake settings elsewhere. - - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - - # --- Linker Flags --- - # Apply Linker flags (now includes exception and memory flags explicitly) - if len(ldflags) >= 1: - f.write('set(VCPKG_LINKER_FLAGS "{}")\n'.format(" ".join(ldflags))) - - # --- Release / RelWithDebInfo Flags --- - # Combine base flags with release-specific flags - c_combined_release_flags = cflags_release + base_flags - cxx_combined_release_flags = cflags_release + cxxflags # Use the derived cxxflags - - f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(c_combined_release_flags)}")\n') - f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cxx_combined_release_flags)}")\n') - - f.write("set(VCPKG_LINKER_FLAGS_RELEASE -flto)\n") - - # --- Add Port Specific Configs --- - # Pass the derived C++ exception status and the original minimal build flag - add_port_configs( - f, - has_exception=cpp_exceptions_enabled, # Derived value - is_emscripten=True, - enable_minimal_build=enable_minimal_onnx_build, - ) # Original parameter - - -def generate_windows_triplets(build_dir: str, toolset_version: str) -> None: + f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") + emscripten_root_path_cmake_path = emscripten_root.replace("\\", "/") + f.write(f'set(EMSCRIPTEN_ROOT_PATH "{emscripten_root_path_cmake_path}")\n') + + # Define the path to the intermediate toolchain file used by vcpkg for wasm + vcpkg_toolchain_file = (Path(build_dir) / "emsdk_vcpkg_toolchain.cmake").absolute() + vcpkg_toolchain_file_cmake_path = str(vcpkg_toolchain_file).replace("\\", "/") + f.write(f'set(VCPKG_CHAINLOAD_TOOLCHAIN_FILE "{vcpkg_toolchain_file_cmake_path}")\n') + + # --- Configure Flags based on Parameters --- + cflags_release = ["-DNDEBUG", "-O3", "-flto"] + ldflags = [] # Initialize linker flags list + # Base flags applicable to both C and C++ + base_flags = [ + "-ffunction-sections", + "-fdata-sections", + "-msimd128", + "-pthread", + "-Wno-pthreads-mem-growth", + ] + + # ASan (apply to Base, Linker) + if enable_asan: + asan_flag = "-fsanitize=address" + base_flags.append(asan_flag) + ldflags.append(asan_flag) # Add to linker flags + + # Wasm Exception Catching Runtime (-s flag, apply to Base and Linker flags) + exception_catching_flag = "" + if enable_wasm_exception_catching: + exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=0" + else: + exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=1" + + base_flags.append(exception_catching_flag) # Add to base C/C++ flags + ldflags.append(exception_catching_flag) # Add to linker flags + + # Wasm64 Memory (apply to Base, Linker) + if target_abi == "wasm64": + memory_flag = "-sMEMORY64" + base_flags.append(memory_flag) + ldflags.append(memory_flag) # Add to linker flags + + # --- C Flags --- + # VCPKG_C_FLAGS applies only base flags + f.write(f'set(VCPKG_C_FLAGS "{" ".join(base_flags)}")\n') + + # --- CXX Flags --- + # Start with base flags + cxxflags = list(base_flags) # Create a copy + + # C++ RTTI Compiler Flag + if not enable_rtti: + cxxflags.append("-fno-rtti") + + # C++ Exceptions Compiler Flag (Derived from enable_minimal_onnx_build) + if not cpp_exceptions_enabled: # i.e., if enable_minimal_onnx_build is True + cxxflags.append("-fno-exceptions") + # If cpp_exceptions_enabled=True, we assume -fexceptions is the default + # or handled by the Emscripten toolchain/CMake settings elsewhere. + + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + + # --- Linker Flags --- + # Apply Linker flags (now includes exception and memory flags explicitly) + if len(ldflags) >= 1: + f.write('set(VCPKG_LINKER_FLAGS "{}")\n'.format(" ".join(ldflags))) + + # --- Release / RelWithDebInfo Flags --- + # Combine base flags with release-specific flags + c_combined_release_flags = cflags_release + base_flags + cxx_combined_release_flags = cflags_release + cxxflags # Use the derived cxxflags + + f.write(f'set(VCPKG_C_FLAGS_RELEASE "{" ".join(c_combined_release_flags)}")\n') + f.write(f'set(VCPKG_CXX_FLAGS_RELEASE "{" ".join(cxx_combined_release_flags)}")\n') + + f.write("set(VCPKG_LINKER_FLAGS_RELEASE -flto)\n") + + add_build_type(f, config) + + # --- Add Port Specific Configs --- + # Pass the derived C++ exception status and the original minimal build flag + add_port_configs( + f, + has_exception=cpp_exceptions_enabled, # Derived value + is_emscripten=True, + enable_minimal_build=enable_minimal_onnx_build, + ) # Original parameter + + +def generate_windows_triplets(build_dir: str, configs: set[str], toolset_version: str) -> None: """ Generate triplet files for Windows platforms. @@ -593,54 +624,56 @@ def generate_windows_triplets(build_dir: str, toolset_version: str) -> None: if crt_linkage == "dynamic": file_name_parts.append("md") file_name = "-".join(file_name_parts) + ".cmake" - dest_path = Path(build_dir) / folder_name / file_name - os.makedirs(dest_path.parent, exist_ok=True) - with open(dest_path, "w", encoding="utf-8") as f: - add_copyright_header(f) - f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") - f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") - f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") - if toolset_version: - f.write(f"set(VCPKG_PLATFORM_TOOLSET_VERSION {toolset_version})\n") - cflags = ["/MP", "/DWIN32", "/D_WINDOWS"] - if enable_binskim: - cflags += [ - "/DWINAPI_FAMILY=100", - "/DWINVER=0x0A00", - "/D_WIN32_WINNT=0x0A00", - "/DNTDDI_VERSION=0x0A000000", - ] - ldflags = [] - if enable_binskim: - cflags += ["/guard:cf", "/Qspectre", "/W3"] - ldflags = ["/profile", "/DYNAMICBASE"] - elif enable_asan: - cflags.append("/fsanitize=address") - cxxflags = cflags.copy() - cxxflags.append("/Zc:__cplusplus") - if enable_exception: - cxxflags.append("/EHsc") - # MSVC doesn't have a specific flag to disable exceptions like /EHs-c- - # but relies on _HAS_EXCEPTIONS=0 and potentially other flags managed by ORT's main CMake. - # Vcpkg doesn't directly control this via a simple triplet flag AFAIK. - # ORT's CMake handles this via CMAKE_CXX_FLAGS adjustment. - if not enable_rtti: - cxxflags += ["/GR-", "/we4541"] - if cflags: - f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') - if cxxflags: - f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') - f.write( - "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DCMAKE_CXX_STANDARD=17)\n" - ) - if ldflags: - f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') - add_port_configs( - f, enable_exception, False, enable_minimal_build - ) # Pass enable_minimal_build - - -def generate_linux_triplets(build_dir: str) -> None: + for config in configs: + dest_path = Path(build_dir) / config / folder_name / file_name + os.makedirs(dest_path.parent, exist_ok=True) + with open(dest_path, "w", encoding="utf-8") as f: + add_copyright_header(f) + f.write(f"set(VCPKG_TARGET_ARCHITECTURE {target_abi})\n") + f.write(f"set(VCPKG_CRT_LINKAGE {crt_linkage})\n") + f.write("set(VCPKG_LIBRARY_LINKAGE static)\n") + if toolset_version: + f.write(f"set(VCPKG_PLATFORM_TOOLSET_VERSION {toolset_version})\n") + cflags = ["/MP", "/DWIN32", "/D_WINDOWS"] + if enable_binskim: + cflags += [ + "/DWINAPI_FAMILY=100", + "/DWINVER=0x0A00", + "/D_WIN32_WINNT=0x0A00", + "/DNTDDI_VERSION=0x0A000000", + ] + ldflags = [] + if enable_binskim: + cflags += ["/guard:cf", "/Qspectre", "/W3"] + ldflags = ["/profile", "/DYNAMICBASE"] + elif enable_asan: + cflags.append("/fsanitize=address") + cxxflags = cflags.copy() + cxxflags.append("/Zc:__cplusplus") + if enable_exception: + cxxflags.append("/EHsc") + # MSVC doesn't have a specific flag to disable exceptions like /EHs-c- + # but relies on _HAS_EXCEPTIONS=0 and potentially other flags managed by ORT's main CMake. + # Vcpkg doesn't directly control this via a simple triplet flag AFAIK. + # ORT's CMake handles this via CMAKE_CXX_FLAGS adjustment. + if not enable_rtti: + cxxflags += ["/GR-", "/we4541"] + if cflags: + f.write(f'set(VCPKG_C_FLAGS "{" ".join(cflags)}")\n') + if cxxflags: + f.write(f'set(VCPKG_CXX_FLAGS "{" ".join(cxxflags)}")\n') + f.write( + "list(APPEND VCPKG_CMAKE_CONFIGURE_OPTIONS --compile-no-warning-as-error -DCMAKE_CXX_STANDARD=17)\n" + ) + if ldflags: + f.write(f'set(VCPKG_LINKER_FLAGS "{" ".join(ldflags)}")\n') + add_build_type(f, config) + add_port_configs( + f, enable_exception, False, enable_minimal_build + ) # Pass enable_minimal_build + + +def generate_linux_triplets(build_dir: str, configs: set[str]) -> None: """ Generate triplet files for Linux platforms. @@ -660,6 +693,7 @@ def generate_linux_triplets(build_dir: str) -> None: for target_abi in target_abis: generate_triplet_for_posix_platform( build_dir, + configs, "linux", enable_rtti, enable_exception, @@ -672,7 +706,7 @@ def generate_linux_triplets(build_dir: str) -> None: ) -def generate_macos_triplets(build_dir: str, osx_deployment_target: str) -> None: +def generate_macos_triplets(build_dir: str, configs: set[str], osx_deployment_target: str) -> None: """ Generate triplet files for macOS platforms. @@ -694,6 +728,7 @@ def generate_macos_triplets(build_dir: str, osx_deployment_target: str) -> None: for target_abi in target_abis: generate_triplet_for_posix_platform( build_dir, + configs, "osx", enable_rtti, enable_exception,