diff --git a/.github/actions/macos-ci-setup/action.yml b/.github/actions/macos-ci-setup/action.yml index 0d60eeae8aee3..054676d301820 100644 --- a/.github/actions/macos-ci-setup/action.yml +++ b/.github/actions/macos-ci-setup/action.yml @@ -8,7 +8,7 @@ inputs: python_version: required: false type: string - default: "3.11" + default: "3.14" node_version: required: false type: string diff --git a/.github/actions/setup-android-ndk/action.yml b/.github/actions/setup-android-ndk/action.yml index fea9745396e81..4eefc3642cd92 100644 --- a/.github/actions/setup-android-ndk/action.yml +++ b/.github/actions/setup-android-ndk/action.yml @@ -89,10 +89,10 @@ runs: set -e -x python3 tools/python/run_android_emulator.py \ --android-sdk-root "${ANDROID_SDK_ROOT}" \ - --start --emulator-extra-args="-partition-size 2047" \ + --start --emulator-extra-args="-partition-size 2047 -memory 5120" \ --emulator-pid-file ./emulator.pid echo "Emulator PID: `cat ./emulator.pid`" - name: View Android ENVs shell: bash - run: env | grep ANDROID \ No newline at end of file + run: env | grep ANDROID diff --git a/.github/workflows/linux_ci.yml b/.github/workflows/linux_ci.yml index 9aa8418c55a40..dd8cbfdc71a9c 100644 --- a/.github/workflows/linux_ci.yml +++ b/.github/workflows/linux_ci.yml @@ -68,6 +68,21 @@ jobs: secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + build-linux-x64-release-py314: + name: Build Linux x64 Release (Python 3.14) + uses: ./.github/workflows/reusable_linux_build.yml + with: + pool_name: "onnxruntime-github-Ubuntu2204-AMD-CPU" + build_config: Release + architecture: x64 + dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu + docker_image_repo: onnxruntimecpubuildpythonx64 + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --build_nuget --enable_transformers_tool_test --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON' + python_path_prefix: 'PATH=/opt/python/cp314-cp314/bin:$PATH' # $ needs escaping in single quotes + job_identifier: build-linux-x64-release-py314 + secrets: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + orttraining-linux-ci-pipeline: name: Build Linux x64 Release with training uses: ./.github/workflows/reusable_linux_build.yml @@ -109,7 +124,7 @@ jobs: dockerfile_path: tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile docker_image_repo: onnxruntimecpubuildpythonaarch64 extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON' - python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' # $ needs escaping in single quotes + python_path_prefix: 'PATH=/opt/python/cp314-cp314/bin:$PATH' # $ needs escaping in single quotes job_identifier: build-linux-arm64-release secrets: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index 5423145132639..0f8b4a42f48ae 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -16,7 +16,7 @@ concurrency: cancel-in-progress: true env: - python_version: 3.11 + python_version: "3.14" jobs: cpu: @@ -28,6 +28,7 @@ jobs: {"machine": "arm64", "target": "arm64", "build_config": "Debug"}, {"machine": "arm64", "target": "arm64", "build_config": "Release"} ] + python_version: "3.14" coreml: uses: ./.github/workflows/macos-ci-build-and-test-workflow.yml @@ -39,6 +40,7 @@ jobs: {"machine": "arm64", "target": "arm64", "build_config": "Debug"}, {"machine": "arm64", "target": "arm64", "build_config": "Release"} ] + python_version: "3.14" xnnpack: uses: ./.github/workflows/macos-ci-build-and-test-workflow.yml @@ -49,6 +51,7 @@ jobs: [ {"machine": "arm64", "target": "arm64", "build_config": "Debug"} ] + python_version: "3.14" webgpu: uses: ./.github/workflows/macos-ci-build-and-test-workflow.yml @@ -60,6 +63,7 @@ jobs: {"machine": "arm64", "target": "arm64", "build_config": "Debug"}, {"machine": "arm64", "target": "arm64", "build_config": "Release"} ] + python_version: "3.14" iphone_simulator: runs-on: macos-15 @@ -72,7 +76,7 @@ jobs: matrix: target_arch: [x86_64, arm64] - timeout-minutes: 90 + timeout-minutes: 120 steps: - name: Checkout code diff --git a/.github/workflows/macos-ci-build-and-test-workflow.yml b/.github/workflows/macos-ci-build-and-test-workflow.yml index 75002fdf12c00..76198c7f5c1ce 100644 --- a/.github/workflows/macos-ci-build-and-test-workflow.yml +++ b/.github/workflows/macos-ci-build-and-test-workflow.yml @@ -19,7 +19,7 @@ on: python_version: required: false type: string - default: "3.11" + default: "3.14" matrix_include: required: false type: string diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 89ae03981ecef..7b93086fbb77d 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -32,7 +32,7 @@ jobs: - uses: actions/setup-python@v6 with: - python-version: '3.12' + python-version: '3.14' architecture: x64 - name: Locate vcvarsall and Setup Env @@ -173,7 +173,7 @@ jobs: - uses: actions/setup-python@v6 with: - python-version: '3.12' + python-version: '3.14' architecture: x64 - uses: actions/setup-node@v6 diff --git a/.gitmodules b/.gitmodules index 37455d1bb64c2..2864085bf85bf 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,4 +7,4 @@ [submodule "cmake/external/emsdk"] path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git - branch = 4.0.21 + branch = 4.0.23 diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 53cc1a6f9292c..ae96cc7310aaa 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.24.0 +1.24.3 diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index cd939acc5aeae..6d0d39556e1c0 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -91,6 +91,7 @@ option(onnxruntime_USE_SVE "Build with SVE support in MLAS" OFF) option(onnxruntime_USE_ARM_NEON_NCHWC "Build with ARM Neon NCHWc kernels in MLAS" OFF) option(onnxruntime_USE_KLEIDIAI "Build with KleidiAI integration in MLAS" OFF) +option(onnxruntime_USE_QMX_KLEIDIAI_COEXIST "Build with QMX and Arm KLEIDIAI libraries" OFF) option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON) option(onnxruntime_BUILD_CSHARP "Build C# library" OFF) option(onnxruntime_BUILD_OBJC "Build Objective-C library" OFF) diff --git a/cmake/deps.txt b/cmake/deps.txt index 9b8c716f12236..65c74060e8deb 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -46,7 +46,7 @@ protoc_linux_aarch64;https://github.com/protocolbuffers/protobuf/releases/downlo protoc_mac_universal;https://github.com/protocolbuffers/protobuf/releases/download/v21.12/protoc-21.12-osx-universal_binary.zip;23710c3d1c2036d8d65a6a22234372fa2d7af9ef psimd;https://github.com/Maratyszcza/psimd/archive/072586a71b55b7f8c584153d223e95687148a900.zip;1f5454b01f06f9656b77e4a5e2e31d7422487013 pthreadpool;https://github.com/google/pthreadpool/archive/dcc9f28589066af0dbd4555579281230abbf74dd.zip;533a77943203ef15ca608bcd9dbe2c94da7451d2 -pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v2.13.6.zip;f780292da9db273c8ef06ccf5fd4b623624143e9 +pybind11;https://github.com/pybind/pybind11/archive/refs/tags/v3.0.2.zip;a064e663b4d7a337ac291d1bef7337ef4e60a1ae pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/403d652dca4c1046e8145950b1c0997a9f748b57.zip;30b2a07fe4bae8574f89176e56274cacdd6d135b re2;https://github.com/google/re2/archive/refs/tags/2024-07-02.zip;646e1728269cde7fcef990bf4a8e87b047882e88 safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac @@ -56,5 +56,8 @@ extensions;https://github.com/microsoft/onnxruntime-extensions/archive/c24b7bab0 directx_headers;https://github.com/microsoft/DirectX-Headers/archive/refs/tags/v1.613.1.zip;47653509a3371eabb156360f42faf582f314bf2e cudnn_frontend;https://github.com/NVIDIA/cudnn-frontend/archive/refs/tags/v1.12.0.zip;7e733cfdc410d777b76122d64232499205589a96 dawn;https://github.com/google/dawn/archive/13c1635a14574ebb7116b56a69f5519301417fda.zip;0aadd28fc385cf7d657d5fc70a352372d2d3c76a -kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.15.0.tar.gz;62ccd24ab60bcef68766440fb42d79071ac2a5d2 +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.20.0.tar.gz;6895e72b3d5cf1173358164cb3d64c9d7d33cc84 +# kleidiai-qmx is pinned to a specific commit as there are no tagged releases. When an appropriate tagged release becomes available, +# this entry will be updated to use refs/tags/ instead of the raw commit hash. +kleidiai-qmx;https://github.com/qualcomm/kleidiai/archive/2f10c9a8d32f81ffeeb6d4885a29cc35d2b0da87.zip;5e855730a2d69057a569f43dd7532db3b2d2a05c duktape;https://github.com/svaarala/duktape/releases/download/v2.7.0/duktape-2.7.0.tar.xz;8200c8e417dbab7adcc12c4dbdef7651cfc55794 diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 6405236da1734..6c5464851937c 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -20,8 +20,13 @@ else() endif() endif() -if(Patch_FOUND AND WIN32) - set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_windows.patch) +if(Patch_FOUND) + if (WIN32) + set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_windows.patch && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_cuda_warnings.patch) + else() + set(ABSL_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/abseil/absl_cuda_warnings.patch) + endif() else() set(ABSL_PATCH_COMMAND "") endif() diff --git a/cmake/external/cuda_configuration.cmake b/cmake/external/cuda_configuration.cmake index be6a5febf3e14..00f7d81eda53d 100644 --- a/cmake/external/cuda_configuration.cmake +++ b/cmake/external/cuda_configuration.cmake @@ -85,6 +85,11 @@ macro(setup_cuda_architectures) # * Always use accelerated (`-a` suffix) target for supported real architectures. # cmake-format: on + # Allow override via CUDAARCHS environment variable (standard CMake variable) + if(NOT CMAKE_CUDA_ARCHITECTURES AND DEFINED ENV{CUDAARCHS}) + set(CMAKE_CUDA_ARCHITECTURES "$ENV{CUDAARCHS}") + endif() + if(CMAKE_CUDA_ARCHITECTURES STREQUAL "native") # Detect highest available compute capability set(OUTPUTFILE ${PROJECT_BINARY_DIR}/detect_cuda_arch) @@ -139,12 +144,12 @@ macro(setup_cuda_architectures) continue() endif() - if(CUDA_ARCH MATCHES "^([1-9])([0-9])+a?-virtual$") + if(CUDA_ARCH MATCHES "^([1-9])([0-9])+[af]?-virtual$") set(CMAKE_CUDA_ARCHITECTURES_LAST_VIRTUAL ${CUDA_ARCH}) - elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?-real$") - list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) - elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)a?$") + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)[af]?-real$") list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}) + elseif(CUDA_ARCH MATCHES "^(([1-9])([0-9])+)([af]?)$") + list(APPEND CMAKE_CUDA_ARCHITECTURES_CLEAN ${CMAKE_MATCH_1}${CMAKE_MATCH_4}) else() message(FATAL_ERROR "Unrecognized CUDA architecture: ${CUDA_ARCH}") endif() @@ -156,7 +161,7 @@ macro(setup_cuda_architectures) set(CMAKE_CUDA_ARCHITECTURES_ORIG "${CMAKE_CUDA_ARCHITECTURES}") message(STATUS "GPU architectures: ${CMAKE_CUDA_ARCHITECTURES_ORIG}") - set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "120") + set(ARCHITECTURES_WITH_KERNELS "80" "86" "89" "90" "100" "110" "120") foreach(CUDA_ARCH IN LISTS ARCHITECTURES_WITH_KERNELS) if(NOT "${CUDA_ARCH}" IN_LIST CMAKE_CUDA_ARCHITECTURES_ORIG) add_definitions("-DEXCLUDE_SM_${CUDA_ARCH}") @@ -165,10 +170,13 @@ macro(setup_cuda_architectures) endforeach() # Enable accelerated features (like WGMMA, TMA and setmaxnreg) for SM >= 90. - set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "120") + set(ARCHITECTURES_WITH_ACCEL "90" "100" "101" "110" "120") unset(CMAKE_CUDA_ARCHITECTURES_NORMALIZED) foreach(CUDA_ARCH IN LISTS CMAKE_CUDA_ARCHITECTURES) - if("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL) + if(CUDA_ARCH MATCHES "^([0-9]+)f$") + # Family code, no -real suffix + list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}") + elseif("${CUDA_ARCH}" IN_LIST ARCHITECTURES_WITH_ACCEL) list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}a-real") else() list(APPEND CMAKE_CUDA_ARCHITECTURES_NORMALIZED "${CUDA_ARCH}-real") diff --git a/cmake/external/emsdk b/cmake/external/emsdk index b2436aafa7351..c0bb220cb6e6f 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit b2436aafa7351ee1b581f15841f1b45ed716a279 +Subproject commit c0bb220cb6e6f4e0fabb6f6db9efd53390ef5e56 diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index 3c616684fb296..9feb7772d1e88 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -845,6 +845,12 @@ if(onnxruntime_USE_KLEIDIAI) onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) onnxruntime_fetchcontent_makeavailable(kleidiai) + # Fetch Qualcomm's kleidiai library + if(onnxruntime_USE_QMX_KLEIDIAI_COEXIST) + onnxruntime_fetchcontent_declare(kleidiai-qmx URL ${DEP_URL_kleidiai-qmx} URL_HASH SHA1=${DEP_SHA1_kleidiai-qmx} + EXCLUDE_FROM_ALL) + onnxruntime_fetchcontent_makeavailable(kleidiai-qmx) + endif() endif() set(onnxruntime_LINK_DIRS) diff --git a/cmake/external/pybind11.cmake b/cmake/external/pybind11.cmake index 79280c97a899e..ba14667bc3c88 100644 --- a/cmake/external/pybind11.cmake +++ b/cmake/external/pybind11.cmake @@ -6,7 +6,6 @@ onnxruntime_fetchcontent_declare( URL ${DEP_URL_pybind11} URL_HASH SHA1=${DEP_SHA1_pybind11} EXCLUDE_FROM_ALL - FIND_PACKAGE_ARGS 2.13 NAMES pybind11 + FIND_PACKAGE_ARGS 3.0 NAMES pybind11 ) onnxruntime_fetchcontent_makeavailable(pybind11_project) - diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index c0ab948b41fff..d7dcde945e6d7 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -45,6 +45,8 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp ${MLAS_SRC_DIR}/qnbitgemm.h ${MLAS_SRC_DIR}/qnbitgemm.cpp + ${MLAS_SRC_DIR}/qlutgemm.h + ${MLAS_SRC_DIR}/qlutgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp ${MLAS_SRC_DIR}/cast.cpp @@ -113,6 +115,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp ) set(mlas_platform_preprocess_srcs @@ -209,6 +212,8 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp @@ -284,6 +289,11 @@ function(setup_kleidiai) ) target_link_libraries(onnxruntime_mlas PRIVATE kleidiai) list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai) + if(onnxruntime_USE_QMX_KLEIDIAI_COEXIST) + target_link_libraries(onnxruntime_mlas PRIVATE kleidiai-qmx) + target_compile_definitions(onnxruntime_mlas PRIVATE ENABLE_QMX_KERNELS=1) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai-qmx) + endif() set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE) # If KLEIDIAI_DEBUG is enabled that implies both DEBUG and KERNEL messages. @@ -302,13 +312,21 @@ function(setup_kleidiai) RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() + + if(onnxruntime_USE_QMX_KLEIDIAI_COEXIST) + install(TARGETS kleidiai-qmx EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() endfunction() function (setup_arm_neon_nchwc) target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/sconv.h - ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp - ${MLAS_SRC_DIR}/spool_kernel_neon.cpp + ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.h + ${MLAS_SRC_DIR}/sconv_nchwc_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_nchwc_kernel_neon.cpp ) list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) @@ -460,6 +478,7 @@ else() ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + ${MLAS_SRC_DIR}/sconv_nchw_kernel_neon.cpp ) # Conditionally add the SVE implementation if compiler supports it @@ -496,6 +515,7 @@ else() ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp ${MLAS_SRC_DIR}/cast_kernel_neon.cpp ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp @@ -511,6 +531,7 @@ else() set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbconv_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") @@ -693,6 +714,8 @@ else() ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp ${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.h + ${MLAS_SRC_DIR}/sqnbitgemm_lut_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index 02d923b9cbc10..686b6b6ba1228 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1212,6 +1212,13 @@ block() ${TEST_SRC_DIR}/common/tensor_op_test_utils.h ) + if (onnxruntime_USE_DNNL) + list(APPEND supporting_test_srcs + ${TEST_SRC_DIR}/common/dnnl_op_test_utils.cc + ${TEST_SRC_DIR}/common/dnnl_op_test_utils.h + ) + endif() + list(APPEND onnxruntime_provider_test_srcs ${supporting_test_srcs} ${onnxruntime_unittest_main_src} @@ -1553,8 +1560,13 @@ endif() onnxruntime_common ${CMAKE_DL_LIBS}) set_target_properties(onnxruntime_runtime_path_test_shared_library PROPERTIES AIX_SHARED_LIBRARY_ARCHIVE OFF) else() - target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE - onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + if (CPUINFO_SUPPORTED) + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common cpuinfo ${CMAKE_DL_LIBS}) + else() + target_link_libraries(onnxruntime_runtime_path_test_shared_library PRIVATE + onnxruntime_common ${CMAKE_DL_LIBS}) + endif() endif() target_include_directories(onnxruntime_runtime_path_test_shared_library PRIVATE ${ONNXRUNTIME_ROOT}) diff --git a/cmake/patches/abseil/absl_cuda_warnings.patch b/cmake/patches/abseil/absl_cuda_warnings.patch new file mode 100644 index 0000000000000..144b9f904bf0f --- /dev/null +++ b/cmake/patches/abseil/absl_cuda_warnings.patch @@ -0,0 +1,40 @@ +diff --git a/absl/hash/internal/hash.h b/absl/hash/internal/hash.h +index 1234567..abcdefg 100644 +--- a/absl/hash/internal/hash.h ++++ b/absl/hash/internal/hash.h +@@ -477,7 +477,7 @@ H AbslHashValue(H hash_state, T (&)[N]) { + template + H AbslHashValue(H hash_state, T (&)[N]) { + static_assert( +- sizeof(T) == -1, ++ sizeof(T) == size_t(-1), + "Hashing C arrays is not allowed. For string literals, wrap the literal " + "in absl::string_view(). To hash the array contents, use " + "absl::MakeSpan() or make the array an std::array. To hash the array " +diff --git a/absl/hash/hash.h b/absl/hash/hash.h +index 1234567..abcdefg 100644 +--- a/absl/hash/hash.h ++++ b/absl/hash/hash.h +@@ -333,7 +333,8 @@ class HashState : public hash_internal::HashStateBase { + absl::enable_if_t< + std::is_base_of, T>::value, int> = 0> + static HashState Create(T* state) { +- HashState s; ++ HashState s = {}; ++ (void)s; + s.Init(state); + return s; + } +diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h +index 1234567..abcdefg 100644 +--- a/absl/container/internal/raw_hash_set.h ++++ b/absl/container/internal/raw_hash_set.h +@@ -464,7 +464,7 @@ inline uint16_t NextSeed() { + inline uint16_t NextSeed() { + static_assert(PerTableSeed::kBitCount == 16); + thread_local uint16_t seed = +- static_cast(reinterpret_cast(&seed)); ++ static_cast(reinterpret_cast(&seed) & 0xFFFFu); + seed += uint16_t{0xad53}; + return seed; + } diff --git a/cmake/vcpkg-ports/abseil/absl_cuda_warnings.patch b/cmake/vcpkg-ports/abseil/absl_cuda_warnings.patch new file mode 100644 index 0000000000000..144b9f904bf0f --- /dev/null +++ b/cmake/vcpkg-ports/abseil/absl_cuda_warnings.patch @@ -0,0 +1,40 @@ +diff --git a/absl/hash/internal/hash.h b/absl/hash/internal/hash.h +index 1234567..abcdefg 100644 +--- a/absl/hash/internal/hash.h ++++ b/absl/hash/internal/hash.h +@@ -477,7 +477,7 @@ H AbslHashValue(H hash_state, T (&)[N]) { + template + H AbslHashValue(H hash_state, T (&)[N]) { + static_assert( +- sizeof(T) == -1, ++ sizeof(T) == size_t(-1), + "Hashing C arrays is not allowed. For string literals, wrap the literal " + "in absl::string_view(). To hash the array contents, use " + "absl::MakeSpan() or make the array an std::array. To hash the array " +diff --git a/absl/hash/hash.h b/absl/hash/hash.h +index 1234567..abcdefg 100644 +--- a/absl/hash/hash.h ++++ b/absl/hash/hash.h +@@ -333,7 +333,8 @@ class HashState : public hash_internal::HashStateBase { + absl::enable_if_t< + std::is_base_of, T>::value, int> = 0> + static HashState Create(T* state) { +- HashState s; ++ HashState s = {}; ++ (void)s; + s.Init(state); + return s; + } +diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h +index 1234567..abcdefg 100644 +--- a/absl/container/internal/raw_hash_set.h ++++ b/absl/container/internal/raw_hash_set.h +@@ -464,7 +464,7 @@ inline uint16_t NextSeed() { + inline uint16_t NextSeed() { + static_assert(PerTableSeed::kBitCount == 16); + thread_local uint16_t seed = +- static_cast(reinterpret_cast(&seed)); ++ static_cast(reinterpret_cast(&seed) & 0xFFFFu); + seed += uint16_t{0xad53}; + return seed; + } diff --git a/cmake/vcpkg-ports/abseil/portfile.cmake b/cmake/vcpkg-ports/abseil/portfile.cmake index 3cdedca7265ef..1e9c48ea834b2 100644 --- a/cmake/vcpkg-ports/abseil/portfile.cmake +++ b/cmake/vcpkg-ports/abseil/portfile.cmake @@ -9,6 +9,7 @@ vcpkg_from_github( SHA512 4ee1a217203933382e728d354a149253a517150eee7580a0abecc69584b2eb200d91933ef424487e3a3fe0e8ab5e77b0288485cac982171b3585314a4417e7d4 HEAD_REF master PATCHES absl_windows.patch + absl_cuda_warnings.patch ) diff --git a/cmake/vcpkg-ports/pybind11/portfile.cmake b/cmake/vcpkg-ports/pybind11/portfile.cmake index 2c63582d1ee15..4e4cd30a26df1 100644 --- a/cmake/vcpkg-ports/pybind11/portfile.cmake +++ b/cmake/vcpkg-ports/pybind11/portfile.cmake @@ -2,7 +2,8 @@ vcpkg_from_github( OUT_SOURCE_PATH SOURCE_PATH REPO pybind/pybind11 REF "v${VERSION}" - SHA512 497c25b33b09a9c42f67131ab82e35d689e8ce089dd7639be997305ff9a6d502447b79c824508c455d559e61f0186335b54dd2771d903a7c1621833930622d1a + # SHA512 for the zip (not tar.gz) file. + SHA512 786b1bf534ac67a8d5669f8babf67bb13e48b3a3da1b6344e43ae10a84b80bbc8fea5f12a65fd18739c341fefef5622c5dc096db964dff33cc62ea4259b2e2c1 HEAD_REF master ) diff --git a/cmake/vcpkg-ports/pybind11/vcpkg.json b/cmake/vcpkg-ports/pybind11/vcpkg.json index a730d32017885..058e2235fea08 100644 --- a/cmake/vcpkg-ports/pybind11/vcpkg.json +++ b/cmake/vcpkg-ports/pybind11/vcpkg.json @@ -1,6 +1,6 @@ { "name": "pybind11", - "version": "2.13.6", + "version": "3.0.2", "description": "pybind11 is a lightweight header-only library that exposes C++ types in Python and vice versa, mainly to create Python bindings of existing C++ code", "homepage": "https://github.com/pybind/pybind11", "license": "BSD-3-Clause", diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs index 1ae7b5c9eb991..a6b267c6802cf 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/NativeMethods.shared.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using System; +using System.Reflection; using System.Runtime.InteropServices; using static Microsoft.ML.OnnxRuntime.NativeMethods; @@ -474,6 +475,12 @@ internal static class NativeMethods static NativeMethods() { +#if !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__ + // Register a custom DllImportResolver to handle platform-specific library loading. + // Replaces default resolution specifically on Windows for case-sensitivity. + NativeLibrary.SetDllImportResolver(typeof(NativeMethods).Assembly, DllImportResolver); +#endif + #if NETSTANDARD2_0 IntPtr ortApiBasePtr = OrtGetApiBase(); OrtApiBase ortApiBase = (OrtApiBase)Marshal.PtrToStructure(ortApiBasePtr, typeof(OrtApiBase)); @@ -847,7 +854,7 @@ static NativeMethods() api_.CreateSyncStreamForEpDevice, typeof(DOrtCreateSyncStreamForEpDevice)); - OrtSyncStream_GetHandle = + OrtSyncStream_GetHandle = (DOrtSyncStream_GetHandle)Marshal.GetDelegateForFunctionPointer( api_.SyncStream_GetHandle, typeof(DOrtSyncStream_GetHandle)); @@ -872,10 +879,142 @@ internal class NativeLib // Define the library name required for iOS internal const string DllName = "__Internal"; #else - // Note: the file name in ONNX Runtime nuget package must be onnxruntime.dll instead of onnxruntime.DLL(Windows filesystem can be case sensitive) - internal const string DllName = "onnxruntime.dll"; + // For desktop platforms (including .NET Standard 2.0), we use the simple name + // to allow .NET's automatic platform-specific resolution (lib*.so, lib*.dylib, *.dll). + // For .NET Core 3.0+, case-sensitivity on Windows is handled by DllImportResolver. + internal const string DllName = "onnxruntime"; +#endif + } + +#if !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__ + /// + /// Custom DllImportResolver to handle platform-specific library loading. + /// On Windows, it explicitly loads the library with a lowercase .dll extension to handle + /// case-sensitive filesystems. + /// +#if NET5_0_OR_GREATER + [System.Diagnostics.CodeAnalysis.UnconditionalSuppressMessage("SingleFile", "IL3000:Avoid accessing Assembly file path when publishing as a single file", Justification = "We also check AppContext.BaseDirectory as a fallback")] #endif + private static IntPtr DllImportResolver(string libraryName, Assembly assembly, DllImportSearchPath? searchPath) + { + try + { + if (libraryName == NativeLib.DllName || libraryName == OrtExtensionsNativeMethods.ExtensionsDllName) + { + string mappedName = null; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + // Explicitly load with .dll extension to avoid issues where the OS might try .DLL + mappedName = libraryName + ".dll"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + // Explicitly load with .so extension and lib prefix + mappedName = "lib" + libraryName + ".so"; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + // Explicitly load with .dylib extension and lib prefix + mappedName = "lib" + libraryName + ".dylib"; + } + + if (mappedName != null) + { + // 1. Try default loading (name only) + if (NativeLibrary.TryLoad(mappedName, assembly, searchPath, out IntPtr handle)) + { + return handle; + } + + // 2. Try relative to assembly location (look into runtimes subfolders) + string assemblyLocation = null; + try { assemblyLocation = assembly.Location; } catch { } + if (!string.IsNullOrEmpty(assemblyLocation)) + { + string assemblyDir = System.IO.Path.GetDirectoryName(assemblyLocation); + string rid = RuntimeInformation.RuntimeIdentifier; + + // Probe the specific RID first, then common fallbacks for the current OS + string[] ridsToTry; + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + ridsToTry = new[] { rid, "win-x64", "win-arm64" }; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) + { + ridsToTry = new[] { rid, "linux-x64", "linux-arm64" }; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + // We no longer provide osx-x64 in official package since 1.24. + // However, we keep it in the list for build-from-source users. + ridsToTry = new[] { rid, "osx-arm64", "osx-x64" }; + } + else + { + ridsToTry = new[] { rid }; + } + + foreach (var tryRid in ridsToTry) + { + string probePath = System.IO.Path.Combine(assemblyDir, "runtimes", tryRid, "native", mappedName); + if (System.IO.File.Exists(probePath) && NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + { + LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); + return handle; + } + } + } + + // 3. Try AppContext.BaseDirectory as a fallback + try + { + string baseDir = AppContext.BaseDirectory; + if (!string.IsNullOrEmpty(baseDir)) + { + string probePath = System.IO.Path.Combine(baseDir, mappedName); + if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + { + LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); + return handle; + } + + string rid = RuntimeInformation.RuntimeIdentifier; + probePath = System.IO.Path.Combine(baseDir, "runtimes", rid, "native", mappedName); + if (NativeLibrary.TryLoad(probePath, assembly, searchPath, out handle)) + { + LogLibLoad($"[DllImportResolver] Loaded {mappedName} from: {probePath}"); + return handle; + } + } + } + catch { } // Ignore AppDomainUnloadedException or similar from AppContext.BaseDirectory + + LogLibLoad($"[DllImportResolver] Failed loading {mappedName} (RID: {RuntimeInformation.RuntimeIdentifier}, Assembly: {assemblyLocation})"); + + } + } + } + catch (Exception ex) + { + // Unhandled exceptions inside DllImportResolver can result in TypeInitializationException. + // Log and swallow the error, returning IntPtr.Zero to fall back to default CLR logic. + try { System.Diagnostics.Trace.WriteLine($"[DllImportResolver] Exception during resolution: {ex}"); } catch { } + } + + // Fall back to default resolution + return IntPtr.Zero; + } + + private static void LogLibLoad(string message) + { + System.Diagnostics.Trace.WriteLine(message); + if (!string.IsNullOrEmpty(Environment.GetEnvironmentVariable("ORT_LOADER_VERBOSITY"))) + { + Console.WriteLine(message); + } } +#endif [DllImport(NativeLib.DllName, CharSet = CharSet.Ansi)] #if NETSTANDARD2_0 @@ -2644,7 +2783,7 @@ public delegate void DOrtAddKeyValuePair(IntPtr /* OrtKeyValuePairs* */ kvps, byte[] /* const char* */ value); /// - /// Get the value for the provided key. + /// Get the value for the provided key. /// /// Value. Returns IntPtr.Zero if key was not found. [UnmanagedFunctionPointer(CallingConvention.Winapi)] @@ -2767,7 +2906,7 @@ out IntPtr /* OrtSyncStream** */ stream // Auto Selection EP registration and selection customization /// - /// Register an execution provider library. + /// Register an execution provider library. /// The library must implement CreateEpFactories and ReleaseEpFactory. /// /// Environment to add the EP library to. @@ -2952,9 +3091,10 @@ internal static class OrtExtensionsNativeMethods #elif __IOS__ internal const string ExtensionsDllName = "__Internal"; #else - // For desktop platforms, explicitly specify the DLL name with extension to avoid - // issues on case-sensitive filesystems. See NativeLib.DllName for detailed explanation. - internal const string ExtensionsDllName = "ortextensions.dll"; + // For desktop platforms, use the simple name to allow .NET's + // automatic platform-specific resolution (lib*.so, lib*.dylib, *.dll). + // Case-sensitivity on Windows is handled by DllImportResolver. + internal const string ExtensionsDllName = "ortextensions"; #endif [DllImport(ExtensionsDllName, CharSet = CharSet.Ansi, diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props.xml b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props.xml index efe5c659f250a..c3cd38c9cd56b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props.xml +++ b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props.xml @@ -28,14 +28,7 @@ - - - $(MSBuildThisFileDirectory)../../runtimes/win-x86/native/onnxruntime.lib;%(AdditionalDependencies) - - - - x86 arm64 arm $(Platform) @@ -120,7 +113,8 @@ + Condition="'$(PlatformTarget)' == 'ARM64' AND + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime.dll')"> onnxruntime.dll PreserveNewest false @@ -135,7 +129,8 @@ + Condition="'$(PlatformTarget)' == 'ARM' AND + Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm\native\onnxruntime.dll')"> onnxruntime.dll PreserveNewest false @@ -147,34 +142,5 @@ PreserveNewest false - - - - onnxruntime.dll - PreserveNewest - false - - - dnnl.dll - PreserveNewest - false - - - mklml.dll - PreserveNewest - false - - - libiomp5md.dll - PreserveNewest - false - diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml index 83ffb22ccf6b2..c1ad99a778a67 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml +++ b/csharp/src/Microsoft.ML.OnnxRuntime/targets/netstandard/props_qnn.xml @@ -28,14 +28,7 @@ - - - $(MSBuildThisFileDirectory)../../runtimes/win-x86/native/onnxruntime.lib;%(AdditionalDependencies) - - - - x86 arm64 arm $(Platform) @@ -91,13 +84,5 @@ PreserveNewest false - - - - onnxruntime.dll - PreserveNewest - false - diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs index 94f8e927c1331..aa1b683acd668 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/OrtEnvTests.cs @@ -489,4 +489,47 @@ void TestCopyTensors() } } } + + [Collection("Ort Inference Tests")] + public class OrtEnvDllImportResolverTest + { + [Fact(DisplayName = "TestDllImportResolverDoesNotThrow")] + public void TestDllImportResolverDoesNotThrow() + { + // The DllImportResolver is a private static method in NativeMethods. + var nativeMethodsType = typeof(OrtEnv).Assembly.GetType("Microsoft.ML.OnnxRuntime.NativeMethods"); + Assert.NotNull(nativeMethodsType); + + // It might not be defined on all platforms (defined when !NETSTANDARD2_0 && !__ANDROID__ && !__IOS__). + var resolverMethod = nativeMethodsType.GetMethod("DllImportResolver", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static); + + if (resolverMethod != null) + { + try + { + // Invoke with null assembly to force it into edge cases where assembly.Location would throw NullReferenceException. + // It should catch the exception and return IntPtr.Zero gracefully rather than throwing. + var result = resolverMethod.Invoke(null, new object[] { "onnxruntime", null, null }); + + // If it reaches here without throwing TargetInvocationException, the try-catch in DllImportResolver works. + Assert.True(result is IntPtr); + } + catch (System.Reflection.TargetInvocationException ex) + { + // If NativeMethods..cctor() threw because the native library is missing, + // we will get a TypeInitializationException wrapping a DllNotFoundException (or DllImportException). + // This is acceptable locally. What we want to avoid is NullReferenceException from DllImportResolver. + if (ex.InnerException is TypeInitializationException typeInitEx) + { + Assert.IsNotType(typeInitEx.InnerException); + } + else + { + Assert.IsNotType(ex.InnerException); + throw; + } + } + } + } + } } diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs index f0d1313783643..c0475bb6102c1 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.NetCoreApp/InferenceTest.netcore.cs @@ -601,6 +601,29 @@ private static Dictionary GetSkippedModels(DirectoryInfo modelsD skipModels["VGG 16-fp32"] = "bad allocation"; } + // The following models are from onnx repo and fail on MacOS nuget test pipeline. + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + var macOSSkips = new[] + { + "test_castlike_FLOAT_to_STRING_expanded", + "test_castlike_FLOAT_to_BFLOAT16_expanded", + "test_castlike_BFLOAT16_to_FLOAT", + "test_cast_FLOAT_to_STRING", + "test_castlike_FLOAT_to_BFLOAT16", + "test_castlike_STRING_to_FLOAT_expanded", + "test_castlike_STRING_to_FLOAT", + "test_cast_STRING_to_FLOAT", + "test_castlike_BFLOAT16_to_FLOAT_expanded", + "test_cast_BFLOAT16_to_FLOAT", + "test_castlike_FLOAT_to_STRING" + }; + foreach (var model in macOSSkips) + { + skipModels[model] = "Skipped on macOS due to flakes or lack of support"; + } + } + return skipModels; } @@ -934,6 +957,7 @@ public void TestPretrainedModelsWithOrtValue(string opsetDir, string modelName) [MemberData(nameof(GetSkippedModelForTest), Skip = "Skipped due to Error, please fix the error and enable the test")] private void TestPreTrainedModels(string opsetDir, string modelName, bool useOrtValueAPIs = false) { + var opsetDirInfo = new DirectoryInfo(opsetDir); var opset = opsetDirInfo.Name; string onnxModelFileName = null; diff --git a/dockerfiles/Dockerfile.source b/dockerfiles/Dockerfile.source index ea28e144ee95a..51291e59aa0d5 100644 --- a/dockerfiles/Dockerfile.source +++ b/dockerfiles/Dockerfile.source @@ -16,4 +16,4 @@ RUN cd /code && /bin/bash ./build.sh --allow_running_as_root --skip_submodule_sy FROM mcr.microsoft.com/azurelinux/base/python:3 COPY --from=0 /code/build/Linux/Release/dist /root COPY --from=0 /code/dockerfiles/LICENSE-IMAGE.txt /code/LICENSE-IMAGE.txt -RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install coloredlogs humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl +RUN tdnf install -y ca-certificates python3-setuptools python3-wheel python3-pip python3-numpy python3-flatbuffers python3-packaging python3-protobuf python3-mpmath python3-sympy && python3 -m pip install humanfriendly && python3 -m pip install --no-index --find-links /root onnxruntime && rm -rf /root/*.whl diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index ebed2b1972ba9..44569dd7e5eff 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -60,13 +60,13 @@ Do not modify directly.* |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BlackmanWindow|*in* size:**T1**
*out* output:**T2**|17+|**T1** = tensor(int32), tensor(int64)
**T2** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| -|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| -|||23|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| -|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| -|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| -|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| -|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|25+|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||24|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||23|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[21, 22]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[19, 20]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[13, 18]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|||[6, 12]|**T1** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)| |||[6, 12]|**T** = tensor(double), tensor(float)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float)| @@ -106,7 +106,7 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(uint8)| |||[1, 10]|**T** = tensor(double), tensor(float)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T3**|25+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T3**|25+|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint2), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||24|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||23|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| |||[21, 22]|**T1** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float), tensor(float16)| @@ -313,9 +313,9 @@ Do not modify directly.* |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**TS**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**TS**
*in* b_zero_point:**T2**
*in* y_scale:**TS**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|21+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**TS** = tensor(float)| |||[10, 20]|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**T2**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|25+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| -|||24|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| -|||23|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**T2**
*in* y_zero_point:**T3**
*out* y:**T3**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|25+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)
**T3** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int4), tensor(int8), tensor(uint16), tensor(uint2), tensor(uint4), tensor(uint8)| +|||24|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)
**T3** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| +|||23|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float), tensor(float16)
**T3** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[21, 22]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)| |||[19, 20]|**T1** = tensor(float), tensor(float16)
**T2** = tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int8), tensor(uint8)| |||[13, 18]|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| @@ -503,7 +503,7 @@ Do not modify directly.* |||[11, 23]|**I** = tensor(int64)
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||10|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| |||[1, 9]|**I** = tensor(int64)
**T** = tensor(double), tensor(float)| -|Transpose|*in* data:**T**
*out* transposed:**T**|25+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| +|Transpose|*in* data:**T**
*out* transposed:**T**|25+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int2), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint2), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||24|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||23|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[21, 22]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| diff --git a/docs/python/README.rst b/docs/python/README.rst index f610b36958fe1..06f5b0ebf3094 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,10 +8,20 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime & AllTensorTypesIRv9(); static const std::vector& AllTensorTypesIRv10(); static const std::vector& AllTensorTypesIRv11(); + static const std::vector& AllTensorTypesIRv13(); static const std::vector& AllFixedSizeTensorTypes(); // up to IR4 (no float 8), deprecated static const std::vector& AllFixedSizeTensorTypesIRv4(); @@ -285,7 +287,7 @@ template struct IsTensorContainedType : public IsAnyOf struct IsSparseTensorContainedType : public IsAnyOf(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -171,6 +177,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -230,6 +242,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -287,6 +305,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -355,6 +379,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -421,6 +451,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -477,6 +513,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } @@ -531,6 +573,12 @@ namespace utils { case ONNX_NAMESPACE::TensorProto_DataType_UINT4: \ retval = function(__VA_ARGS__); \ break; \ + case ONNX_NAMESPACE::TensorProto_DataType_INT2: \ + retval = function(__VA_ARGS__); \ + break; \ + case ONNX_NAMESPACE::TensorProto_DataType_UINT2: \ + retval = function(__VA_ARGS__); \ + break; \ default: \ ORT_ENFORCE(false, "Unknown tensor type of ", tensor_type); \ } diff --git a/include/onnxruntime/core/framework/int2.h b/include/onnxruntime/core/framework/int2.h new file mode 100644 index 0000000000000..0d406d6fcd8d3 --- /dev/null +++ b/include/onnxruntime/core/framework/int2.h @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/common/common.h" +#include + +namespace onnxruntime { + +template +struct Int2Traits; + +template <> +struct Int2Traits { + using UnpackedType = int8_t; + static constexpr int8_t min_val = -2; + static constexpr int8_t max_val = 1; +}; + +template <> +struct Int2Traits { + using UnpackedType = uint8_t; + static constexpr uint8_t min_val = 0; + static constexpr uint8_t max_val = 3; +}; + +/// +/// Stores 4 packed 2-bit elements in 1 byte. +/// Packing follows ONNX spec: x0 | (x1 << 2) | (x2 << 4) | (x3 << 6) +/// +/// Set to true if signed int2, or false if unsigned uint2. +template +struct Int2x4Base { + using UnpackedType = typename Int2Traits::UnpackedType; + static constexpr UnpackedType min_val = Int2Traits::min_val; + static constexpr UnpackedType max_val = Int2Traits::max_val; + + std::byte bits_{}; + + Int2x4Base() = default; + + explicit Int2x4Base(std::byte bits) { + bits_ = bits; + } + + Int2x4Base(UnpackedType val0, UnpackedType val1, UnpackedType val2, UnpackedType val3) { + bits_ = static_cast( + (val0 & 0x3) | + ((val1 & 0x3) << 2) | + ((val2 & 0x3) << 4) | + ((val3 & 0x3) << 6)); + } + + static inline int8_t SignExtendLower2Bits(std::byte bits) { + // Sign-extend lower 2-bits by left shifting and then doing an arithmetic right shift. + constexpr uint8_t shift = (sizeof(int32_t) * 8) - 2; + return static_cast((static_cast(bits) << shift) >> shift); + } + + inline UnpackedType GetElem(size_t index) const { + assert(index <= 3); + const uint8_t shift = 2 * static_cast(index); + const std::byte val = (bits_ >> shift) & std::byte{0x3}; + + if constexpr (Signed) { + return SignExtendLower2Bits(val); + } else { + return static_cast(val); + } + } + + inline void SetElem(size_t index, UnpackedType val) { + assert(index <= 3); + const uint8_t shift = 2 * static_cast(index); + const std::byte clear_mask = ~(std::byte{0x3} << shift); + + bits_ &= clear_mask; // Clear 2-bit element to 0 + bits_ |= static_cast((val & 0x3) << shift); // Set 2-bit element to val + } + + inline std::byte ToBits() const { + return bits_; + } + + /// + /// Calculates the number of packed byte units needed to store the given number of 2-bit elements. + /// Each byte stores 4 x 2-bit elements. + /// + static size_t CalcNumInt2Quads(size_t num_int2_elems) { + return (num_int2_elems + 3) / 4; + } + + /// + /// Copy a source buffer of 2-bit elements (packed) into a destination buffer of 8-bit elements (unpacked). + /// + /// Destination buffer to store unpacked 8-bit elements + /// Source buffer with 2-bit elements + /// True on success + static bool Unpack(gsl::span dst, gsl::span> src) { + if (CalcNumInt2Quads(dst.size()) != src.size()) { + return false; + } + + if (src.empty()) { + return true; + } + + for (size_t i = 0; i < dst.size(); i++) { + size_t byte_idx = i >> 2; // i / 4 + size_t elem_idx = i & 0x3; // i % 4 + dst[i] = src[byte_idx].GetElem(elem_idx); + } + + return true; + } + + /// + /// Copy a source buffer of 8-bit elements (unpacked) into a destination buffer of 2-bit elements (packed). + /// + /// Destination buffer to store packed 2-bit elements + /// Source buffer with 8-bit elements + /// True on success + static bool Pack(gsl::span> dst, gsl::span src) { + if (CalcNumInt2Quads(src.size()) != dst.size()) { + return false; + } + + if (src.empty()) { + return true; + } + + size_t src_i = 0; + size_t dst_i = 0; + const size_t full_quads = src.size() / 4; + + // Process complete groups of 4 elements + for (; dst_i < full_quads; dst_i++) { + dst[dst_i] = Int2x4Base(src[src_i], src[src_i + 1], src[src_i + 2], src[src_i + 3]); + src_i += 4; + } + + // Handle remaining elements (1-3) + if (src_i < src.size()) { + UnpackedType vals[4] = {0, 0, 0, 0}; + size_t remaining = src.size() - src_i; + for (size_t j = 0; j < remaining; j++) { + vals[j] = src[src_i + j]; + } + dst[dst_i] = Int2x4Base(vals[0], vals[1], vals[2], vals[3]); + } + + return true; + } + + /// + /// Returns hierarchical indices for a packed int2 element from the given element index. + /// + /// Usage: + /// Int2x4* data = ...; + /// auto indices = GetTensorElemIndices(5); // 6th int2 element + /// int8_t elem = data[indices.first].GetElem(indices.second); + /// + /// Index of 2-bit element + /// Pair of (byte_index, element_index_within_byte) + static inline std::pair GetTensorElemIndices(size_t index) { + return {index >> 2, index & 0x3}; + } +}; + +using Int2x4 = Int2x4Base; +using UInt2x4 = Int2x4Base; +static_assert(sizeof(Int2x4) == sizeof(std::byte)); +static_assert(sizeof(UInt2x4) == sizeof(std::byte)); + +} // namespace onnxruntime diff --git a/include/onnxruntime/core/framework/run_options.h b/include/onnxruntime/core/framework/run_options.h index e63ab044834f5..001fa158345ab 100644 --- a/include/onnxruntime/core/framework/run_options.h +++ b/include/onnxruntime/core/framework/run_options.h @@ -51,6 +51,11 @@ struct OrtRunOptions { onnxruntime::InlinedVector active_adapters; + // Optional sync stream for external resource import. + // When set, the EP uses this stream for execution, enabling proper + // synchronization with imported external semaphores. + OrtSyncStream* sync_stream = nullptr; + OrtRunOptions() = default; ~OrtRunOptions() = default; }; diff --git a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h index e1b5e614d095d..82aefe0165fcc 100644 --- a/include/onnxruntime/core/framework/to_tensor_proto_element_type.h +++ b/include/onnxruntime/core/framework/to_tensor_proto_element_type.h @@ -13,6 +13,7 @@ #include "core/framework/float4.h" #include "core/common/float8.h" #include "core/common/float16.h" +#include "core/framework/int2.h" #include "core/framework/int4.h" namespace onnxruntime { @@ -116,5 +117,14 @@ constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType return ONNX_NAMESPACE::TensorProto_DataType_UINT4; } +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_INT2; +} +template <> +constexpr ONNX_NAMESPACE::TensorProto_DataType ToTensorProtoElementType() { + return ONNX_NAMESPACE::TensorProto_DataType_UINT2; +} + } // namespace utils } // namespace onnxruntime diff --git a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h index 3df33df06acb0..2b61890225888 100644 --- a/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h +++ b/include/onnxruntime/core/providers/utils/ort_graph_to_proto.h @@ -122,6 +122,7 @@ #define INCLUDE_ONNXRUNTIME_CORE_PROVIDERS_UTILS_ORT_GRAPH_TO_PROTO_H_ #include +#include #include "core/session/onnxruntime_cxx_api.h" #include "onnx/onnx_pb.h" @@ -317,9 +318,11 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph, // Don't add graph inputs or graph outputs to GraphProto's list of value_infos. // Do add initializers (constant and non-constant) to GraphProto's list of initializer tensors. - // For values defined in an outer scope, just add the value info but not the initializer. if (is_from_outer_scope) { value_infos.emplace(value_name, ort_value_info); + if (is_constant_initializer) { + initializer_value_infos.emplace(value_name, ort_value_info); + } } else if (is_optional_graph_input) { initializer_value_infos.emplace(value_name, ort_value_info); } else if (is_constant_initializer) { @@ -413,6 +416,16 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph, ORT_EP_UTILS_CXX_RETURN_IF_ERROR(OrtValueInfoToProto(value_info, *value_info_proto)); } + // There may be initializers in the original OrtGraph that have not been added yet. + // For example, an initializer may not be used by any node but is still a graph output. + // Iterating through all nodes to collect initializer value info is therefore not sufficient, + // initializers must also be obtained from ort_graph.GetInitializers(). + // Add those missing initializers and skip the ones that already in `initializer_value_infos` + std::vector ort_graph_initializers = ort_graph.GetInitializers(); + for (const auto& initializer : ort_graph_initializers) { + initializer_value_infos.emplace(initializer.GetName(), initializer); + } + // Add initializers to GraphProto as TensorProto objects. for (const auto& [initializer_name, initializer_value_info] : initializer_value_infos) { std::vector initializer_dims; @@ -490,10 +503,7 @@ Ort::Status OrtGraphToProto(const OrtGraph& graph, onnx::ModelProto& model_proto, HandleInitializerDataFunc handle_initializer_data_func) { try { - // Check that OrtGraph is a top-level graph (no parent node). Ort::ConstGraph ort_graph{&graph}; - Ort::ConstNode parent_node = ort_graph.GetParentNode(); - ORT_EP_UTILS_C_RETURN_IF(parent_node != nullptr, "Cannot serialize nested OrtGraph into a ModelProto"); // Set model description. model_proto.set_doc_string("Serialized from OrtGraph"); diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 410b63147a8fe..6ae1539d4c294 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -209,6 +209,9 @@ typedef enum ONNXTensorElementDataType { ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4, // maps to a pair of packed int4 values (size == 1 byte) // Float4 types were introduced in ONNX 1.18. See https://onnx.ai/onnx/technical/float4.html ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1, // maps to a pair of packed float4 values (size == 1 byte) + // Int2 types were introduced in ONNX 1.20. See https://onnx.ai/onnx/technical/int2.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2, // maps to 4 packed uint2 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2, // maps to 4 packed int2 values (size == 1 byte) } ONNXTensorElementDataType; // Synced with onnx TypeProto oneof @@ -334,6 +337,8 @@ ORT_RUNTIME_CLASS(ExternalResourceImporter); // Capability object for external ORT_RUNTIME_CLASS(ExternalMemoryHandle); // EP-imported view of shared external allocation ORT_RUNTIME_CLASS(ExternalSemaphoreHandle); // EP-imported view of shared external semaphore ORT_RUNTIME_CLASS(DeviceEpIncompatibilityDetails); +ORT_RUNTIME_CLASS(EpAssignedSubgraph); +ORT_RUNTIME_CLASS(EpAssignedNode); #ifdef _MSC_VER typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; @@ -965,10 +970,6 @@ typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options */ typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status); -/** \addtogroup Global - * @{ - */ - /** \brief External memory handle type for importing GPU resources. * * \todo Add OPAQUE_WIN32 for Windows Vulkan-specific memory handles @@ -1036,8 +1037,6 @@ typedef struct OrtExternalTensorDescriptor { Enables multiple tensors from the same imported memory handle. */ } OrtExternalTensorDescriptor; -/// @} - /* * Public enum for compiled model compatibility across EPs. */ @@ -1063,8 +1062,8 @@ typedef struct OrtEnvCreationOptions { * \note Logging messages which are less severe than the `logging_severity_level` are not emitted. * * \note Serves as the default logging severity level for session creation and runs. - * Use ::SetSessionLogSeverityLevel() to set a logging severity level for the creation of specific session. - * Use ::RunOptionsSetRunLogSeverityLevel() to set a logging severity level for a specific session run. + * Use OrtApi::SetSessionLogSeverityLevel to set a logging severity level for the creation of specific session. + * Use OrtApi::RunOptionsSetRunLogSeverityLevel to set a logging severity level for a specific session run. * * \since Version 1.24. */ @@ -1117,7 +1116,7 @@ typedef struct OrtEnvCreationOptions { * \note Refer to onnxruntime_env_config_keys.h for common config entry keys and their supported values. * * \note An application provides environment-level configuration options for execution provider libraries by - * using keys with the prefix 'ep_factory..'. Ex: the key 'ep_factory.my_ep.some_ep_key' represents + * using keys with the prefix 'ep_factory.\\.'. Ex: the key 'ep_factory.my_ep.some_ep_key' represents * a key named 'some_ep_key' that is meant to be consumed by an execution provider named 'my_ep'. Refer to * the specific execution provider's documentation for valid keys and values. * @@ -6095,7 +6094,8 @@ struct OrtApi { /** \brief Returns an OrtGraph that contains a subset of nodes in the source OrtGraph. * * \note The lifetime of "dst_graph" is tied to that of "src_graph", as they both internally reference - * the same underlying graph. + * the same underlying graph. "dst_graph" preserves the input order of "src_graph", and + * its output order corresponds to the outputs produced by the nodes in "nodes" with the given order. * * \param[in] src_graph The source OrtGraph instance. * \param[in] nodes A subset of the nodes/OrtNodes in 'graph'. @@ -6384,6 +6384,9 @@ struct OrtApi { /** \brief Get the node's parent OrtGraph instance. * * Can return NULL if the OrtNode was created without an owning graph. + * In another case, this API may also return NULL if `node` is obtained by calling Graph_GetParentNode() + * on an OrtGraph that is a subgraph of a control-flow op, and the parent graph has not been created yet, + * for example during ORT's GetCapability() when processing the innermost subgraph. * * \param[in] node The OrtNode instance. * \param[out] graph Output parameter set to the node's OrtGraph. Can be set to NULL @@ -7000,6 +7003,77 @@ struct OrtApi { /// @} + /// \name Model Compatibility APIs + /// @{ + + /** \brief Extract EP compatibility info from a precompiled model file. + * + * Parses the model file to extract the compatibility info string for a specific execution provider + * from the model's metadata properties. This is only applicable to models that have been precompiled + * for an EP (e.g., via OrtCompileApi). Standard ONNX models do not contain this information. + * + * The compatibility info string must be valid UTF-8 without embedded NUL characters. + * + * \note This API performs standalone model parsing, separate from session creation. This means + * the protobuf parsing cost is incurred here and again during session creation. It is intended + * for scenarios where applications need to check compatibility before deciding whether to proceed + * with session creation, such as providing early user feedback. + * + * \note This operation parses the full ONNX ModelProto from disk. For very large models, consider + * using GetCompatibilityInfoFromModelBytes with a pre-loaded buffer if the model is already in memory. + * + * The compatibility info can then be passed to GetModelCompatibilityForEpDevices to check if a + * precompiled model is compatible with the current system. + * + * \param[in] model_path Path to the ONNX model file. + * \param[in] ep_type The execution provider type string. Must be non-empty. + * Use OrtApi::EpDevice_EpName to get this value from an OrtEpDevice. + * \param[in] allocator Allocator to use for the output string. Use OrtApi::GetAllocatorWithDefaultOptions. + * \param[out] compatibility_info Output pointer to the compatibility info string. + * Returns nullptr if no compatibility info exists for the specified EP. + * Caller must free with OrtApi::AllocatorFree when non-null. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetCompatibilityInfoFromModel, + _In_ const ORTCHAR_T* model_path, + _In_ const char* ep_type, + _Inout_ OrtAllocator* allocator, + _Outptr_result_maybenull_ char** compatibility_info); + + /** \brief Extract EP compatibility info from precompiled model bytes in memory. + * + * Same as GetCompatibilityInfoFromModel but reads from a memory buffer instead of a file. + * Useful when precompiled models are loaded from encrypted storage, network, or other non-file sources. + * + * \note This API performs standalone model parsing, separate from session creation. This means + * the protobuf parsing cost is incurred here and again during session creation. It is intended + * for scenarios where applications need to check compatibility before deciding whether to proceed + * with session creation, such as providing early user feedback. + * + * \param[in] model_data Pointer to the model data in memory. + * \param[in] model_data_length Size of the model data in bytes. + * \param[in] ep_type The execution provider type string. Must be non-empty. + * \param[in] allocator Allocator to use for the output string. Use OrtApi::GetAllocatorWithDefaultOptions. + * \param[out] compatibility_info Output pointer to the compatibility info string. + * Returns nullptr if no compatibility info exists for the specified EP. + * Caller must free with OrtApi::AllocatorFree when non-null. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetCompatibilityInfoFromModelBytes, + _In_reads_(model_data_length) const void* model_data, + _In_ size_t model_data_length, + _In_ const char* ep_type, + _Inout_ OrtAllocator* allocator, + _Outptr_result_maybenull_ char** compatibility_info); + + /// @} + /** \brief Create an OrtEnv instance with the given options. * * \note Invoking this function will return the same instance of the environment as that returned by a previous call @@ -7013,6 +7087,139 @@ struct OrtApi { * \since Version 1.24 */ ORT_API2_STATUS(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out); + + /** \brief Get information about the subgraphs assigned to each execution provider (EP) and the nodes within. + * + * Each returned OrtEpAssignedSubgraph instance contains details of the subgraph/nodes assigned to an execution + * provider, including the execution provider's name, and the name, domain, and operator type for every node. + * + * For compiling execution providers, a single OrtEpAssignedSubgraph instance contains information about the + * nodes that are fused and compiled within a single subgraph assigned to the execution provider. + * + * For execution providers that use kernel registration (e.g., CPU EP), each node with a registered kernel is + * contained in its own OrtEpAssignedSubgraph instance. + * + * \note The caller must enable the collection of this information by enabling the session + * configuration entry "session.record_ep_graph_assignment_info" during session creation. + * Refer to onnxruntime_session_options_config_keys.h. Otherwise, if not enabled, this function returns a + * status with error code ORT_FAIL. + * + * \note The information reported by this function is obtained immediately after running basic optimizations on the + * original graph if the session optimization level is set to ORT_ENABLE_BASIC or higher. If the session + * optimization level is set to ORT_DISABLE_ALL, only minimal/required optimizations are run before + * the information is collected. + * + * \param[in] session The OrtSession instance. + * \param[out] ep_subgraphs Output parameter set to the array of OrtEpAssignedSubgraph instances. + * \param[out] num_ep_subgraphs Output parameter set to the number of elements in the `ep_subgraphs` array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(Session_GetEpGraphAssignmentInfo, _In_ const OrtSession* session, + _Outptr_ const OrtEpAssignedSubgraph* const** ep_subgraphs, + _Out_ size_t* num_ep_subgraphs); + + /** \brief Get the name of the execution provider to which the subgraph was assigned. + * + * \param[in] ep_subgraph The OrtEpAssignedSubgraph instance. + * \param[out] out Output parameter set to the execution provider's name as a UTF-8 null-terminated string. + * Owned by the OrtEpAssignedSubgraph instance (do not free). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(EpAssignedSubgraph_GetEpName, _In_ const OrtEpAssignedSubgraph* ep_subgraph, + _Outptr_ const char** out); + + /** \brief Get the nodes in a subgraph assigned to a specific execution provider. + * + * \param[in] ep_subgraph The OrtEpAssignedSubgraph instance. + * \param[out] ep_nodes Output parameter set to the array of OrtEpAssignedNode instances. + * \param[out] num_ep_nodes Output parameter set to the number of OrtEpAssignedNode instance returned. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgraph* ep_subgraph, + _Outptr_ const OrtEpAssignedNode* const** ep_nodes, _Out_ size_t* num_ep_nodes); + + /** \brief Get the name of the node assigned to an execution provider. + * + * \param[in] ep_node The OrtEpAssignedNode instance. + * \param[out] out Output parameter set to the node's name as a UTF-8 null-terminated string. + * Owned by the OrtEpAssignedNode instance (do not free). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); + + /** \brief Get the domain of the node assigned to an execution provider. + * + * \param[in] ep_node The OrtEpAssignedNode instance. + * \param[out] out Output parameter set to the node's domain as a UTF-8 null-terminated string. + * Owned by the OrtEpAssignedNode instance (do not free). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); + + /** \brief Get the operator type of the node assigned to an execution provider. + * + * \param[in] ep_node The OrtEpAssignedNode instance. + * \param[out] out Output parameter set to the node's operator type as a UTF-8 null-terminated string. + * Owned by the OrtEpAssignedNode instance (do not free). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); + + /** \brief Sets OrtSyncStream for the run options + * + * OrtSyncStream is used to synchronize the execution of the model run for the device + * of the stream. It overrides the existing stream for the duration of the Run(). + * The stream instance must be alive for the duration of the Run() call. + * + * \param[in] options + * \param[in] sync_stream The synchronization stream. Pass nullptr to clear previous setting. + * + * \since 1.24 + */ + ORT_API_T(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream); + + /** \brief Get the element data type and shape for an OrtValue that represents a Tensor (scalar, dense, or sparse). + * + * \note This function is an alternative to ::GetTensorTypeAndShape() that does not allocate a new array for + * the shape data. The OrtValue instance's internal shape data is returned directly. + * + * \note Returns an error if the underlying OrtValue is not a Tensor. + * + * \param[in] value The OrtValue instance. + * \param[out] elem_type Output parameter set to the tensor element data type. + * \param[out] shape_data Output parameter set to the OrtValue instance's internal shape data array. + * For a scalar, `shape_data` is NULL and `shape_data_count` is 0. + * Must not be released as it is owned by the OrtValue instance. This pointer becomes invalid + * when the OrtValue is released or if the underlying shape data is updated or reallocated. + * \param[out] shape_data_count Output parameter set to the number of elements in `shape_data`. + * `shape_data_count` is 0 for a scalar. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count); }; /* @@ -7868,7 +8075,7 @@ struct OrtInteropApi { /** \brief Release an OrtExternalResourceImporter instance. * - * \param[in] importer The OrtExternalResourceImporter instance to release. May be nullptr. + * \param[in] input The OrtExternalResourceImporter instance to release. May be nullptr. * * \since Version 1.24. */ @@ -7911,7 +8118,7 @@ struct OrtInteropApi { /** \brief Release an OrtExternalMemoryHandle instance. * - * \param[in] handle The OrtExternalMemoryHandle instance to release. May be nullptr. + * \param[in] input The OrtExternalMemoryHandle instance to release. May be nullptr. * * \since Version 1.24. */ @@ -7977,7 +8184,7 @@ struct OrtInteropApi { /** \brief Release an OrtExternalSemaphoreHandle instance. * - * \param[in] handle The OrtExternalSemaphoreHandle instance to release. May be nullptr. + * \param[in] input The OrtExternalSemaphoreHandle instance to release. May be nullptr. * * \since Version 1.24. */ diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 901f7f10f3754..5cf8cf88bb054 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1164,6 +1164,72 @@ OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( const std::vector& ep_devices, const char* compatibility_info); +/** \brief Extract EP compatibility info from a precompiled model file. + * + * Parses the model file to extract the compatibility info string for a specific execution provider + * from the model's metadata properties. This is only applicable to models that have been precompiled + * for an EP. Standard ONNX models do not contain this information. + * + * \note This operation parses the full ONNX ModelProto from disk. + * + * \param model_path Path to the ONNX model file. + * \param ep_type The execution provider type string. Must be non-empty. + * Use ConstEpDevice::EpName() to get this value. + * \param allocator Allocator to use for the output string. + * \return The compatibility info string, or nullptr if not found for this EP. Caller must free via allocator. + * \throws Ort::Exception on error. + */ +AllocatedStringPtr GetCompatibilityInfoFromModelAllocated(const ORTCHAR_T* model_path, const char* ep_type, + OrtAllocator* allocator); + +/** \brief Extract EP compatibility info from precompiled model bytes in memory. + * + * Same as GetCompatibilityInfoFromModelAllocated but reads from a memory buffer. + * Useful when precompiled models are loaded from encrypted storage, network, or other non-file sources. + * + * \param model_data Pointer to the model data in memory. + * \param model_data_length Size of the model data in bytes. + * \param ep_type The execution provider type string. Must be non-empty. + * \param allocator Allocator to use for the output string. + * \return The compatibility info string, or nullptr if not found for this EP. Caller must free via allocator. + * \throws Ort::Exception on error. + */ +AllocatedStringPtr GetCompatibilityInfoFromModelBytesAllocated(const void* model_data, size_t model_data_length, + const char* ep_type, OrtAllocator* allocator); + +namespace detail { +template +struct EpAssignedNodeImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + std::string GetName() const; + std::string GetDomain() const; + std::string GetOperatorType() const; +}; +} // namespace detail + +/** \brief Constant wrapper around ::OrtEpAssignedNode + * \remarks EpAssignedNode is always read-only for ORT API users. + */ +using ConstEpAssignedNode = detail::EpAssignedNodeImpl>; + +namespace detail { +template +struct EpAssignedSubgraphImpl : Ort::detail::Base { + using B = Ort::detail::Base; + using B::B; + + std::string GetEpName() const; + std::vector GetNodes() const; +}; +} // namespace detail + +/** \brief Constant wrapper around ::OrtEpAssignedSubgraph + * \remarks EpAssignedSubgraph is always read-only for ORT API users. + */ +using ConstEpAssignedSubgraph = detail::EpAssignedSubgraphImpl>; + /** \brief The Env (Environment) * * The Env holds the logging state used by all other objects. @@ -1307,6 +1373,15 @@ struct RunOptions : detail::Base { * \param adapter The LoraAdapter to be used as the active adapter */ RunOptions& AddActiveLoraAdapter(const LoraAdapter& adapter); + + /** \brief Associate a sync stream with the run options. + * + * When set, the EP uses this stream for execution, enabling proper + * synchronization with imported external semaphores. Wraps OrtApi::RunOptionsSetSyncStream. + * + * \param stream The OrtSyncStream to associate with these run options. May be nullptr to clear. + */ + RunOptions& SetSyncStream(OrtSyncStream* stream); }; namespace detail { @@ -1665,9 +1740,14 @@ struct ConstSessionImpl : Base { int GetOpset(const std::string& domain) const; ///< Wraps OrtApi::SessionGetOpsetForDomain - // Will move before checkin if that's the case. std::vector GetInputs() const; std::vector GetOutputs() const; + + /** \brief Returns information on the subgraph/nodes assigned to execution providers in the session. + * + * \return A list of ConstEpAssignedSubgraph instances. + */ + std::vector GetEpGraphAssignmentInfo() const; }; template @@ -2140,6 +2220,19 @@ struct ConstValueImpl : Base { const R* GetSparseTensorValues() const; #endif + + /// + /// Returns the tensor's element type and a reference to the tensor's internal shape data. The shape data is owned + /// by the Ort::Value and becomes invalid when the Ort::Value is destroyed or if the underlying shape data is + /// updated or reallocated. + /// + /// For a scalar, shape.shape is nullptr and shape.shape_len is 0. + /// + /// Wraps OrtApi::GetTensorElementTypeAndShapeDataReference. + /// + /// Output parameter set to the element's data type. + /// Output parameter set to the OrtValue instance's shape data and number of elements. + void GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, Shape& shape) const; }; template diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 267838e41887e..1a3e49130a1d1 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -747,6 +747,61 @@ inline EpDevice::EpDevice(OrtEpFactory& ep_factory, ConstHardwareDevice& hardwar ThrowOnError(GetEpApi().CreateEpDevice(&ep_factory, hardware_device, ep_metadata, ep_options, &p_)); } +namespace detail { +template +inline std::string EpAssignedSubgraphImpl::GetEpName() const { + const char* ep_name = nullptr; + + // Returned null-terminated string will not be null if API function returns successfully. + ThrowOnError(GetApi().EpAssignedSubgraph_GetEpName(this->p_, &ep_name)); + return std::string(ep_name); +} + +template +inline std::vector EpAssignedSubgraphImpl::GetNodes() const { + size_t num_ep_nodes = 0; + const OrtEpAssignedNode* const* ep_node_ptrs = nullptr; + ThrowOnError(GetApi().EpAssignedSubgraph_GetNodes(this->p_, &ep_node_ptrs, &num_ep_nodes)); + + std::vector ep_nodes; + if (num_ep_nodes > 0) { + ep_nodes.reserve(num_ep_nodes); + for (size_t i = 0; i < num_ep_nodes; ++i) { + ep_nodes.emplace_back(ep_node_ptrs[i]); + } + } + + return ep_nodes; +} + +template +inline std::string EpAssignedNodeImpl::GetName() const { + const char* node_name = nullptr; + + // Returned null-terminated string will not be null if API function returns successfully. + ThrowOnError(GetApi().EpAssignedNode_GetName(this->p_, &node_name)); + return std::string(node_name); +} + +template +inline std::string EpAssignedNodeImpl::GetDomain() const { + const char* domain = nullptr; + + // Returned null-terminated string will not be null if API function returns successfully. + ThrowOnError(GetApi().EpAssignedNode_GetDomain(this->p_, &domain)); + return std::string(domain); +} + +template +inline std::string EpAssignedNodeImpl::GetOperatorType() const { + const char* op_type = nullptr; + + // Returned null-terminated string will not be null if API function returns successfully. + ThrowOnError(GetApi().EpAssignedNode_GetOperatorType(this->p_, &op_type)); + return std::string(op_type); +} +} // namespace detail + inline Env::Env(OrtLoggingLevel logging_level, _In_ const char* logid) { ThrowOnError(GetApi().CreateEnv(logging_level, logid, &p_)); if (strcmp(logid, "onnxruntime-node") == 0) { @@ -928,6 +983,20 @@ inline OrtCompiledModelCompatibility GetModelCompatibilityForEpDevices( return status; } +inline AllocatedStringPtr GetCompatibilityInfoFromModelAllocated(const ORTCHAR_T* model_path, const char* ep_type, + OrtAllocator* allocator) { + char* compat_info = nullptr; + ThrowOnError(GetApi().GetCompatibilityInfoFromModel(model_path, ep_type, allocator, &compat_info)); + return AllocatedStringPtr(compat_info, detail::AllocatedFree(allocator)); +} + +inline AllocatedStringPtr GetCompatibilityInfoFromModelBytesAllocated(const void* model_data, size_t model_data_length, + const char* ep_type, OrtAllocator* allocator) { + char* compat_info = nullptr; + ThrowOnError(GetApi().GetCompatibilityInfoFromModelBytes(model_data, model_data_length, ep_type, allocator, &compat_info)); + return AllocatedStringPtr(compat_info, detail::AllocatedFree(allocator)); +} + inline LoraAdapter LoraAdapter::CreateLoraAdapter(const std::basic_string& adapter_path, OrtAllocator* allocator) { OrtLoraAdapter* p; @@ -1003,6 +1072,11 @@ inline RunOptions& RunOptions::AddActiveLoraAdapter(const LoraAdapter& adapter) return *this; } +inline RunOptions& RunOptions::SetSyncStream(OrtSyncStream* stream) { + GetApi().RunOptionsSetSyncStream(p_, stream); + return *this; +} + inline ModelCompilationOptions::ModelCompilationOptions(const Env& env, const SessionOptions& session_options) { ThrowOnError(GetCompileApi().CreateModelCompilationOptionsFromSessionOptions(env, session_options, &this->p_)); } @@ -1756,6 +1830,23 @@ std::vector ConstSessionImpl::GetOutputs() const { return outputs; } +template +inline std::vector ConstSessionImpl::GetEpGraphAssignmentInfo() const { + size_t num_ep_subgraphs = 0; + const OrtEpAssignedSubgraph* const* ep_subgraph_ptrs = nullptr; + ThrowOnError(GetApi().Session_GetEpGraphAssignmentInfo(this->p_, &ep_subgraph_ptrs, &num_ep_subgraphs)); + + std::vector ep_subgraphs; + if (num_ep_subgraphs > 0) { + ep_subgraphs.reserve(num_ep_subgraphs); + for (size_t i = 0; i < num_ep_subgraphs; ++i) { + ep_subgraphs.emplace_back(ep_subgraph_ptrs[i]); + } + } + + return ep_subgraphs; +} + template inline std::vector SessionImpl::Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, const char* const* output_names, size_t output_count) { @@ -2286,6 +2377,13 @@ inline const R* ConstValueImpl::GetSparseTensorValues() const { #endif +template +void ConstValueImpl::GetTensorElementTypeAndShapeDataReference(ONNXTensorElementDataType& elem_type, + Shape& shape) const { + ThrowOnError(GetApi().GetTensorElementTypeAndShapeDataReference(this->p_, &elem_type, &shape.shape, + &shape.shape_len)); +} + template void ValueImpl::FillStringTensor(const char* const* s, size_t s_len) { ThrowOnError(GetApi().FillStringTensor(this->p_, s, s_len)); diff --git a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h index b64e13531c260..b888d0d609e55 100644 --- a/include/onnxruntime/core/session/onnxruntime_ep_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_ep_c_api.h @@ -43,11 +43,9 @@ ORT_RUNTIME_CLASS(ExternalResourceImporterImpl); * \since Version 1.24. */ struct OrtExternalMemoryHandle { - uint32_t version; ///< Must be ORT_API_VERSION - const OrtEpDevice* ep_device; ///< EP device that created this handle - OrtExternalMemoryHandleType handle_type; ///< Original handle type for tracking - size_t size_bytes; ///< Size of the imported memory - size_t offset_bytes; ///< Offset into the imported memory + uint32_t version; ///< Must be ORT_API_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalMemoryDescriptor descriptor; ///< External memory descriptor /** \brief Release callback for this handle. EP sets this to its release function. * @@ -72,9 +70,9 @@ struct OrtExternalMemoryHandle { * \since Version 1.24. */ struct OrtExternalSemaphoreHandle { - uint32_t version; ///< Must be ORT_API_VERSION - const OrtEpDevice* ep_device; ///< EP device that created this handle - OrtExternalSemaphoreType type; ///< Original semaphore type + uint32_t version; ///< Must be ORT_API_VERSION + const OrtEpDevice* ep_device; ///< EP device that created this handle + OrtExternalSemaphoreDescriptor descriptor; ///< External semaphore descriptor /** \brief Release callback for this handle. EP sets this to its release function. * @@ -666,7 +664,7 @@ struct OrtLoopKernelHelper; typedef struct OrtLoopKernelHelper OrtLoopKernelHelper; /** - * \brief Contains helper functions for a Loop OrtKernelImpl created via ::CreateLoopKernel(). + * \brief Contains helper functions for a Loop OrtKernelImpl created via OrtEpApi::CreateLoopKernel. * \since Version 1.24. */ struct OrtLoopKernelHelper { @@ -709,7 +707,7 @@ struct OrtScanKernelHelper; typedef struct OrtScanKernelHelper OrtScanKernelHelper; /** - * \brief Contains helper functions for a Scan OrtKernelImpl created via ::CreateScanKernel(). + * \brief Contains helper functions for a Scan OrtKernelImpl created via OrtEpApi::CreateScanKernel. * \since Version 1.24. */ struct OrtScanKernelHelper { @@ -1433,13 +1431,13 @@ struct OrtEpApi { /** \brief Gets a new OrtKeyValuePairs instance containing a copy of all configuration entries set on the environment. * * \note An application provides environment-level configuration options for execution provider libraries by - * using keys with the prefix 'ep_factory..'. Ex: the key 'ep_factory.my_ep.some_ep_key' represents + * using keys with the prefix 'ep_factory.\\.'. Ex: the key 'ep_factory.my_ep.some_ep_key' represents * a key named 'some_ep_key' that is meant to be consumed by an execution provider named 'my_ep'. Refer to * the specific execution provider's documentation for valid keys and values. * * \note Refer to onnxruntime_env_config_keys.h for common configuration entry keys and their supported values. * - * \param[out] out Output parameter set to the OrtKeyValuePairs instance containing all configuration entries. + * \param[out] config_entries Output parameter set to the OrtKeyValuePairs instance containing all configuration entries. * Must be released via OrtApi::ReleaseKeyValuePairs. * * \snippet{doc} snippets.dox OrtStatus Return Value @@ -2048,6 +2046,65 @@ struct OrtEpFactory { ORT_API2_STATUS(CreateExternalResourceImporterForDevice, _In_ OrtEpFactory* this_ptr, _In_ const OrtEpDevice* ep_device, _Outptr_result_maybenull_ OrtExternalResourceImporterImpl** out_importer); + + /** \brief Returns the number of OrtCustomOpDomains that this factory provides. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] num_domains Output parameter set to the number of provided OrtCustomOpDomain instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetNumCustomOpDomains, _In_ OrtEpFactory* this_ptr, _Out_ size_t* num_domains); + + /** \brief Gets the EP-specific OrtCustomOpDomains. + * + * This function is used when running inference on a model that contains EP-specific custom operations. + * + * Workflow: + * 1. The EP factory implements this function to supply a list of OrtCustomOpDomain instances. + * 2. The application either 1) calls SessionOptionsAppendExecutionProvider_V2() with an OrtEpDevice containing + * the plugin EP's factory or 2) enables auto ep selection. + * 3. 1) SessionOptionsAppendExecutionProvider_V2() appends the provided OrtCustomOpDomains to the + * session options or 2) ORT registers the OrtCustomOpDomains provided by the EP devices + * that could be potentially selected. + * + * As a result, any session created from these session options will have these custom op domains registered + * in ORT, ensuring that the custom ops are properly recognized and validated when the model is loaded. + * + * Plugin EPs can provide two types of custom ops: + * 1. A full OrtCustomOp with a concrete kernel implementation + * - A Plugin EP can supply an OrtCustomOp and a corresponding CustomKernel::Compute() implementation. + * - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT + * that the custom node should NOT be fused or compiled. Instead, ORT should invoke + * the custom node's Compute() function at runtime. + * + * 2. A "placeholder" OrtCustomOp with an empty kernel implementation + * - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() + * does nothing. The purpose is to satisfy model validation during model loading by + * registering the custom op as a valid operator in the session. + * - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to + * notify ORT that this custom node should be fused and compiled by the EP. + * - In Compile(), the EP executes its compiled bits to perform inference for + * the fused custom node. + * + * Note: The OrtCustomOpDomain instances must be valid while any session is using them. + EP factory has the responsibility to release OrtCustomOpDomain instances it creates. It happens + * automatically if using the C++ Ort::CustomOpDomain class. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[out] domains Array of `num_domains` elements pre-allocated by ORT that should be filled with + OrtCustomOpDomain instances created by the EP. The `num_domains` is the value returned by + GetNumCustomOpDomains(). + * \param[in] num_domains The size of the `domains` array pre-allocated by ORT. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.24. + */ + ORT_API2_STATUS(GetCustomOpDomains, _In_ OrtEpFactory* this_ptr, + _Out_writes_all_(num_domains) OrtCustomOpDomain** domains, _In_ size_t num_domains); }; #ifdef __cplusplus diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 64a434e2fe301..1ea147a0079cc 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -374,6 +374,12 @@ static const char* const kOrtSessionOptionsEpContextModelExternalInitializersFil // - "1": Gemm FastMath mode is enabled. static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; +// Use LUT (Lookup Table) based GEMM for quantized models when available. +// Option values: +// - "0": Do not use LUT based GEMM. [DEFAULT] +// - "1": Use LUT based GEMM when available. +static const char* const kOrtSessionOptionsMlasLutGemm = "mlas.use_lut_gemm"; + // When converting DQ + MatMul -> MatMulNBits, the accuracy level of the MatMulNBits is controlled by this option. // Refer to MatMulNBits op schema for more details. // If not provided, default is 4. @@ -415,3 +421,11 @@ static const char* const kOrtSessionOptionsFailOnSuboptimalCompiledModel = // "high_power_saver", "low_balanced", "extreme_power_saver", "low_power_saver", "power_saver", // "sustained_high_performance". Default to "default". static const char* const kOrtEpDynamicOptionsQnnHtpPerformanceMode = "ep.dynamic.qnn_htp_performance_mode"; + +// Enables the session to record information about the subgraphs/nodes assigned to execution providers. +// When enabled, an application may call Session_GetEpGraphAssignmentInfo() to retrieve the information. +// +// Option values: +// - "0": Recording of EP graph assignment information is disabled. [DEFAULT] +// - "1": Recording of EP graph assignment information is enabled. +static const char* const kOrtSessionOptionsRecordEpGraphAssignmentInfo = "session.record_ep_graph_assignment_info"; diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 1bf7e3ff6b819..9e819d0a932ef 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.24.0'; +export const version = '1.24.3'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index fca226a2962a7..c647d7b75ca38 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.24.0", + "version": "1.24.3", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.24.0", + "version": "1.24.3", "license": "MIT", "devDependencies": { "globby": "^15.0.0", diff --git a/js/common/package.json b/js/common/package.json index c96a750530d4a..17a651d6ac995 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.24.0", + "version": "1.24.3", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 1bf7e3ff6b819..9e819d0a932ef 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.24.0'; +export const version = '1.24.3'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index 145d11ada7aa3..8d890063b532b 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.24.0", + "version": "1.24.3", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.24.0", + "version": "1.24.3", "hasInstallScript": true, "license": "MIT", "os": [ @@ -30,9 +30,10 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.24.0", + "version": "1.24.3", "license": "MIT", "devDependencies": { + "globby": "^15.0.0", "typedoc": "^0.25.7" } }, @@ -2103,6 +2104,7 @@ "onnxruntime-common": { "version": "file:../common", "requires": { + "globby": "^15.0.0", "typedoc": "^0.25.7" } }, diff --git a/js/node/package.json b/js/node/package.json index 3490ae8cf0cce..96be689947352 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -11,7 +11,7 @@ 6 ] }, - "version": "1.24.0", + "version": "1.24.3", "dependencies": { "adm-zip": "^0.5.16", "global-agent": "^3.0.0", diff --git a/js/node/script/install-metadata-versions.js b/js/node/script/install-metadata-versions.js index f03a78878788b..1747fbc37259e 100644 --- a/js/node/script/install-metadata-versions.js +++ b/js/node/script/install-metadata-versions.js @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -module.exports = { nuget: [{ feed: 'nuget', version: '1.24.0' }] }; +module.exports = { nuget: [{ feed: 'nuget', version: '1.24.3' }] }; diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 1bf7e3ff6b819..9e819d0a932ef 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.24.0'; +export const version = '1.24.3'; diff --git a/js/react_native/package-lock.json b/js/react_native/package-lock.json index de8d631362db7..61f389ecbd618 100644 --- a/js/react_native/package-lock.json +++ b/js/react_native/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-react-native", - "version": "1.24.0", + "version": "1.24.3", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "onnxruntime-react-native", - "version": "1.24.0", + "version": "1.24.3", "license": "MIT", "dependencies": { "onnxruntime-common": "file:../common" @@ -30,7 +30,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.24.0", + "version": "1.24.3", "license": "MIT", "devDependencies": { "globby": "^15.0.0", @@ -92,6 +92,7 @@ "integrity": "sha512-BBt3opiCOxUr9euZ5/ro/Xv8/V7yJ5bjYMqG/C1YAo8MIKAnumZalCN+msbci3Pigy4lIQfPUpfMM27HMGaYEA==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@ampproject/remapping": "^2.2.0", "@babel/code-frame": "^7.24.7", @@ -1942,6 +1943,7 @@ "integrity": "sha512-vX3qPGE8sEKEAZCWk05k3cpTAE3/nOYca++JA+Rd0z2NCNzabmYvEiSShKzm10zdquOIAVXsy2Ei/DTW34KlKQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@babel/compat-data": "^7.26.8", "@babel/helper-compilation-targets": "^7.26.5", @@ -3509,6 +3511,7 @@ } ], "license": "MIT", + "peer": true, "dependencies": { "caniuse-lite": "^1.0.30001688", "electron-to-chromium": "^1.5.73", @@ -7001,6 +7004,7 @@ "integrity": "sha512-/3IjMdb2L9QbBdWiW5e3P2/npwMBaU9mHCSCUzNln0ZCYbcfTsGbTJrU/kGemdH2IWmB2ioZ+zkxtmq6g09fGQ==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "loose-envify": "^1.1.0" }, @@ -7030,6 +7034,7 @@ "integrity": "sha512-yvQIX+ZXOHMFnhmwZ1fBpRI/53k+iLN8DxVf24Fx4ABU63RGAYfyCZC0/3W+5OUVx4KSIZUv4Tv+/NGIieBOwg==", "dev": true, "license": "MIT", + "peer": true, "dependencies": { "@jest/create-cache-key-function": "^29.6.3", "@react-native-community/cli": "12.3.7", diff --git a/js/react_native/package.json b/js/react_native/package.json index d5c1641d92a1e..7c846efae8b62 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -37,7 +37,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.24.0", + "version": "1.24.3", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 1bf7e3ff6b819..9e819d0a932ef 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.24.0'; +export const version = '1.24.3'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 18bf30a325d83..994aeb83a0ed5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -132,7 +132,7 @@ export const parseConvTransposeAttributes = (attributes: Record typeof attributes.autoPad == 'undefined' ? 0 : (attributes.autoPad as number) ]; const dilations = attributes.dilations as [number, number]; - const group = attributes.group as number; + const group = (attributes.group as number) ?? 1; // default to 1 per ONNX spec const kernelShape = attributes.kernelShape as [number, number]; const pads = attributes.pads as [number, number, number, number]; const strides = attributes.strides as [number, number]; diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index a9ef6c72314dd..ba4b9578207f0 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -194,6 +194,13 @@ export const initializeWebAssembly = async (flags: Env.WebAssemblyFlags): Promis if (wasmBinaryOverride) { // Set a custom buffer which contains the WebAssembly binary. This will skip the wasm file fetching. config.wasmBinary = wasmBinaryOverride; + + // Offer an implementation of locateFile() that returns the file name directly. This helps to avoid an error + // thrown later from the following code when `import.meta.url` is a blob URL: + // ``` + // return new URL("ort-wasm-simd-threaded.jsep.wasm", import.meta.url).href; + // ``` + config.locateFile = (fileName) => fileName; } else if (wasmPathOverride || wasmPrefixOverride) { // A callback function to locate the WebAssembly file. The function should return the full path of the file. // diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index e2e46bb37dcfc..6c899d1ae9cf5 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -272,7 +272,9 @@ export const importWasmModule = async ( } } else { // if the script source is available, we can check if it is from the same origin. - useEmbeddedModule = isSameOrigin(scriptSrc); + // Also use the embedded module when wasmBinary is provided and single-threaded (eg. Blob URL workers + // where isSameOrigin fails but no file resolution or worker spawning is needed). + useEmbeddedModule = isSameOrigin(scriptSrc) || (isWasmOverridden && !isMultiThreaded); } } if (useEmbeddedModule) { diff --git a/js/web/package-lock.json b/js/web/package-lock.json index f2e3a7b8ed57c..01a6e497a9a75 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.24.0", + "version": "1.24.3", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.24.0", + "version": "1.24.3", "license": "MIT", "dependencies": { "flatbuffers": "^25.1.24", @@ -50,7 +50,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.24.0", + "version": "1.24.3", "license": "MIT", "devDependencies": { "globby": "^15.0.0", @@ -635,6 +635,7 @@ "resolved": "https://registry.npmjs.org/chai/-/chai-4.3.7.tgz", "integrity": "sha512-HLnAzZ2iupm25PlN0xFreAlBA5zaBSv3og0DdeGA4Ar6h6rJ3A0rolRUKJhSF2V10GZKDgWF/VmAEsNWjCRB+A==", "dev": true, + "peer": true, "dependencies": { "assertion-error": "^1.1.0", "check-error": "^1.0.2", @@ -2081,6 +2082,7 @@ "resolved": "https://registry.npmjs.org/karma/-/karma-6.4.1.tgz", "integrity": "sha512-Cj57NKOskK7wtFWSlMvZf459iX+kpYIPXmkNUzP2WAFcA7nhr/ALn5R7sw3w+1udFDcpMx/tuB8d5amgm3ijaA==", "dev": true, + "peer": true, "dependencies": { "@colors/colors": "1.5.0", "body-parser": "^1.19.0", @@ -4201,6 +4203,7 @@ "resolved": "https://registry.npmjs.org/chai/-/chai-4.3.7.tgz", "integrity": "sha512-HLnAzZ2iupm25PlN0xFreAlBA5zaBSv3og0DdeGA4Ar6h6rJ3A0rolRUKJhSF2V10GZKDgWF/VmAEsNWjCRB+A==", "dev": true, + "peer": true, "requires": { "assertion-error": "^1.1.0", "check-error": "^1.0.2", @@ -5329,6 +5332,7 @@ "resolved": "https://registry.npmjs.org/karma/-/karma-6.4.1.tgz", "integrity": "sha512-Cj57NKOskK7wtFWSlMvZf459iX+kpYIPXmkNUzP2WAFcA7nhr/ALn5R7sw3w+1udFDcpMx/tuB8d5amgm3ijaA==", "dev": true, + "peer": true, "requires": { "@colors/colors": "1.5.0", "body-parser": "^1.19.0", diff --git a/js/web/package.json b/js/web/package.json index 85b1deeec8490..23d89ca57199f 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -7,7 +7,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.24.0", + "version": "1.24.3", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^25.1.24", diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 5ad3647659cee..ace3e4dd4bf46 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -10,7 +10,7 @@ import contextlib -__version__ = "1.24.0" +__version__ = "1.24.3" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). @@ -34,6 +34,8 @@ OrtArenaCfg, # noqa: F401 OrtCompileApiFlags, # noqa: F401 OrtDeviceMemoryType, # noqa: F401 + OrtEpAssignedNode, # noqa: F401 + OrtEpAssignedSubgraph, # noqa: F401 OrtEpDevice, # noqa: F401 OrtExecutionProviderDevicePolicy, # noqa: F401 OrtExternalInitializerInfo, # noqa: F401 diff --git a/onnxruntime/contrib_ops/cpu/activations.h b/onnxruntime/contrib_ops/cpu/activations.h index 7e64235d3fc3d..f00fad809968f 100644 --- a/onnxruntime/contrib_ops/cpu/activations.h +++ b/onnxruntime/contrib_ops/cpu/activations.h @@ -77,17 +77,27 @@ class QuickGelu : public OpKernel { const T* p_input = input_data + start; T* p_output = output_data + start; int64_t count = std::min(length_per_task, elem_count - start); - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * alpha_; - } - MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + if (alpha_ != 1.0f) { + // TODO: Consider vectorizing this scalar multiplication. + // It needs exposing a new API in MLAS to take in a scalar + // that will be used in the elementwise multiplication. + // Estimate the cost-benefit tradeoff before proceeding + // with that optimization. + for (int64_t i = 0; i < count; i++) { + p_output[i] = p_input[i] * alpha_; + } - for (int64_t i = 0; i < count; i++) { - p_output[i] = p_input[i] * p_output[i]; + MlasComputeLogistic(p_output, p_output, onnxruntime::narrow(count)); + } else { + // SILU activation - this needs no `alpha_` scaling as `alpha_` will be 1.0f + MlasComputeLogistic(p_input, p_output, onnxruntime::narrow(count)); } + + MlasEltwiseMul(p_input, p_output, p_output, onnxruntime::narrow(count)); }, 0); + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h index f237b24b899a0..4ad11dce7e093 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_parameters.h @@ -25,15 +25,16 @@ struct AttentionParameters { int num_splits; // number of splits for splitkv int rotary_dim = 0; // rotary embedding dimension int beam_width; - bool is_unidirectional; - bool past_present_share_buffer; + bool is_unidirectional = false; + bool past_present_share_buffer = false; bool is_packed_qkv = false; // whether qkv is packed - bool do_rotary; - bool broadcast_attn_bias_dim_0; - bool broadcast_attn_bias_dim_1; + bool do_rotary = false; + bool broadcast_attn_bias_dim_0 = false; + bool broadcast_attn_bias_dim_1 = false; float mask_filter_value; float scale; - bool use_tf32; + bool use_tf32 = false; + bool is_output_bnsh = false; // whether the output format is BNSH AttentionMaskType mask_type; AttentionQkvFormat qkv_format; }; @@ -87,9 +88,8 @@ struct GroupQueryAttentionParameters : AttentionParameters { int seqlen_past_kv_cache; // sequence length of past kv tensor int seqlen_present_kv_cache; // sequence length of present kv tensor int local_window_size; // Mask out tokens prior to total_sequence_length - local_window_size - bool kv_share_buffer; - bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 - bool is_first_prompt; // indicates whether this is first decoding step + bool is_subsequent_prompt; // indicates whether we have past context and seqlen > 1 + bool is_first_prompt; // indicates whether this is first decoding step bool rotary_interleaved; bool use_smooth_softmax; float softcap; diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h index 257c5a189b3bd..bd30418030dc2 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_helper.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_helper.h @@ -35,44 +35,86 @@ struct MoEParameters { }; namespace moe_helper { +// Helper to check shape dimensions +#define ASSERT_SHAPE_DIMENSION(shape_ptr, dim, name) \ + if (shape_ptr != nullptr) { \ + if (shape_ptr->NumDimensions() != dim) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have ", dim, " dimensions, got ", \ + shape_ptr->NumDimensions()); \ + } \ + } + +#define ASSERT_SHAPE_3D(shape_ptr, name) ASSERT_SHAPE_DIMENSION(shape_ptr, 3, name) + +#define CHECK_SHAPE(shape_ptr, name, ...) \ + if (shape_ptr != nullptr) { \ + const TensorShape& expected_shape = make_shape(__VA_ARGS__); \ + if (*shape_ptr != expected_shape) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input '", name, \ + "' is expected to have shape ", expected_shape, \ + ", got ", *shape_ptr); \ + } \ + } + template Status CheckInputs(MoEParameters& parameters, - const Tensor* input, // required - const Tensor* router_probs, // required - const Tensor* fc1_experts_weights, // required - const Tensor* fc1_experts_bias, // optional - const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc1_zero_points, // optional, for qMoE - const Tensor* fc2_experts_weights, // required - const Tensor* fc2_experts_bias, // optional - const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc2_zero_points, // optional, for qMoE - const Tensor* fc3_experts_weights, // optional - const Tensor* fc3_experts_bias, // optional - const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE - const Tensor* fc3_zero_points, // optional, for qMoE - const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const Tensor* input, // required + const Tensor* router_probs, // required + const TensorShape* fc1_experts_weights_shape, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const TensorShape* fc2_experts_weights_shape, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const TensorShape* fc3_experts_weights_shape, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) const bool is_fused_swiglu, const int64_t block_size = 0) { // block size for block-wise quantization - // Check dimensions of input to avoid input_dims index out of range. CHECK_TENSOR_SHAPE will verify each tensor later. + // Required inputs + if (input == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input' is required."); + } ASSERT_TENSOR_2D_OR_3D(input); - ASSERT_TENSOR_3D(fc1_experts_weights); - ASSERT_TENSOR_3D(fc2_experts_weights); + + if (router_probs == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'router_probs' is required."); + } ASSERT_TENSOR_2D(router_probs); + if (fc1_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc1_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc1_experts_weights_shape, "fc1_experts_weights"); + + if (fc2_experts_weights_shape == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'fc2_experts_weights' is required."); + } + ASSERT_SHAPE_3D(fc2_experts_weights_shape, "fc2_experts_weights"); + const auto& input_dims = input->Shape().GetDims(); const auto& router_probs_dims = router_probs->Shape().GetDims(); - const auto& fc1_experts_weights_dims = fc1_experts_weights->Shape().GetDims(); - const auto& fc2_experts_weights_dims = fc2_experts_weights->Shape().GetDims(); int64_t num_rows = input_dims.size() == 2 ? input_dims[0] : input_dims[0] * input_dims[1]; int64_t hidden_size = input_dims[input_dims.size() - 1]; - int64_t local_num_experts = fc1_experts_weights_dims[0]; int64_t num_experts = router_probs_dims[1]; - int64_t inter_size = (fc2_experts_weights_dims[1] * fc2_experts_weights_dims[2] * pack_size) / hidden_size; - const bool legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || - (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); + int64_t local_num_experts = fc1_experts_weights_shape->GetDims()[0]; + + int64_t inter_size = (fc2_experts_weights_shape->GetDims()[1] * + fc2_experts_weights_shape->GetDims()[2] * pack_size) / + hidden_size; + + bool legacy_shape = false; + const auto& fc2_experts_weights_dims = fc2_experts_weights_shape->GetDims(); + const auto& fc1_experts_weights_dims = fc1_experts_weights_shape->GetDims(); + legacy_shape = (hidden_size != inter_size && fc2_experts_weights_dims[1] == inter_size) || + (hidden_size == inter_size && is_fused_swiglu && fc1_experts_weights_dims[1] == hidden_size); // Fused swiglu doubles the output dimension of FC1 since it fused two GEMMs into one. const int64_t fc1_inter_size = is_fused_swiglu ? (inter_size + inter_size) : inter_size; @@ -80,13 +122,13 @@ Status CheckInputs(MoEParameters& parameters, if (legacy_shape) { // legacy shape does not match column major memory layout. This is for backward compatibility. - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, hidden_size, fc1_inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, hidden_size, fc1_inter_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, hidden_size, inter_size / pack_size); } else { - CHECK_TENSOR_SHAPE(fc1_experts_weights, num_experts, fc1_inter_size, hidden_size / pack_size); - CHECK_TENSOR_SHAPE(fc2_experts_weights, num_experts, hidden_size, inter_size / pack_size); - CHECK_TENSOR_SHAPE(fc3_experts_weights, num_experts, inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc1_experts_weights_shape, "fc1_experts_weights", num_experts, fc1_inter_size, hidden_size / pack_size); + CHECK_SHAPE(fc2_experts_weights_shape, "fc2_experts_weights", num_experts, hidden_size, inter_size / pack_size); + CHECK_SHAPE(fc3_experts_weights_shape, "fc3_experts_weights", num_experts, inter_size, hidden_size / pack_size); } CHECK_TENSOR_SHAPE(router_probs, num_rows, num_experts); @@ -168,9 +210,11 @@ Status CheckInputs(MoEParameters& parameters, } } - if (fc3_experts_weights == nullptr) { + if (fc3_experts_weights_shape == nullptr) { + // If fc3 weights are not provided, ensure no other fc3 parameters are provided ORT_ENFORCE(fc3_experts_bias == nullptr && fc3_experts_scales == nullptr && fc3_zero_points == nullptr); } else { + // If fc3 weights are provided, ensure scales logic is consistent ORT_ENFORCE(fc1_experts_scales == nullptr || fc3_experts_scales != nullptr); // MOE no scale, or qMOE need scales } @@ -200,6 +244,36 @@ Status CheckInputs(MoEParameters& parameters, return Status::OK(); } +template +Status CheckInputs(MoEParameters& parameters, + const Tensor* input, // required + const Tensor* router_probs, // required + const Tensor* fc1_experts_weights, // required + const Tensor* fc1_experts_bias, // optional + const Tensor* fc1_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc1_zero_points, // optional, for qMoE + const Tensor* fc2_experts_weights, // required + const Tensor* fc2_experts_bias, // optional + const Tensor* fc2_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc2_zero_points, // optional, for qMoE + const Tensor* fc3_experts_weights, // optional + const Tensor* fc3_experts_bias, // optional + const Tensor* fc3_experts_scales, // required for qMoE; NULL for MOE + const Tensor* fc3_zero_points, // optional, for qMoE + const int64_t pack_size, // number of weights packed together (like 2 for uint4 packed to uint8) + const bool is_fused_swiglu, + const int64_t block_size = 0) { // block size for block-wise quantization + + const TensorShape* fc1_shape = (fc1_experts_weights != nullptr) ? &fc1_experts_weights->Shape() : nullptr; + const TensorShape* fc2_shape = (fc2_experts_weights != nullptr) ? &fc2_experts_weights->Shape() : nullptr; + const TensorShape* fc3_shape = (fc3_experts_weights != nullptr) ? &fc3_experts_weights->Shape() : nullptr; + + return CheckInputs(parameters, input, router_probs, fc1_shape, fc1_experts_bias, fc1_experts_scales, fc1_zero_points, + fc2_shape, fc2_experts_bias, fc2_experts_scales, fc2_zero_points, + fc3_shape, fc3_experts_bias, fc3_experts_scales, fc3_zero_points, + pack_size, is_fused_swiglu, block_size); +} + } // namespace moe_helper } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc index 6d1d191689466..81d2b0f8efdc6 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.cc @@ -13,6 +13,7 @@ #include "core/common/narrow.h" #include "core/framework/tensor_type_and_shape.h" #include "core/util/math.h" +#include "core/platform/env_var_utils.h" #include "contrib_ops/cpu/moe/moe_utils.h" #include "contrib_ops/cpu/moe/moe_helper.h" @@ -69,13 +70,13 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, out_qtype = BlkQ4Sym64; } else if (block_size == 128) { out_qtype = BlkQ4Sym128; - } else if (block_size == 0) { + } else if (block_size == 0 || block_size == 32) { out_qtype = BlkQ4Sym; } else { return false; } - size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(cols), static_cast(rows)); + size_t expected_size = MlasQ4GemmPackBSize(out_qtype, static_cast(rows), static_cast(cols)); return expected_size > 0; } @@ -84,6 +85,8 @@ bool CanUseMlasQ4Gemm(int64_t expert_weight_bits, int64_t block_size, namespace onnxruntime { namespace contrib { +constexpr const char* kUseMlasQ4GemmMoe = "ORT_USE_MLAS_Q4_GEMM_MOE"; + template void DequantizeBlockWithMlas(const uint8_t* quantized_data, const TScale* scales, @@ -118,13 +121,23 @@ Status ConvertToMlasQ4Format(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, temp_float, nullptr); - size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(cols), static_cast(rows)); + // Transpose from N x K (weights) to K x N. + // DirectQ4Gemm expects weights to be packed in a specific layout ([K, N] logically) + auto transposed_float_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(rows * cols)); + float* transposed_float = transposed_float_buffer.get(); + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + transposed_float[c * rows + r] = temp_float[r * cols + c]; + } + } + + size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); if (packed_size == 0) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "MLAS Q4 packing not supported for this configuration"); } mlas_packed_buffer = IAllocator::MakeUniquePtr(allocator, packed_size); - MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), temp_float, static_cast(cols), static_cast(rows), static_cast(cols)); + MlasQ4GemmPackB(qtype, mlas_packed_buffer.get(), transposed_float, static_cast(rows), static_cast(cols), static_cast(rows)); return Status::OK(); } @@ -354,6 +367,257 @@ void DequantizeBlock(const uint8_t* quantized_data, DequantizeBlockWithMlas(quantized_data, scales, zero_points, block_size, num_bits, rows, cols, dequantized_data, thread_pool); } +template +void DequantizePrePacked(const uint8_t* prepacked_data, + const TScale* scales, + const uint8_t* zero_points, + int64_t block_size, + int64_t rows, + int64_t cols, + float* dequantized_data, + const gsl::span& scale_dims) { + // prepacked_data is [cols, rows] (transposed, unpacked) + // dequantized_data is [cols, rows] (transposed) + // scales, zero_points correspond to original [rows, cols] layout + + const float default_zp_4bit = 8.0f; + const int64_t blocks_per_row = (block_size > 0) ? ((cols + block_size - 1) / block_size) : 1; + const int64_t zp_pack_size = 2; // Always 2 for 4-bit + + // Iterate over Columns (K) then Rows (N) because prepacked_data is [K, N] + for (int64_t c = 0; c < cols; ++c) { + for (int64_t r = 0; r < rows; ++r) { + uint8_t val = prepacked_data[c * rows + r]; + + int64_t block_idx = (block_size > 0) ? (c / block_size) : 0; + if (block_size > 0) block_idx = std::min(block_idx, blocks_per_row - 1); + + int64_t scale_idx; + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + scale_idx = r * blocks_per_row + block_idx; + } else { // per-channel + scale_idx = r; + } + + float scale = static_cast(scales[scale_idx]); + float zp = default_zp_4bit; + + if (zero_points != nullptr) { + int64_t zp_idx; + bool is_lower_nibble; + + if (scale_dims.size() == 3 && scale_dims[2] > 1) { // block-wise + int64_t zp_blocks_packed = (blocks_per_row + zp_pack_size - 1) / zp_pack_size; + zp_idx = r * zp_blocks_packed + block_idx / 2; + is_lower_nibble = (block_idx % 2 == 0); + } else { + zp_idx = r / 2; + is_lower_nibble = (r % 2 == 0); + } + + uint8_t packed_zp = zero_points[zp_idx]; + zp = is_lower_nibble ? static_cast(packed_zp & 0x0F) : static_cast(packed_zp >> 4); + } + + dequantized_data[c * rows + r] = scale * (static_cast(val) - zp); + } + } +} + +template +Status BuildDirectQ4PackedBCache(const uint8_t* prepacked_weights, + const TScale* scales_data, + int64_t num_experts, + int64_t rows, + int64_t cols, + int64_t block_size, + const gsl::span& scales_dims, + MLAS_BLK_QUANT_TYPE qtype, + AllocatorPtr allocator, + IAllocatorUniquePtr& packed_b) { + const size_t packed_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)); + if (packed_size == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to compute MLAS Q4 packed size for cache"); + } + + const bool is_block_wise = (scales_dims.size() == 3 && scales_dims[2] > 1); + const int64_t scales_expert_stride = is_block_wise ? (rows * scales_dims[2]) : rows; + const size_t prepacked_expert_stride = static_cast(rows * cols); + const size_t total_packed_size = packed_size * static_cast(num_experts); + + packed_b = IAllocator::MakeUniquePtr(allocator, total_packed_size, true); + uint8_t* packed_b_ptr = static_cast(packed_b.get()); + + std::vector dequantized_transposed(static_cast(rows * cols)); + for (int64_t expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const uint8_t* expert_prepacked = prepacked_weights + static_cast(expert_idx) * prepacked_expert_stride; + const TScale* expert_scales = scales_data + expert_idx * scales_expert_stride; + + DequantizePrePacked(expert_prepacked, expert_scales, nullptr, block_size, rows, cols, + dequantized_transposed.data(), scales_dims); + + MlasQ4GemmPackB(qtype, packed_b_ptr + expert_idx * packed_size, dequantized_transposed.data(), + static_cast(rows), static_cast(cols), static_cast(rows)); + } + + return Status::OK(); +} + +template +Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) { + is_packed = false; + + // If scales are prepacked, they are constant initializers. + if (input_idx == 3) { + return Status::OK(); + } + if (input_idx == 6) { + return Status::OK(); + } + + // Only support PrePack for FC1 (2) and FC2 (5) weights + // and only if expert_weight_bits_ == 4 (since we unpack to uint8) + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if (input_idx == 2 || input_idx == 5) { + const auto& shape = tensor.Shape(); + const int64_t num_experts = shape[0]; + const int64_t rows = shape[1]; + const int64_t cols_packed = shape[2]; + const int64_t cols = cols_packed * 2; + + size_t packed_size = static_cast(num_experts * rows * cols); + auto packed_buffer = IAllocator::MakeUniquePtr(alloc, packed_size, true); + uint8_t* dst_base = static_cast(packed_buffer.get()); + const uint8_t* src_base = static_cast(tensor.DataRaw()); + + for (int64_t i = 0; i < num_experts; ++i) { + const uint8_t* src = src_base + i * rows * cols_packed; + uint8_t* dst = dst_base + i * rows * cols; + + for (int64_t r = 0; r < rows; ++r) { + for (int64_t c = 0; c < cols; ++c) { + uint8_t packed_val = src[r * cols_packed + (c / 2)]; + uint8_t val = (c % 2 == 0) ? (packed_val & 0x0F) : (packed_val >> 4); + + dst[c * rows + r] = val; + } + } + } + + if (input_idx == 2) { + fc1_shape_ = shape; + } else if (input_idx == 5) { + fc2_shape_ = shape; + } + + if (prepacked_weights) { + prepacked_weights->buffers_.push_back(std::move(packed_buffer)); + prepacked_weights->buffer_sizes_.push_back(packed_size); + is_packed = true; + + // Pack Shape (Buffer 1) + auto dims = shape.GetDims(); + size_t rank_bytes = sizeof(int64_t); + size_t dims_bytes = dims.size() * sizeof(int64_t); + size_t shape_size = rank_bytes + dims_bytes; + + auto shape_buffer = IAllocator::MakeUniquePtr(alloc, shape_size); + int64_t* buffer_data = static_cast(shape_buffer.get()); + *buffer_data = static_cast(dims.size()); + memcpy(buffer_data + 1, dims.data(), dims_bytes); + + prepacked_weights->buffers_.push_back(std::move(shape_buffer)); + prepacked_weights->buffer_sizes_.push_back(shape_size); + + // Try build MLAS Q4 cache if scales are available + if (use_mlas_q4_gemm_) { + const Tensor* scales_tensor = nullptr; + MLAS_BLK_QUANT_TYPE qtype = BlkQ4Sym; + int scales_idx = -1; + int zp_idx = -1; + + if (input_idx == 2) { // FC1 + scales_idx = 3; + zp_idx = 11; + } else if (input_idx == 5) { // FC2 + scales_idx = 6; + zp_idx = 12; + } + + if (scales_idx != -1 && + (zp_idx >= static_cast(Info().node().InputDefs().size()) || !Info().node().InputDefs()[zp_idx]->Exists()) && + Info().TryGetConstantInput(scales_idx, &scales_tensor) && + scales_tensor != nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, block_size_, rows, cols, qtype)) { + IAllocatorUniquePtr cache_buffer; + const auto& scales_dims = scales_tensor->Shape().GetDims(); + const T* scales_data = scales_tensor->Data(); + // Use the simple packed buffer we just created (buffer 0) as input + const uint8_t* simple_packed = dst_base; + + if (BuildDirectQ4PackedBCache(simple_packed, scales_data, num_experts, rows, cols, + block_size_, scales_dims, qtype, + alloc, cache_buffer) + .IsOK()) { + // Store the MLAS Q4 cache as buffer 2 (after unpacked weights and shape). + size_t cache_size = MlasQ4GemmPackBSize(qtype, static_cast(rows), static_cast(cols)) * static_cast(num_experts); + prepacked_weights->buffers_.push_back(std::move(cache_buffer)); + prepacked_weights->buffer_sizes_.push_back(cache_size); + } + } + } + } + } + + return Status::OK(); +} + +template +Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, + gsl::span /*prepacked_buffer_sizes*/, + int input_idx, + /*out*/ bool& used_shared_buffers) { + used_shared_buffers = false; + + if (expert_weight_bits_ != 4) { + return Status::OK(); + } + + if ((input_idx == 2 || input_idx == 5) && !prepacked_buffers.empty()) { + auto parse_shape = [&](TensorShape& shape) { + if (prepacked_buffers.size() > 1) { + int64_t* buffer_data = static_cast(prepacked_buffers[1].get()); + int64_t rank = buffer_data[0]; + std::vector dims(static_cast(rank)); + memcpy(dims.data(), buffer_data + 1, static_cast(rank) * sizeof(int64_t)); + shape = TensorShape(dims); + } + }; + + if (input_idx == 2) { + packed_fc1_ = std::move(prepacked_buffers[0]); + parse_shape(fc1_shape_); + if (prepacked_buffers.size() > 2) { + packed_fc1_mlas_cache_ = std::move(prepacked_buffers[2]); + } + } else if (input_idx == 5) { + packed_fc2_ = std::move(prepacked_buffers[0]); + parse_shape(fc2_shape_); + if (prepacked_buffers.size() > 2) { + packed_fc2_mlas_cache_ = std::move(prepacked_buffers[2]); + } + } + used_shared_buffers = true; + } + + return Status::OK(); +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info), @@ -362,21 +626,32 @@ QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info) ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, "Attribute 'expert_weight_bits' must be 4 or 8."); block_size_ = op_kernel_info.GetAttrOrDefault("block_size", 0); + ORT_ENFORCE(block_size_ >= 0); if (block_size_ > 0) { ORT_ENFORCE(block_size_ >= 16, "block_size must be >= 16 when provided."); ORT_ENFORCE((block_size_ & (block_size_ - 1)) == 0, "block_size must be a power of 2."); } + + const auto use_mlas_q4_gemm = ParseEnvironmentVariable(kUseMlasQ4GemmMoe); + if (use_mlas_q4_gemm.has_value()) { + use_mlas_q4_gemm_ = *use_mlas_q4_gemm; + use_mlas_q4_gemm_overridden_ = true; + } else { + // Default policy: enable fast path unless this run hits a known accuracy-loss configuration. + use_mlas_q4_gemm_ = true; + use_mlas_q4_gemm_overridden_ = false; + } } template Status QMoECPU::Compute(OpKernelContext* context) const { const auto* input = context->Input(0); const auto* router_probs = context->Input(1); - const auto* fc1_experts_weights = context->Input(2); + const auto* fc1_experts_weights = packed_fc1_ ? nullptr : context->Input(2); const auto* fc1_scales = context->Input(3); const auto* fc1_experts_bias = context->Input(4); - const auto* fc2_experts_weights = context->Input(5); + const auto* fc2_experts_weights = packed_fc2_ ? nullptr : context->Input(5); const auto* fc2_scales = context->Input(6); const auto* fc2_experts_bias = context->Input(7); const auto* fc3_experts_weights = context->Input(8); @@ -386,17 +661,21 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const auto* fc2_zero_points = context->Input(12); const auto* fc3_zero_points = context->Input(13); + const TensorShape* fc1_shape_ptr = packed_fc1_ ? &fc1_shape_ : (fc1_experts_weights ? &fc1_experts_weights->Shape() : nullptr); + const TensorShape* fc2_shape_ptr = packed_fc2_ ? &fc2_shape_ : (fc2_experts_weights ? &fc2_experts_weights->Shape() : nullptr); + const TensorShape* fc3_shape_ptr = fc3_experts_weights ? &fc3_experts_weights->Shape() : nullptr; + MoEParameters moe_params; ORT_RETURN_IF_ERROR(moe_helper::CheckInputs( moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias, fc1_scales, fc1_zero_points, - fc2_experts_weights, fc2_experts_bias, fc2_scales, fc2_zero_points, - fc3_experts_weights, fc3_experts_bias, fc3_scales, fc3_zero_points, + fc1_shape_ptr, fc1_experts_bias, fc1_scales, fc1_zero_points, + fc2_shape_ptr, fc2_experts_bias, fc2_scales, fc2_zero_points, + fc3_shape_ptr, fc3_experts_bias, fc3_scales, fc3_zero_points, expert_weight_bits_ == 4 ? 2 : 1, - true, + activation_type_ == ActivationType::SwiGLU, block_size_)); - if (fc3_experts_weights || fc3_experts_bias || fc3_scales || fc3_zero_points) { + if (fc3_shape_ptr || fc3_experts_bias || fc3_scales || fc3_zero_points) { return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "FC3 gating is not yet implemented on CPU for QMoE"); } @@ -559,8 +838,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const bool is_fc1_block_wise = (fc1_scales_dims.size() == 3 && fc1_scales_dims[2] > 1); const bool is_fc2_block_wise = (fc2_scales_dims.size() == 3 && fc2_scales_dims[2] > 1); - const uint8_t* fc1_weights_data = fc1_experts_weights->Data(); - const uint8_t* fc2_weights_data = fc2_experts_weights->Data(); + const uint8_t* fc1_weights_data = (packed_fc1_ != nullptr) ? nullptr : fc1_experts_weights->template Data(); + const uint8_t* fc2_weights_data = (packed_fc2_ != nullptr) ? nullptr : fc2_experts_weights->template Data(); const T* fc1_scales_data = fc1_scales->Data(); const T* fc2_scales_data = fc2_scales->Data(); const T* fc1_bias_data = fc1_experts_bias ? fc1_experts_bias->Data() : nullptr; @@ -568,6 +847,13 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const uint8_t* fc1_zp_data = fc1_zero_points ? fc1_zero_points->Data() : nullptr; const uint8_t* fc2_zp_data = fc2_zero_points ? fc2_zero_points->Data() : nullptr; + // Known loss-prone case from parity testing: 4-bit symmetric path (row-wise and block-wise). + const bool known_accuracy_loss_case = (expert_weight_bits_ == 4) && + (fc1_zp_data == nullptr) && (fc2_zp_data == nullptr); + const bool use_mlas_q4_gemm_effective = use_mlas_q4_gemm_overridden_ + ? use_mlas_q4_gemm_ + : (use_mlas_q4_gemm_ && !known_accuracy_loss_case); + const int64_t pack_unit = (8 / expert_weight_bits_); const int64_t fc1_packed_cols = (hidden_size + pack_unit - 1) / pack_unit; const int64_t fc2_packed_cols = (inter_size + pack_unit - 1) / pack_unit; @@ -595,6 +881,22 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_zp_expert_stride = (hidden_size + zp_pack_size - 1) / zp_pack_size; } + MLAS_BLK_QUANT_TYPE fc1_direct_qtype = BlkQ4Sym; + MLAS_BLK_QUANT_TYPE fc2_direct_qtype = BlkQ4Sym; + + // Use pre-packed MLAS cache if available + const void* fc1_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_effective && packed_fc1_mlas_cache_ && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, fc1_out_features, hidden_size, fc1_direct_qtype)) { + fc1_direct_q4_cache_ptr = packed_fc1_mlas_cache_.get(); + } + + const void* fc2_direct_q4_cache_ptr = nullptr; + if (use_mlas_q4_gemm_effective && packed_fc2_mlas_cache_ && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, hidden_size, inter_size, fc2_direct_qtype)) { + fc2_direct_q4_cache_ptr = packed_fc2_mlas_cache_.get(); + } + std::vector> expert_workload; size_t total_work = 0; @@ -634,6 +936,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* thread_bias2_buffer = thread_bias1_buffer + static_cast(fc1_out_features); for (int64_t expert_idx : expert_batch) { + bool fc2_bias_added_by_mlas = false; const auto& routes = expert_token_map[static_cast(expert_idx)]; if (routes.empty()) { continue; @@ -707,12 +1010,57 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k = static_cast(hidden_size); MLAS_BLK_QUANT_TYPE q_type = BlkQ4Sym; // Initialize to default - // Direct Q4 GEMM only supports symmetric quantization, so we disable it if zero_points are provided. - bool use_direct_q4_gemm = (fc1_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, - fc1_out_features, hidden_size, q_type); - bool fc1_used_direct_q4 = false; - bool fc1_bias_handled_by_q4_gemm = false; + bool use_direct_q4_gemm = use_mlas_q4_gemm_effective && + ((fc1_direct_q4_cache_ptr != nullptr) || + ((packed_fc1_ == nullptr) && (fc1_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type))); + + if (packed_fc1_ != nullptr) { + if (use_mlas_q4_gemm_effective && fc1_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, q_type)) { + if (fc1_direct_q4_cache_ptr != nullptr) { + float* fc1_bias_float = nullptr; + if (has_fc1_bias) { + const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); + } else { + std::memcpy(thread_bias1_buffer, B1_bias, static_cast(fc1_out_features) * sizeof(float)); + } + fc1_bias_float = thread_bias1_buffer; + } + + size_t packed_size = MlasQ4GemmPackBSize(q_type, static_cast(fc1_out_features), static_cast(hidden_size)); + const uint8_t* packed_b = static_cast(fc1_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A1, packed_b, fc1_bias_float, C1, + num_expert_tokens, fc1_out_features, hidden_size, fc1_direct_qtype, tp); + if (gemm_status.IsOK()) { + goto fc1_gemm_done; + } + } + } + + // Fallback: Dequantize from PrePacked (transposed, unpacked) -> MlasGemm + const uint8_t* current_packed_ptr = static_cast(packed_fc1_.get()) + expert_idx * fc1_out_features * hidden_size; + + DequantizePrePacked(current_packed_ptr, fc1_scales_ptr, fc1_zp_ptr, + is_fc1_block_wise ? block_size_ : 0, + fc1_out_features, hidden_size, + B1_dequant, fc1_scales_dims); + + // Use MlasGemm with B1_dequant (which is already float transposed) + MlasGemm(CblasNoTrans, CblasNoTrans, + m, n, k, + 1.0f, A1, k, + B1_dequant, n, + 0.0f, C1, n, + tp); + + goto fc1_bias_handling; + } if (use_direct_q4_gemm) { IAllocatorUniquePtr mlas_packed_fc1; @@ -730,12 +1078,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (convert_status.IsOK()) { float* fc1_bias_float = nullptr; - IAllocatorUniquePtr fc1_bias_buffer; if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; - fc1_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(fc1_out_features)); - fc1_bias_float = fc1_bias_buffer.get(); + fc1_bias_float = thread_bias1_buffer; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), fc1_bias_float, static_cast(fc1_out_features)); @@ -750,7 +1096,6 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, fc1_out_features, hidden_size, q_type, tp); if (gemm_status.IsOK()) { - fc1_used_direct_q4 = true; goto fc1_gemm_done; } } @@ -797,8 +1142,9 @@ Status QMoECPU::Compute(OpKernelContext* context) const { 0.0f, C1, n, tp); - fc1_bias_handled_by_q4_gemm = fc1_used_direct_q4 && has_fc1_bias; - if (has_fc1_bias && !fc1_bias_handled_by_q4_gemm) { + fc1_bias_handling: + + if (has_fc1_bias) { const T* B1_bias = fc1_bias_data + expert_idx * fc1_out_features; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B1_bias), thread_bias1_buffer, static_cast(fc1_out_features)); @@ -837,22 +1183,30 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc1_gemm_done: - const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); - if (num_expert_tokens >= activation_threshold && tp != nullptr) { - const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); - const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; - - if (num_activation_blocks > 1) { - concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { - const int64_t start_token = block_idx * activation_block_size; - const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); - - for (int64_t i = start_token; i < end_token; ++i) { + if (activation_type_ == ActivationType::SwiGLU) { + const int64_t activation_threshold = std::max(int64_t{4}, 256 / std::max(int64_t{1}, inter_size)); + if (num_expert_tokens >= activation_threshold && tp != nullptr) { + const int64_t activation_block_size = std::max(int64_t{1}, std::min(int64_t{64}, activation_threshold)); + const int64_t num_activation_blocks = (num_expert_tokens + activation_block_size - 1) / activation_block_size; + + if (num_activation_blocks > 1) { + concurrency::ThreadPool::TrySimpleParallelFor(tp, narrow(num_activation_blocks), [&](std::ptrdiff_t block_idx) { + const int64_t start_token = block_idx * activation_block_size; + const int64_t end_token = std::min(start_token + activation_block_size, num_expert_tokens); + + for (int64_t i = start_token; i < end_token; ++i) { + const float* C1_token = C1 + i * fc1_out_features; + float* A2_token = A2 + i * inter_size; + ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); + } + }); + } else { + for (int64_t i = 0; i < num_expert_tokens; ++i) { const float* C1_token = C1 + i * fc1_out_features; float* A2_token = A2 + i * inter_size; ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); } - }); + } } else { for (int64_t i = 0; i < num_expert_tokens; ++i) { const float* C1_token = C1 + i * fc1_out_features; @@ -861,11 +1215,8 @@ Status QMoECPU::Compute(OpKernelContext* context) const { } } } else { - for (int64_t i = 0; i < num_expert_tokens; ++i) { - const float* C1_token = C1 + i * fc1_out_features; - float* A2_token = A2 + i * inter_size; - ApplySwiGLUActivation(C1_token, A2_token, inter_size, true, activation_alpha_, activation_beta_, swiglu_limit_); - } + ApplyActivationVectorized(C1, num_expert_tokens * fc1_out_features); + std::copy(C1, C1 + (num_expert_tokens * fc1_out_features), A2); } const T* fc2_scales_ptr; @@ -888,10 +1239,58 @@ Status QMoECPU::Compute(OpKernelContext* context) const { const size_t k2 = static_cast(inter_size); MLAS_BLK_QUANT_TYPE q_type2 = BlkQ4Sym; // Initialize to default - bool use_direct_q4_gemm_fc2 = (fc2_zp_data == nullptr) && - CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, - hidden_size, inter_size, q_type2); - bool fc2_used_direct_q4 = false; + bool use_direct_q4_gemm_fc2 = use_mlas_q4_gemm_effective && + ((fc2_direct_q4_cache_ptr != nullptr) || + ((packed_fc2_ == nullptr) && (fc2_zp_data == nullptr) && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2))); + + if (packed_fc2_ != nullptr) { + if (use_mlas_q4_gemm_effective && fc2_zp_data == nullptr && + CanUseMlasQ4Gemm(expert_weight_bits_, is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, q_type2)) { + if (fc2_direct_q4_cache_ptr != nullptr) { + float* fc2_bias_float = nullptr; + if (has_fc2_bias) { + const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; + if constexpr (std::is_same_v) { + MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); + } else { + std::memcpy(thread_bias2_buffer, B2_bias, static_cast(hidden_size) * sizeof(float)); + } + fc2_bias_float = thread_bias2_buffer; + } + + size_t packed_size = MlasQ4GemmPackBSize(q_type2, static_cast(hidden_size), static_cast(inter_size)); + const uint8_t* packed_b = static_cast(fc2_direct_q4_cache_ptr) + expert_idx * packed_size; + + Status gemm_status = DirectQ4Gemm(A2, packed_b, fc2_bias_float, C2, + num_expert_tokens, hidden_size, inter_size, fc2_direct_qtype, tp); + if (gemm_status.IsOK()) { + fc2_bias_added_by_mlas = true; + goto fc2_gemm_done; + } + } + } + + // Dequantize from PrePacked (transposed, unpacked) + const uint8_t* current_packed_ptr = static_cast(packed_fc2_.get()) + expert_idx * hidden_size * inter_size; + + DequantizePrePacked(current_packed_ptr, fc2_scales_ptr, fc2_zp_ptr, + is_fc2_block_wise ? block_size_ : 0, + hidden_size, inter_size, + B2_dequant, fc2_scales_dims); + + // Fallback + MlasGemm(CblasNoTrans, CblasNoTrans, + m2, n2, k2, + 1.0f, A2, k2, + B2_dequant, n2, + 0.0f, C2, n2, + tp); + + goto fc2_gemm_done; + } if (use_direct_q4_gemm_fc2) { IAllocatorUniquePtr mlas_packed_fc2; @@ -909,12 +1308,10 @@ Status QMoECPU::Compute(OpKernelContext* context) const { if (convert_status.IsOK()) { float* fc2_bias_float = nullptr; - IAllocatorUniquePtr fc2_bias_buffer; if (has_fc2_bias) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; - fc2_bias_buffer = IAllocator::MakeUniquePtr(allocator, static_cast(hidden_size)); - fc2_bias_float = fc2_bias_buffer.get(); + fc2_bias_float = thread_bias2_buffer; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), fc2_bias_float, static_cast(hidden_size)); @@ -929,7 +1326,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { num_expert_tokens, hidden_size, inter_size, q_type2, tp); if (gemm_status.IsOK()) { - fc2_used_direct_q4 = true; + fc2_bias_added_by_mlas = true; goto fc2_gemm_done; } } @@ -979,8 +1376,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { fc2_gemm_done: - bool fc2_bias_handled_by_q4_gemm = fc2_used_direct_q4 && has_fc2_bias; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const T* B2_bias = fc2_bias_data + expert_idx * hidden_size; if constexpr (std::is_same_v) { MlasConvertHalfToFloatBuffer(reinterpret_cast(B2_bias), thread_bias2_buffer, static_cast(hidden_size)); @@ -1015,7 +1411,7 @@ Status QMoECPU::Compute(OpKernelContext* context) const { float* dest = thread_local_outputs + static_cast(thread_id) * output_buffer_size + buffer_offset; const float* src = C2 + i * hidden_size; - if (has_fc2_bias && !fc2_bias_handled_by_q4_gemm) { + if (has_fc2_bias && !fc2_bias_added_by_mlas) { const size_t unroll_factor = narrow(GetUnrollFactor(hidden_size)); size_t j = 0; for (; j + unroll_factor <= narrow(hidden_size); j += unroll_factor) { @@ -1109,10 +1505,22 @@ Status QMoECPU::Compute(OpKernelContext* context) const { return Status::OK(); } +template +void QMoECPU::ApplyActivationVectorized(float* data, int64_t size) const { + for (int64_t i = 0; i < size; ++i) { + data[i] = ApplyActivation(data[i], activation_type_); + } +} + template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); + template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); template QMoECPU::QMoECPU(const OpKernelInfo& op_kernel_info); template Status QMoECPU::Compute(OpKernelContext* context) const; +template Status QMoECPU::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, bool& is_packed, PrePackedWeights* prepacked_weights); +template Status QMoECPU::UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, gsl::span prepacked_buffer_sizes, int input_idx, bool& used_shared_buffers); // Kernel Registration ONNX_OPERATOR_TYPED_KERNEL_EX( diff --git a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h index 890580e051a8e..f678a27190c90 100644 --- a/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h +++ b/onnxruntime/contrib_ops/cpu/moe/moe_quantization_cpu.h @@ -5,7 +5,9 @@ #include "core/common/common.h" #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas_q4.h" #include "contrib_ops/cpu/moe/moe_base_cpu.h" +#include namespace onnxruntime { namespace contrib { @@ -26,8 +28,30 @@ class QMoECPU final : public OpKernel, public MoEBaseCPU { Status Compute(OpKernelContext* context) const override; private: + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* prepacked_weights) override; + + Status UseSharedPrePackedBuffers_V2(std::vector& prepacked_buffers, + gsl::span prepacked_buffer_sizes, + int input_idx, + /*out*/ bool& used_shared_buffers) override; + + void ApplyActivationVectorized(float* data, int64_t size) const; + int64_t expert_weight_bits_; int64_t block_size_; + bool use_mlas_q4_gemm_{false}; + bool use_mlas_q4_gemm_overridden_{false}; + + IAllocatorUniquePtr packed_fc1_; + IAllocatorUniquePtr packed_fc2_; + + TensorShape fc1_shape_; + TensorShape fc2_shape_; + + IAllocatorUniquePtr packed_fc1_mlas_cache_; + IAllocatorUniquePtr packed_fc2_mlas_cache_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc index 13748b43b1ae6..7531d63fb5fc8 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.cc +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.cc @@ -202,6 +202,12 @@ Status NchwcConv::Compute(OpKernelContext* context) const { } } +#if defined(__aarch64__) && defined(__linux__) + const bool use_bf16 = use_fastmath_mode_; +#else + const bool use_bf16 = false; +#endif + MlasNchwcConv( X_shape.GetDims().data(), kernel_shape.data(), @@ -216,7 +222,8 @@ Status NchwcConv::Compute(OpKernelContext* context) const { y_data.data(), &activation_, Sum == nullptr, - context->GetOperatorThreadPool()); + context->GetOperatorThreadPool(), + use_bf16); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/nchwc_ops.h b/onnxruntime/contrib_ops/cpu/nchwc_ops.h index 4827d70489674..169eecdeaa02f 100644 --- a/onnxruntime/contrib_ops/cpu/nchwc_ops.h +++ b/onnxruntime/contrib_ops/cpu/nchwc_ops.h @@ -7,6 +7,7 @@ #include "core/framework/op_kernel.h" #include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/cpu/nn/pool.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "contrib_ops/cpu/fused_activation.h" namespace onnxruntime { @@ -43,6 +44,10 @@ class NchwcConv final : public OpKernel { public: NchwcConv(const OpKernelInfo& info) : OpKernel(info), conv_attrs_(info) { ORT_ENFORCE(GetFusedActivationAttr(info, activation_).IsOK()); +#if defined(__aarch64__) && defined(__linux__) + auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); + use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported(); +#endif } Status Compute(OpKernelContext* context) const override; @@ -51,6 +56,9 @@ class NchwcConv final : public OpKernel { ConvAttributes conv_attrs_; MLAS_ACTIVATION activation_; +#if defined(__aarch64__) && defined(__linux__) + bool use_fastmath_mode_{false}; +#endif }; class NchwcPoolBase : public PoolBase { diff --git a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc index 8a69263ab2f37..e069adcb82863 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/dynamic_quantize_matmul.cc @@ -164,132 +164,23 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { Status Compute(OpKernelContext* context) const override; #if defined(USE_KLEIDIAI) - Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - /*out*/ bool& is_packed, - /*out*/ PrePackedWeights* prepacked_weights) override { - // only pack Matrix B - if (input_idx == GetBIdx()) { - const Tensor* b_zp_constant_tensor{nullptr}; - bool b_quantization_might_be_asymmetric = false; - - const OrtValue* b_zp; - if (Info().TryGetConstantInput(IN_B_ZERO_POINT, &b_zp)) { - b_zp_constant_tensor = &b_zp->Get(); - } - - // MlasDynamicQgemm requires symmetric quantization for B, so the B zero point value should either be all zeros - // or not provided. - if (b_zp_constant_tensor != nullptr) { - // B zero point is constant. Check if it is all zeros. - assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); - const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); - const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); - b_quantization_might_be_asymmetric = std::any_of(zp_bytes, zp_bytes + zp_size_in_bytes, - [](std::byte v) { return v != std::byte{0}; }); - } else { - // B zero point input is not constant. If it exists, we can't assume symmetric quantization. - const auto input_defs = Info().node().InputDefs(); - const bool b_zp_input_exists = input_defs.size() > IN_B_ZERO_POINT && input_defs[IN_B_ZERO_POINT]->Exists(); - b_quantization_might_be_asymmetric = b_zp_input_exists; - } - - // MlasDynamicQgemm requires scale data to be available at packing stage - const Tensor* b_scale_tensor = nullptr; - const bool b_scale_available = Info().TryGetConstantInput(IN_B_SCALE, &b_scale_tensor); - - can_use_dynamic_quant_mlas_ = (!b_quantization_might_be_asymmetric && b_scale_available); - - // Kleidi dynamic path requires strictly positive, finite scales. - // Disable if any invalid scale is detected. - if (can_use_dynamic_quant_mlas_) { - const auto bs = b_scale_tensor->DataAsSpan(); - const bool has_invalid = - std::any_of(bs.begin(), bs.end(), - [](float s) { return !std::isfinite(s) || s <= 0.0f; }); - - if (has_invalid) { - can_use_dynamic_quant_mlas_ = false; - } - } - - if (!MlasIsDynamicQGemmAvailable()) { - can_use_dynamic_quant_mlas_ = false; - } - - // Only handle the common case of a 2D weight matrix. Additional matrices - // could be handled by stacking the packed buffers. - b_shape_ = tensor.Shape(); - if (b_shape_.NumDimensions() >= 2) { - for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) { - if (b_shape_[i] != 1) { - can_use_dynamic_quant_mlas_ = false; - break; - } - } - } else { - can_use_dynamic_quant_mlas_ = false; - } - - // Can we use the mlas dynamic Q gemm interface supported with float output ? - if (!can_use_dynamic_quant_mlas_) { - // default to piece wise mlas interface with separate int matmul, quantize and float conversion - return MatMulIntegerToFloatBase::PrePack(tensor, input_idx, alloc, is_packed, prepacked_weights); - } - is_packed = false; - - // Default to all zeros for bias - const Tensor* bias_tensor{nullptr}; - const OrtValue* bias; - if (Info().TryGetConstantInput(IN_BIAS, &bias)) { - bias_tensor = &bias->Get(); - dynamic_quant_mlas_bias_data_was_packed_ = true; - } - size_t K = static_cast(b_shape_[0]); - size_t N = static_cast(b_shape_[1]); - - const auto* b_data = static_cast(tensor.DataRaw()); - - std::optional b_trans_buffer; - if (IsBTransposed()) { - std::swap(K, N); - b_data = quantization::TransPoseInputData(b_data, b_trans_buffer, alloc, N, K); - } + bool SupportsKleidiaiDynamicQuant() const override { + if (!MlasIsDynamicQGemmAvailable()) { + return false; + } + return true; + } - const size_t packed_b_size = MlasDynamicQgemmPackBSize(N, K); - if (packed_b_size == 0) { - return Status::OK(); - } + int GetBScaleIdx() const override { + return IN_B_SCALE; + } - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); - // Initialize memory to 0 as there could be some padding associated with pre-packed - // buffer memory and we do not want it uninitialized and generate different hashes - // if and when we try to cache this pre-packed buffer for sharing between sessions. - memset(packed_b_.get(), 0, packed_b_size); - - const auto scales = static_cast(b_scale_tensor->Shape().Size()) == N ? std::vector(&b_scale_tensor->Data()[0], - &b_scale_tensor->Data()[N]) - : - // Broadcast matrix scale to all channels - std::vector(N, b_scale_tensor->Data()[0]); - - const auto biases = bias_tensor != nullptr ? std::vector(&bias_tensor->Data()[0], - &bias_tensor->Data()[N]) - : - // Broadcast zero to all channels - no bias data is available - std::vector(N, 0.f); - - MlasDynamicQgemmPackB(N, K, reinterpret_cast(b_data), scales.data(), biases.data(), - packed_b_.get()); - - bool share_prepacked_weights = (prepacked_weights != nullptr); - if (share_prepacked_weights) { - prepacked_weights->buffers_.push_back(std::move(packed_b_)); - prepacked_weights->buffer_sizes_.push_back(packed_b_size); - } + int GetBZeroPointIdx() const override { + return IN_B_ZERO_POINT; + } - is_packed = true; - } - return Status::OK(); + int GetBiasIdx() const override { + return IN_BIAS; } #endif @@ -303,14 +194,6 @@ class DynamicQuantizeMatMul final : public MatMulIntegerToFloatBase { protected: int GetBIdx() const override { return IN_B; } - - private: - // Indicates when MlasDynamicQGemmBatch() can be used - bool can_use_dynamic_quant_mlas_{false}; -#if defined(USE_KLEIDIAI) - // Indicates that the biases are a constant input and thus already quantized / packed - bool dynamic_quant_mlas_bias_data_was_packed_{false}; -#endif }; class MatMulIntegerToFloat final : public MatMulIntegerToFloatBase { @@ -380,8 +263,8 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { ScaleOutput(*b_scale_tensor, *ctx->Output(0)); } } - // Guard against KleidiAI functions being called in non kleidi builds - // TODO: migrate to a suitable override function call for kleidi dynamic qgemm function calls + // Guard against KleidiAI functions being called in non-Kleidi builds + // migrate to a suitable override function call for KleidiAI dynamic QGEMM function calls #if defined(USE_KLEIDIAI) else { MatMulComputeHelper helper; @@ -390,10 +273,10 @@ Status DynamicQuantizeMatMul::Compute(OpKernelContext* ctx) const { // deleted during session init post prepacking nullptr, nullptr)); - + // allocate the kernel’s output tensor from the execution context Tensor* y = ctx->Output(OUT_Y, helper.OutputShape()); - // Bail out early if the output is going to be empty + // Bail out early if any dimension is 0, the product (and hence the total number of elements) is 0 if (y->Shape().Size() == 0) return Status::OK(); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index fc0b7e40c628b..cc93799059f43 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -4,6 +4,7 @@ #include "contrib_ops/cpu/quantization/matmul_nbits_impl.h" #include +#include #include #include "core/common/common.h" @@ -15,7 +16,10 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/common.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "contrib_ops/cpu/quantization/matmul_nbits_helper.h" +#include "core/platform/threadpool.h" +#include "core/util/thread_utils.h" namespace onnxruntime { namespace contrib { @@ -100,6 +104,11 @@ class MatMulNBits final : public OpKernel { nbits_{narrow(info.GetAttr("bits"))}, has_g_idx_{info.GetInputCount() > InputIndex::g_idx && info.node().InputDefs()[InputIndex::g_idx]->Exists()}, has_bias_{info.GetInputCount() > InputIndex::bias && info.node().InputDefs()[InputIndex::bias]->Exists()}, + prefer_lut_gemm_{info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasLutGemm) == "1" && + MlasIsLutGemmAvailable(narrow(info.GetAttr("N")), + narrow(info.GetAttr("K")), + narrow(info.GetAttr("bits")), + narrow(info.GetAttr("block_size")))}, compute_type_{GetComputeType(nbits_, block_size_, info.GetAttr("accuracy_level"))} { const auto& node = info.node(); auto input_defs = node.InputDefs(); @@ -135,6 +144,7 @@ class MatMulNBits final : public OpKernel { const bool has_g_idx_; const bool has_bias_; bool scales_are_packed_{false}; + const bool prefer_lut_gemm_{false}; const MLAS_QNBIT_GEMM_COMPUTE_TYPE compute_type_; bool has_unquantized_zero_point_{false}; const bool column_wise_quant_{true}; @@ -167,6 +177,11 @@ class MatMulNBits final : public OpKernel { AllocatorPtr& allocator, concurrency::ThreadPool* thread_pool, const MatMulComputeHelper& helper) const; + + Status ComputeBPackedLUT(const Tensor* a, + Tensor* y, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const; }; template @@ -179,22 +194,76 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } - if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && !prefer_lut_gemm_) { return Status::OK(); } + + // Create a temporary threadpool for parallel packing + // This is used during model load time to speed up weight prepacking + std::unique_ptr temp_threadpool; + concurrency::ThreadPool* threadpool_ptr = nullptr; + + // Only create threadpool for LUT GEMM path which can benefit from parallel packing + // TODO: Consider extending threadpool usage to non-LUT path (CompInt8) with appropriate tests + if (prefer_lut_gemm_) { + OrtThreadPoolParams tpo; + tpo.thread_pool_size = Env::Default().GetNumPhysicalCpuCores(); + tpo.allow_spinning = false; // Don't spin during model load + tpo.auto_set_affinity = false; + + temp_threadpool = concurrency::CreateThreadPool( + &Env::Default(), + tpo, + concurrency::ThreadPoolType::INTRA_OP); + + threadpool_ptr = temp_threadpool.get(); + } + if (input_idx == InputIndex::B) { const Tensor* scales = nullptr; OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); - packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); - if (packed_b_size_ == 0) { - return Status::OK(); + if (prefer_lut_gemm_) { + MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); + + packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + + const float* scales_ptr = scales ? scales->Data() : nullptr; + const uint8_t* zp_ptr = nullptr; + if (scales_ptr != nullptr && has_zp_input_) { + const Tensor* zero_points = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); + zp_ptr = zero_points ? zero_points->Data() : nullptr; + } + + MlasLutGemmPack( + N_, K_, nbits_, block_size_, has_zp_input_, + static_cast(tensor.DataRaw()), + scales_ptr, + zp_ptr, + static_cast(packed_b_.get()), + threadpool_ptr); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size_); + } + } else { + packed_b_size_ = MlasQNBitGemmPackQuantBDataSize(N_, K_, nbits_, block_size_, has_zp_input_, compute_type_); + if (packed_b_size_ == 0) { + return Status::OK(); + } + auto qptr = tensor.DataRaw(); + auto scale_ptr = scales ? scales->DataRaw() : nullptr; + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); + MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, + has_zp_input_, nullptr, threadpool_ptr); } - auto qptr = tensor.DataRaw(); - auto scale_ptr = scales ? scales->DataRaw() : nullptr; - packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size_, true); - MlasQNBitGemmPackQuantBData(N_, K_, nbits_, block_size_, compute_type_, qptr, packed_b_.get(), scale_ptr, - has_zp_input_, nullptr, nullptr); is_packed = true; } else if (compute_type_ == SQNBIT_CompInt8) { // Packing scales and zero points @@ -230,8 +299,30 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All is_packed = true; } #endif // MLAS_TARGET_ARM64 + } else if (prefer_lut_gemm_) { + // Pack scales/zero_points for LUT GEMM if B was already packed but scales weren't available then + if (input_idx == InputIndex::scales && packed_b_ != nullptr) { + auto scales_ptr = tensor.Data(); + const uint8_t* zp_ptr = nullptr; + if (has_zp_input_) { + const Tensor* zero_points = nullptr; + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zero_points); + zp_ptr = zero_points ? zero_points->Data() : nullptr; + } + // Pack only scales (QuantBData is nullptr) + MlasLutGemmPack( + N_, K_, nbits_, block_size_, has_zp_input_, + nullptr, // QuantBData already packed + scales_ptr, + zp_ptr, + static_cast(packed_b_.get()), + nullptr); // No threadpool needed for scales only + is_packed = false; // scales tensor can be released but not "packed" in the ORT sense + } } + // Threadpool will be automatically destroyed when temp_threadpool goes out of scope + return Status::OK(); } @@ -268,9 +359,9 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*ou OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales); if (scales && MlasQNBitGemmScalesPacked(K_, nbits_, block_size_, compute_type_, has_zp_input_)) { auto sptr = scales->Data(); - auto tensor_size = static_cast(tensor.Shape().Size()); - auto ptr = IAllocator::MakeUniquePtr(alloc, tensor_size, true); - MlasConvertHalfToFloatBuffer(sptr, ptr.get(), tensor_size); + auto scales_size = static_cast(scales->Shape().Size()); + auto ptr = IAllocator::MakeUniquePtr(alloc, scales_size, true); + MlasConvertHalfToFloatBuffer(sptr, ptr.get(), scales_size); scales_fp32_ = std::move(ptr); } @@ -307,14 +398,34 @@ Status MatMulNBits::UseSharedPrePackedBuffers(std::vector& /*out*/ bool& used_shared_buffers) { used_shared_buffers = false; - if (input_idx == 1) { - used_shared_buffers = true; + if (input_idx == InputIndex::B && !prepacked_buffers.empty()) { packed_b_ = std::move(prepacked_buffers[0]); + used_shared_buffers = true; + + if (prefer_lut_gemm_) { + MlasInitLutGemmKernelConfig(N_, K_, nbits_, block_size_, has_zp_input_); + packed_b_size_ = MlasLutGemmPackedSize(N_, K_, nbits_, block_size_, has_zp_input_); + } } return Status::OK(); } +template +Status MatMulNBits::ComputeBPackedLUT(const Tensor* a, + Tensor* y, + concurrency::ThreadPool* thread_pool, + const MatMulComputeHelper& helper) const { + const auto* a_data = a->Data(); + auto* y_data = y->MutableData(); + const int M = static_cast(helper.M()); + const int N = static_cast(helper.N()); + const int K = static_cast(helper.K()); + + MlasLutGemm(a_data, block_size_, packed_b_.get(), y_data, K, M, N, has_zp_input_, thread_pool); + return Status::OK(); +} + template Status MatMulNBits::ComputeBPacked(const Tensor* a, const Tensor* scales, @@ -740,7 +851,7 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // If B is prepacked, B would have been removed from the context const bool is_b_prepacked = packed_b_size_ > 0; const Tensor* b = is_b_prepacked ? nullptr : ctx->Input(InputIndex::B); - const Tensor* scales = scales_are_packed_ ? nullptr : ctx->Input(InputIndex::scales); + const Tensor* scales = (scales_are_packed_ || (prefer_lut_gemm_ && packed_b_)) ? nullptr : ctx->Input(InputIndex::scales); const Tensor* zero_points = ctx->Input(InputIndex::zero_points); const Tensor* reorder_idx = ctx->Input(InputIndex::g_idx); const Tensor* bias = ctx->Input(InputIndex::bias); @@ -774,6 +885,10 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const { // If this changes, i.e., if MlasIsQNBitGemmAvailable() can return true while // MlasQNBitGemmPackQuantBDataSize() returns 0, we can consider calling MlasQNBitGemmBatch() // with B directly too. + if (prefer_lut_gemm_) { + return ComputeBPackedLUT(a, y, thread_pool, helper); + } + if (MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_)) { return ComputeBPacked(a, scales, zero_points, bias, y, allocator, thread_pool, helper); } diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc index b7b839a4f366b..e9ef220a2187e 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.cc @@ -73,7 +73,7 @@ void Dequantize4BitsKernelReOrder( } } -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input @@ -102,17 +102,17 @@ void DequantizeBlockwise( }); } -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const uint8_t* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const float* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); -template void DequantizeBlockwise( +template void DequantizeBlockwise( float* output, const uint8_t* quant_data, const float* scales_data, const MLFloat16* zero_points, const int32_t* reorder_idx, int32_t block_size, bool columnwise, int32_t K, int32_t N, onnxruntime::concurrency::ThreadPool* thread_pool); diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h index 5061ac5c800a6..be77ec03d006b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits_impl.h @@ -6,7 +6,7 @@ namespace onnxruntime { namespace contrib { -template +template void DequantizeBlockwise( inputT* output, // dequantized output const uint8_t* quant_data, // quantized input diff --git a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h index 47d0fc5e4008c..415612582ee4b 100644 --- a/onnxruntime/contrib_ops/cpu/utils/debug_macros.h +++ b/onnxruntime/contrib_ops/cpu/utils/debug_macros.h @@ -1,12 +1,9 @@ #pragma once +#include #include "core/common/make_string.h" -// #define DEBUG_GENERATION 1 // uncomment it for debugging generation (like beam search etc) - -#ifdef DEBUG_GENERATION -#define DUMP_TENSOR_LEVEL 2 -#else -#define DUMP_TENSOR_LEVEL 0 // change it to 1 or 2 if want to enable dumping for code not in generation. +#if !defined(DUMP_TENSOR_LEVEL) +#define DUMP_TENSOR_LEVEL 0 #endif #define DUMP_CPU_TENSOR_LEVEL DUMP_TENSOR_LEVEL @@ -48,3 +45,12 @@ #else #define DUMP_TENSOR_D(...) #endif + +#if (defined(__GNUC__) || defined(__clang__)) && !defined(NDEBUG) +#define DEBUG_PRINTF(fmt, ...) \ + std::printf("[DEBUG] " fmt "\n", ##__VA_ARGS__) +#else +#define DEBUG_PRINTF(fmt, ...) \ + do { \ + } while (0) +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_data.h b/onnxruntime/contrib_ops/cuda/bert/attention_data.h index 2344b425ed263..1622bb6622412 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_data.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_data.h @@ -179,35 +179,20 @@ struct GroupQueryAttentionData { // Memory Efficient buffers T* fmha_buffer = nullptr; - T* unpacked_qkv_buffer = nullptr; - T* rotary_buffer = nullptr; - int64_t* position_ids_buffer = nullptr; // Separate buffer for generated position IDs + T* qkv_buffer = nullptr; + T* k = nullptr; T* v = nullptr; -#ifndef NDEBUG - // Buffer size tracking for debug validation - // Allocated sizes are set during buffer allocation in group_query_attention.cc - // Max used sizes are updated during kernel calls in group_query_attention_impl.cu - // Validated before operator returns to ensure usage exactly matches allocation - size_t unpacked_qkv_buffer_size = 0; // Allocated size - size_t rotary_buffer_size = 0; // Allocated size - size_t position_ids_buffer_size = 0; // Allocated size - mutable size_t unpacked_qkv_max_used = 0; // Max offset accessed (updated by kernels) - mutable size_t rotary_max_used = 0; // Max offset accessed (updated by kernels) - mutable size_t position_ids_max_used = 0; // Max offset accessed (updated by kernels) -#endif - // Output Tensors T* output = nullptr; - T* present_key = nullptr; - T* present_value = nullptr; + void* present_key = nullptr; + void* present_value = nullptr; // Kernel Flags bool use_flash_attention = false; bool use_memory_efficient_attention = false; bool use_flash_attention_fast_decode = false; - bool disable_fused_kv = false; }; template diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index 91cac731054e6..fcc470b19a7b4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -533,7 +533,6 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.is_seqlens_k_cumulative = seqlens_k_ == nullptr; if (seqlens_k_ != nullptr) { params.cu_seqlens_k = static_cast(seqlens_k_); - params.seqused_k = static_cast(seqlens_k_); } if (rotary_cos != nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 22b075d8533f9..83f94a31d1786 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -132,9 +132,12 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, size_t get_softmax_lse_size(size_t max_seqlen_q, size_t batch_size, size_t num_heads); size_t get_softmax_lse_size(size_t token_count, size_t num_heads); +size_t get_softmax_lse_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, size_t seqlen_q); +size_t get_out_accum_size(size_t num_splits, size_t batch_size, size_t num_heads, + size_t seqlen_q, size_t head_size_rounded); -std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, size_t num_heads, - size_t head_size, size_t num_SMs); +std::tuple get_num_splits_and_buffer_sizes(size_t batch_size, size_t seqlen_q, size_t seqlen_k, + size_t num_heads, size_t head_size, size_t num_SMs); bool is_supported(const cudaDeviceProp& dprops, size_t head_size, size_t num_heads, size_t num_heads_k); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index c99db85f93421..29ef660e562e0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/platform/env_var_utils.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" @@ -39,8 +40,17 @@ REGISTER_KERNEL_TYPED(MLFloat16) REGISTER_KERNEL_TYPED(BFloat16) constexpr const char* kDisableFlashDecode = "ORT_DISABLE_FLASH_DECODE"; -constexpr const char* kDisableFusedKv = "ORT_DISABLE_FUSED_KV"; +// Group Query Attention (GQA) Operator +// +// This operator implements Group Query Attention, a variation of Multi-Head Attention (MHA) +// where the number of key/value heads is smaller than the number of query heads. +// It supports: +// - Rotary Positional Embeddings (RoPE) +// - KV Cache (past/present key/value) +// - Quantized KV Cache (Int8/Int4) via GroupQueryAttentionData +// - Flash Attention and Memory Efficient Attention backends +// template GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) : CudaKernel(info) { @@ -63,7 +73,7 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) disable_flash_attention_ = sizeof(T) != 2 || !kernel_options_->UseFlashAttention(); - // Memory efficient attention supports float and float16. BFloat16 support is added for SM80+ via cutlass kernels. + // Memory efficient attention supports float and float16. BFloat16 support added for SM80+. disable_memory_efficient_attention_ = !kernel_options_->UseEfficientAttention(); if (!disable_flash_attention_) { @@ -72,9 +82,23 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) } disable_flash_decode_ = ParseEnvironmentVariableWithDefault(kDisableFlashDecode, false); - disable_fused_kv_ = ParseEnvironmentVariableWithDefault(kDisableFusedKv, false); } +// ComputeInternal executes the GQA kernel. +// +// Inputs: +// 0. query (Tensor) [batch, sequence_length, hidden_size] +// 1. key (Tensor) [batch, sequence_length, kv_hidden_size] (Optional) +// 2. value (Tensor) [batch, sequence_length, kv_hidden_size] (Optional) +// 3. past_key (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional) +// 4. past_value (Tensor) [batch, num_kv_heads, max_seq_len, head_size] (Optional) +// 5. seqlens_k (Tensor) [batch] - Total sequence length minus 1 (for historical compatibility) +// 6. total_seqlen (Tensor) - Max total sequence length +// 7. cos_cache (Tensor) - Precomputed cosine table for RoPE +// 8. sin_cache (Tensor) - Precomputed sine table for RoPE +// 9. position_ids (Tensor) - Position indices for RoPE +// 10. attention_bias (Tensor) - Not supported in this kernel +// 11. head_sink (Tensor) - Attention sink for GPT-OSS template Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* query = context->Input(0); @@ -162,7 +186,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { IAllocatorUniquePtr k_buffer; IAllocatorUniquePtr v_buffer; IAllocatorUniquePtr rotary_buffer; - IAllocatorUniquePtr position_ids_buffer; IAllocatorUniquePtr fmha_buffer; IAllocatorUniquePtr unpacked_qkv_buffer; IAllocatorUniquePtr seq_lens_buffer; @@ -185,24 +208,39 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { data.past_value = (past_value == nullptr) ? nullptr : reinterpret_cast(past_value->Data()); data.present_value = reinterpret_cast(context->Output(2)->MutableData()); + // Compute past_present_share_buffer early since it's needed for flash attention path selection. + // This compares the final pointer values after quantization handling. + parameters.past_present_share_buffer = (data.past_key == data.present_key); + #if USE_FLASH_ATTENTION bool use_flash_attention = !disable_flash_attention_ && onnxruntime::flash::is_supported(device_prop, parameters.head_size, parameters.num_heads, parameters.kv_num_heads); - data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.kv_share_buffer; - if (use_flash_attention) { - data.use_flash_attention = true; - data.use_memory_efficient_attention = false; + data.use_flash_attention = use_flash_attention; + data.use_flash_attention_fast_decode = use_flash_attention && !disable_flash_decode_ && !parameters.is_first_prompt && parameters.past_present_share_buffer; + + if (use_flash_attention) { // Allocate Flash specific buffers (Softmax LSE, Accum) size_t softmax_lse_bytes = onnxruntime::flash::get_softmax_lse_size(parameters.sequence_length, parameters.batch_size, parameters.num_heads); + + int num_heads_for_split = data.use_flash_attention_fast_decode ? parameters.kv_num_heads : parameters.num_heads; auto [num_splits, softmax_lse_accum_bytes, out_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes( - parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, parameters.num_heads, + parameters.batch_size, parameters.sequence_length, parameters.total_sequence_length, num_heads_for_split, parameters.head_size, device_prop.multiProcessorCount); + parameters.num_splits = static_cast(num_splits); + if (data.use_flash_attention_fast_decode && num_splits > 1) { + // The heuristic used kv_num_heads to maximize occupancy for the GQA-aware kernel. + // However, the LSE and Accum buffers must store results for ALL num_heads. + softmax_lse_accum_bytes = onnxruntime::flash::get_softmax_lse_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length); + auto round_multiple = [](size_t x, size_t m) { return (x + m - 1) / m * m; }; + out_accum_bytes = onnxruntime::flash::get_out_accum_size(num_splits, parameters.batch_size, parameters.num_heads, parameters.sequence_length, round_multiple(parameters.head_size, 32)); + } + softmax_lse_buffer = GetScratchBuffer(softmax_lse_bytes, context->GetComputeStream()); softmax_lse_accum_buffer = GetScratchBuffer(softmax_lse_accum_bytes, context->GetComputeStream()); out_accum_buffer = GetScratchBuffer(out_accum_bytes, context->GetComputeStream()); @@ -214,11 +252,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { #endif if (data.use_flash_attention_fast_decode && parameters.sequence_length == 1) { - // FlashAttentionDecoding Fast Path: + // FlashDecoding Fast Path: // - Uses Flash Attention's internal KV append logic, so total_seq_lens and padded_seq_lens are not needed. - // - Past_seq_lens is passed as seqlens_k to Flash Attention, which uses it to: - // 1. Determine where to append new K/V in the cache - // 2. Apply correct causal masking (attention only to positions [0, past_seq_len]) // - The input seqlens_k from ONNX graph is (total_len - 1), which equals past_seq_len when seq_len == 1. // - This optimization avoids launching GetSequenceLengths kernel for single-token decoding. data.past_seq_lens = const_cast(total_seq_lens_minus_one->Data()); @@ -239,16 +274,20 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.is_first_prompt, cuda_stream, device_prop.maxThreadsPerBlock)); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("total_seq_lens", data.total_seq_lens, parameters.batch_size, 1); + DUMP_TENSOR("past_seq_lens", data.past_seq_lens, parameters.batch_size, 1); + DUMP_TENSOR("padded_seq_lens", data.padded_seq_lens, parameters.batch_size, 1); } - if (!use_flash_attention) { - // Fall back to memory efficient attention. #if USE_MEMORY_EFFICIENT_ATTENTION + if (!data.use_flash_attention) { + // Fall back to memory efficient attention. int sm = (device_prop.major * 10) + device_prop.minor; bool use_memory_efficient_attention = - !use_flash_attention && !disable_memory_efficient_attention_ && has_memory_efficient_attention(sm, std::is_same::value, std::is_same::value, parameters.head_size, parameters.head_size); + data.use_memory_efficient_attention = use_memory_efficient_attention; // KV buffer for head expansion (when num_heads != kv_num_heads) size_t kv_buffer_bytes = (use_memory_efficient_attention && (parameters.num_heads != parameters.kv_num_heads)) @@ -262,49 +301,30 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); -#else - constexpr bool use_memory_efficient_attention = false; -#endif - - data.use_memory_efficient_attention = use_memory_efficient_attention; - data.use_flash_attention = false; data.k = reinterpret_cast(k_buffer.get()); data.v = reinterpret_cast(v_buffer.get()); data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); - data.disable_fused_kv = disable_fused_kv_; } +#endif + // ------------- // Centralized scratch buffer allocation using GQABufferRequirements // This ensures allocation logic stays in sync with kernel usage auto buffer_req = GQABufferRequirements::Compute( parameters, - use_flash_attention, + data.use_flash_attention, data.use_flash_attention_fast_decode, data.use_memory_efficient_attention); - if (buffer_req.unpacked_qkv_bytes > 0) { - unpacked_qkv_buffer = GetScratchBuffer(buffer_req.unpacked_qkv_bytes, context->GetComputeStream()); - data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); - } - if (buffer_req.rotary_buffer_bytes > 0) { - rotary_buffer = GetScratchBuffer(buffer_req.rotary_buffer_bytes, context->GetComputeStream()); - data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); + if (buffer_req.qkv_buffer_bytes > 0) { + unpacked_qkv_buffer = GetScratchBuffer(buffer_req.qkv_buffer_bytes, context->GetComputeStream()); + data.qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); } - if (buffer_req.position_ids_bytes > 0) { - position_ids_buffer = GetScratchBuffer(buffer_req.position_ids_bytes, context->GetComputeStream()); - data.position_ids_buffer = reinterpret_cast(position_ids_buffer.get()); - } -#ifndef NDEBUG - // Track allocated sizes for validation - data.unpacked_qkv_buffer_size = buffer_req.unpacked_qkv_bytes; - data.rotary_buffer_size = buffer_req.rotary_buffer_bytes; - data.position_ids_buffer_size = buffer_req.position_ids_bytes; -#endif if (kernel_options_->AllowDebugInfo()) { AttentionKernelDebugInfo debug_info; - debug_info.use_flash_attention = use_flash_attention; + debug_info.use_flash_attention = data.use_flash_attention; debug_info.use_efficient_attention = data.use_memory_efficient_attention; debug_info.Print("GroupQueryAttention", @@ -313,12 +333,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { std::is_same::value); } - if (data.past_key == data.present_key) { - parameters.kv_share_buffer = true; - ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when kv_share_buffer is true"); + // Validate past_value pointer consistency (past_present_share_buffer was computed early after pointer setup) + if (parameters.past_present_share_buffer) { + ORT_ENFORCE(data.past_value == data.present_value, "past_value and present_value must be the same tensor when past_present_share_buffer is true"); } else { - parameters.kv_share_buffer = false; - ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when kv_share_buffer is false"); + ORT_ENFORCE(data.past_value != data.present_value, "past_value and present_value must be different tensors when past_present_share_buffer is false"); } data.output = reinterpret_cast(output->MutableData()); @@ -337,19 +356,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { ORT_RETURN_IF_ERROR(QkvToContext( device_prop, cublas, context->GetComputeStream(), parameters, data)); -#ifndef NDEBUG - // Validate buffer usage matches allocation exactly - ORT_ENFORCE(data.unpacked_qkv_max_used == data.unpacked_qkv_buffer_size, - "unpacked_qkv_buffer: used ", data.unpacked_qkv_max_used, - " bytes but allocated ", data.unpacked_qkv_buffer_size); - ORT_ENFORCE(data.rotary_max_used == data.rotary_buffer_size, - "rotary_buffer: used ", data.rotary_max_used, - " bytes but allocated ", data.rotary_buffer_size); - ORT_ENFORCE(data.position_ids_max_used == data.position_ids_buffer_size, - "position_ids_buffer: used ", data.position_ids_max_used, - " bytes but allocated ", data.position_ids_buffer_size); -#endif - return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 5bf26e8c6edac..2536da9fe1379 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -35,7 +35,6 @@ class GroupQueryAttention final : public CudaKernel { bool disable_flash_attention_; bool disable_memory_efficient_attention_; bool disable_flash_decode_; - bool disable_fused_kv_; static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) IAllocatorUniquePtr zeros_; diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index c8a1629f21bce..0b6da63b31af6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -27,11 +27,11 @@ limitations under the License. #include #include -#include // For getenv #include #include +#include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/utils/debug_macros.h" #include "contrib_ops/cuda/bert/add_bias_transpose.h" #include "contrib_ops/cuda/bert/attention_impl.h" @@ -40,14 +40,16 @@ limitations under the License. #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/flash_attention/flash_api.h" #include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/group_query_attention_qkv.cuh" #include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include "contrib_ops/cuda/bert/rotary_common.cuh" #include "contrib_ops/cuda/bert/transformer_common.h" +#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "core/providers/cuda/cuda_common.h" + #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "contrib_ops/cuda/utils/dump_cuda_tensor.h" using namespace onnxruntime::cuda; @@ -59,6 +61,100 @@ namespace onnxruntime { namespace contrib { namespace cuda { +// ============================================================================ +// QKV Preprocessing Helpers +// ============================================================================ + +// Internal helper to get Q, K, V pointers, handling packed input +// +// This function orchestrates the preparation of Q, K, and V tensors for attention kernels. +// It performs: +// 1. Handling packed vs. unpacked QKV inputs. +// 2. Managing KV cache updates (appending new tokens). +// 3. Ensuring synchronization between past and present KV caches when necessary. +// 4. Launching the UnpackRoPEQuantizeAppend kernel to unpack, apply RoPE, and update caches. +// 5. Returning strict Q, K, V pointers ready for the core attention operation. +template +Status PrepareQKV( + cudaStream_t stream, + const int max_threads_per_block, + const GroupQueryAttentionParameters& parameters, + GroupQueryAttentionData& data, + const T*& q, + const T*& k, + const T*& v) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + + using CudaT = typename ToCudaType::MappedType; + CudaT* q_out = data.qkv_buffer; + + if (!parameters.is_packed_qkv && !parameters.do_rotary) { + q_out = nullptr; + } + + CudaT* k_final_ptr = reinterpret_cast(data.present_key); + CudaT* v_final_ptr = reinterpret_cast(data.present_value); + int final_max_seqlen = parameters.seqlen_present_kv_cache; + bool final_is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + + if (!parameters.past_present_share_buffer) { + size_t kv_buffer_size = (size_t)batch_size * kv_num_heads * final_max_seqlen * head_size * sizeof(CudaT); + CUDA_CALL_THROW(cudaMemsetAsync(data.present_key, 0, kv_buffer_size, stream)); + CUDA_CALL_THROW(cudaMemsetAsync(data.present_value, 0, kv_buffer_size, stream)); + } + + if (!parameters.past_present_share_buffer && data.past_key != nullptr && parameters.seqlen_past_kv_cache > 0) { + bool is_bnsh = (parameters.past_kv_format == AttentionQkvFormat::Q_K_V_BNSH); + if (is_bnsh) { + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * head_size * sizeof(CudaT); + size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * head_size * sizeof(CudaT); + size_t width = src_pitch; + size_t height = (size_t)batch_size * kv_num_heads; + + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + } else { + size_t src_pitch = (size_t)parameters.seqlen_past_kv_cache * kv_num_heads * head_size * sizeof(CudaT); + size_t dst_pitch = (size_t)parameters.seqlen_present_kv_cache * kv_num_heads * head_size * sizeof(CudaT); + size_t width = src_pitch; + size_t height = (size_t)batch_size; + + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_key, dst_pitch, data.past_key, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + CUDA_CALL_THROW(cudaMemcpy2DAsync(data.present_value, dst_pitch, data.past_value, src_pitch, width, height, + cudaMemcpyDeviceToDevice, stream)); + } + } + + ORT_RETURN_IF_ERROR(LaunchUnpackRoPEAppendKV( + parameters.is_packed_qkv ? reinterpret_cast(data.query) : nullptr, + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.query), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.key), + parameters.is_packed_qkv ? nullptr : reinterpret_cast(data.value), + q_out, k_final_ptr, v_final_ptr, + num_heads, kv_num_heads, head_size, sequence_length, batch_size, + final_max_seqlen, data.past_seq_lens, + reinterpret_cast(data.cos_cache), reinterpret_cast(data.sin_cache), + parameters.rotary_dim, data.position_ids, parameters.rotary_interleaved, + final_is_bnsh, + stream, max_threads_per_block)); + + if (q_out != nullptr) { + q = reinterpret_cast(q_out); + } else { + q = reinterpret_cast(data.query); + } + k = reinterpret_cast(k_final_ptr); + v = reinterpret_cast(v_final_ptr); + return Status::OK(); +} + ////////// Auxiliary Kernels for KV prep // Concat new to past in present. Supports past BSNH or past BNSH @@ -393,267 +489,6 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp return CUDA_CALL(cudaGetLastError()); } -// Fused kernel: Unpack QKV + Apply RoPE to Q and K + Append K/V directly to cache -// This eliminates 4 kernel launches: Unpack -> Rotate Q -> Rotate K -> Append K -> Append V -// Becomes: Single kernel that does all operations in one pass -// -// Bounds Safety: -// - cache_s = past_seq_len + s is guaranteed < max_seqlen by the caller (group_query_attention.cc) -// because present_sequence_length = max(past + new_seq_len) across batches, and the present -// buffer is allocated with seqlen_present_kv_cache >= total_seq_lens[b] for all b. -// - The kernel processes exactly batch_size * sequence_length * (Q+K+V hidden) elements, -// which matches the packed_qkv input size allocated by the model. -// -// RoPE Contiguity Requirement: -// - packed_qkv MUST be strictly contiguous with layout [B, S, (H_q + 2*H_kv) * D] -// - The half-split RoPE logic (RotaryDispatcher::apply) fetches pair elements at offset -// (h + rotary_dim/2) relative to the start of each head -// - If strided/non-contiguous inputs are ever supported, this pointer arithmetic must change -// -// Performance Optimization: -// Uses 3D grid layout to eliminate expensive integer divisions: -// - blockIdx.z = batch index (b) -// - blockIdx.y = sequence index (s) -// - blockIdx.x * blockDim.x + threadIdx.x = offset within QKV hidden dimension -// This removes 4 divisions (/, %) per thread that would otherwise be needed. -template -__global__ void UnpackQKVWithRoPEAndAppendKV( - const T* packed_qkv, // Input: packed QKV [B, S, (Q+K+V) hidden] - T* unpacked_q, // Output: rotated Q [B, S, Q_heads, H] (BSNH) - T* k_cache, // Output: K cache [B, N, MaxS, H] or [B, MaxS, N, H] - T* v_cache, // Output: V cache [B, N, MaxS, H] or [B, MaxS, N, H] - const int num_heads, - const int kv_num_heads, - const int head_size, - const int d, // QKV hidden stride = (num_heads + 2*kv_num_heads) * head_size - const int max_seqlen, // KV cache max sequence length - const int* past_seq_lens, - // RoPE params - const T* cos_cache, - const T* sin_cache, - const int rotary_dim, - const int64_t* position_ids, - const bool interleaved, - const bool is_cache_bnsh) { - // Vectorized load/store using float4 (16 bytes) - using LoadT = float4; - constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); - - // 3D grid layout eliminates integer division: - // - blockIdx.z = batch index (b) - obtained from grid dimension, no division needed - // - blockIdx.y = sequence index (s) - obtained from grid dimension, no division needed - // - linear thread index within (b, s) gives offset directly - const int b = blockIdx.z; - const int s = blockIdx.y; - const int offset_vec_idx = blockIdx.x * blockDim.x + threadIdx.x; // Vector index within d - const int offset = offset_vec_idx * elements_per_thread; // Element offset within d - - // Bounds check: offset must be within the QKV hidden dimension - if (offset >= d) return; - - const int q_hidden = num_heads * head_size; - const int k_hidden = kv_num_heads * head_size; - const int sequence_length = gridDim.y; // Get from grid dimension - - // Calculate linear index for packed_qkv load - const int64_t packed_idx = static_cast(b) * sequence_length * d + - static_cast(s) * d + offset; - - // Load vector from packed buffer - LoadT val_vec = reinterpret_cast(packed_qkv)[packed_idx / elements_per_thread]; - - // Common RoPE Calculations - const int past_seq_len = past_seq_lens[b]; - int pos_id = 0; - if (position_ids != nullptr) { - pos_id = static_cast(position_ids[b * sequence_length + s]); - } else { - pos_id = past_seq_len + s; - } - - // Determine Q, K, or V based on offset - if (offset < q_hidden) { - // Q: Apply RoPE and write to unpacked_q buffer (BSNH format) - const int q_head_idx = offset / head_size; - const int h = offset % head_size; - const int h_idx = h / elements_per_thread; - - if (cos_cache != nullptr && rotary_dim > 0 && h < rotary_dim) { - // For half-split RoPE, pair values should be read relative to the START of the current Q head. - // Calculate offset to head start: (b, s, q_head_n, 0) in packed QKV. - const int64_t q_head_start_in_packed = static_cast(b) * sequence_length * d + - static_cast(s) * d + - static_cast(q_head_idx) * head_size; - RotaryDispatcher::apply(val_vec, - reinterpret_cast(cos_cache), - reinterpret_cast(sin_cache), - rotary_dim, h_idx, pos_id, interleaved, - reinterpret_cast(packed_qkv), - q_head_start_in_packed / elements_per_thread); - } - - const int64_t q_idx = static_cast(b) * sequence_length * num_heads * head_size + - static_cast(s) * num_heads * head_size + offset; - // Vector store to unpacked_q - reinterpret_cast(unpacked_q)[q_idx / elements_per_thread] = val_vec; - - } else if (offset < q_hidden + k_hidden) { - // K: Apply RoPE and write DIRECTLY to K cache - const int k_offset = offset - q_hidden; - const int n = k_offset / head_size; - const int h = k_offset % head_size; - const int h_idx = h / elements_per_thread; - - if (cos_cache != nullptr && rotary_dim > 0 && h < rotary_dim) { - // For half-split RoPE, pair values should be read relative to the START of the current K head. - // Calculate offset to head start: (b, s, k_head_n, 0) in packed QKV. - const int64_t k_head_start_in_packed = static_cast(b) * sequence_length * d + - static_cast(s) * d + - q_hidden + - static_cast(n) * head_size; - RotaryDispatcher::apply(val_vec, - reinterpret_cast(cos_cache), - reinterpret_cast(sin_cache), - rotary_dim, h_idx, pos_id, interleaved, - reinterpret_cast(packed_qkv), - k_head_start_in_packed / elements_per_thread); - } - - const int cache_s = past_seq_len + s; - int64_t cache_idx; - if (is_cache_bnsh) { - cache_idx = static_cast(b) * kv_num_heads * max_seqlen * head_size + - static_cast(n) * max_seqlen * head_size + - static_cast(cache_s) * head_size + h; - } else { // BSNH - cache_idx = static_cast(b) * max_seqlen * kv_num_heads * head_size + - static_cast(cache_s) * kv_num_heads * head_size + - static_cast(n) * head_size + h; - } - // Vector store to k_cache - reinterpret_cast(k_cache)[cache_idx / elements_per_thread] = val_vec; - - } else { - // V: Write DIRECTLY to V cache (no rotation) - const int v_offset = offset - q_hidden - k_hidden; - const int n = v_offset / head_size; - const int h = v_offset % head_size; - - const int cache_s = past_seq_len + s; - int64_t cache_idx; - if (is_cache_bnsh) { - cache_idx = static_cast(b) * kv_num_heads * max_seqlen * head_size + - static_cast(n) * max_seqlen * head_size + - static_cast(cache_s) * head_size + h; - } else { // BSNH - cache_idx = static_cast(b) * max_seqlen * kv_num_heads * head_size + - static_cast(cache_s) * kv_num_heads * head_size + - static_cast(n) * head_size + h; - } - // Vector store to v_cache - reinterpret_cast(v_cache)[cache_idx / elements_per_thread] = val_vec; - } -} - -// Launcher for fused UnpackQKV + RoPE + KV Append -template -Status LaunchUnpackQKVWithRoPEAndAppendKV( - const T* packed_qkv, - T* unpacked_q, - T* k_cache, - T* v_cache, - const int num_heads, - const int kv_num_heads, - const int head_size, - const int sequence_length, - const int batch_size, - const int max_seqlen, - const int* past_seq_lens, - const T* cos_cache, - const T* sin_cache, - const int rotary_dim, - const int64_t* position_ids, - const bool interleaved, - const bool is_cache_bnsh, - cudaStream_t stream, - const int max_threads_per_block) { - // Determine vectorization factor (float4 is 16 bytes) - constexpr int vector_bytes = sizeof(float4); - constexpr int element_bytes = sizeof(T); - constexpr int elements_per_vector = vector_bytes / element_bytes; - - // Validate head_size alignment - if (head_size % elements_per_vector != 0) { - // If strict alignment is not met (unlikely given GQA constraints), we should fall back or fail. - // Typically GQA enforces head_size % 8 == 0. - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size must be divisible by ", elements_per_vector, " for vectorized GQA kernel."); - } - - // Validate grid dimensions - CUDA limits gridDim.y to 65535 - if (sequence_length > 65535) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Sequence length ", sequence_length, - " exceeds CUDA grid dimension limit (65535) for fused UnpackQKV kernel."); - } - -#ifndef NDEBUG - // Debug-mode alignment assertions for vectorized memory access - assert(reinterpret_cast(packed_qkv) % 16 == 0 && "packed_qkv must be 16-byte aligned"); - assert(reinterpret_cast(unpacked_q) % 16 == 0 && "unpacked_q must be 16-byte aligned"); - assert(reinterpret_cast(k_cache) % 16 == 0 && "k_cache must be 16-byte aligned"); - assert(reinterpret_cast(v_cache) % 16 == 0 && "v_cache must be 16-byte aligned"); - if (cos_cache != nullptr) { - assert(reinterpret_cast(cos_cache) % 16 == 0 && "cos_cache must be 16-byte aligned"); - assert(reinterpret_cast(sin_cache) % 16 == 0 && "sin_cache must be 16-byte aligned"); - } -#endif - - // QKV hidden dimension stride - const int d = (num_heads + 2 * kv_num_heads) * head_size; - const int d_vectors = d / elements_per_vector; // Number of vectors per (b, s) - - // 3D grid layout for eliminating integer divisions in kernel: - // grid.x = number of blocks needed to cover d_vectors with threads_per_block threads - // grid.y = sequence_length - // grid.z = batch_size - const int threads_per_block = std::min(max_threads_per_block, d_vectors); - const int blocks_x = (d_vectors + threads_per_block - 1) / threads_per_block; - const dim3 grid(blocks_x, sequence_length, batch_size); - const dim3 block(threads_per_block); - - UnpackQKVWithRoPEAndAppendKV<<>>( - packed_qkv, - unpacked_q, - k_cache, - v_cache, - num_heads, - kv_num_heads, - head_size, - d, - max_seqlen, - past_seq_lens, - cos_cache, - sin_cache, - rotary_dim, - position_ids, - interleaved, - is_cache_bnsh); - - return CUDA_CALL(cudaGetLastError()); -} - -// Explicit template instantiations -template Status LaunchUnpackQKVWithRoPEAndAppendKV( - const half*, half*, half*, half*, - int, int, int, int, int, int, const int*, - const half*, const half*, int, const int64_t*, bool, bool, - cudaStream_t, int); - -template Status LaunchUnpackQKVWithRoPEAndAppendKV( - const BFloat16*, BFloat16*, BFloat16*, BFloat16*, - int, int, int, int, int, int, const int*, - const BFloat16*, const BFloat16*, int, const int64_t*, bool, bool, - cudaStream_t, int); - // ============================================================================ // GetSequenceLengths Kernel // ============================================================================ @@ -697,6 +532,7 @@ __global__ void GetSequenceLengths(const int* total_seq_lens_minus_one, padded_seq_lens[i] = sequence_length; } else { past_seq_lens[i] = total_len - sequence_length; + padded_seq_lens[i] = 0; } } } @@ -716,20 +552,32 @@ Status LaunchGetSequenceLengths( return CUDA_CALL(cudaGetLastError()); } -////////// Kernels (supports right padding but not left padding) +// Trace function for debugging +#define ORT_GQA_TRACE(func_name) \ + DEBUG_PRINTF("[GQA %s] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, past_present_share_buffer: %d", \ + func_name, \ + static_cast(parameters.is_packed_qkv), \ + static_cast(parameters.is_first_prompt), \ + static_cast(parameters.is_subsequent_prompt), \ + static_cast(parameters.past_present_share_buffer)); +////////// Kernels (supports right padding but not left padding) +// Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. +// Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. #if USE_FLASH_ATTENTION // Use flash attention for all workloads (rotary, kv append, attention, etc.). No extra kernel is used in this path. // Currently, only decoding or subsequent prompt can use this path. First prompt will not use this path. template -Status FlashAttentionDecoding( +Status FlashDecoding( const cudaDeviceProp& device_prop, cudaStream_t stream, GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, float scale) { - assert(!parameters.is_first_prompt && parameters.kv_share_buffer); + assert(!parameters.is_first_prompt && parameters.past_present_share_buffer); + + ORT_GQA_TRACE("FlashDecoding"); const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; @@ -757,8 +605,8 @@ Status FlashAttentionDecoding( void* seqlens_k = reinterpret_cast(data.past_seq_lens); - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* present_key = data.present_key; + void* present_value = data.present_value; void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); void* head_sink = reinterpret_cast(const_cast(data.head_sink)); @@ -773,7 +621,8 @@ Status FlashAttentionDecoding( parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), - parameters.local_window_size - 1, parameters.rotary_interleaved, parameters.is_packed_qkv)); + parameters.local_window_size - 1, parameters.rotary_interleaved, parameters.is_packed_qkv, + 0, 1)); return Status::OK(); } @@ -799,242 +648,21 @@ Status FlashAttention( bool is_causal = parameters.is_unidirectional; bool is_bf16 = std::is_same::value; - void* query = reinterpret_cast(const_cast(data.query)); - void* key; - void* value; - - if (!parameters.is_packed_qkv) { - key = reinterpret_cast(const_cast(data.key)); - value = reinterpret_cast(const_cast(data.value)); - } else { - const size_t key_offset = static_cast(num_heads * head_size); - const size_t value_offset = static_cast(kv_num_heads * head_size); - key = reinterpret_cast(query) + key_offset; - value = reinterpret_cast(key) + value_offset; - } - -#if DUMP_TENSOR_LEVEL > 0 - printf("[GQA FlashAttention] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, kv_share_buffer: %d\n", - static_cast(parameters.is_packed_qkv), - static_cast(parameters.is_first_prompt), - static_cast(parameters.is_subsequent_prompt), - static_cast(parameters.kv_share_buffer)); -#endif DUMP_TENSOR_INIT(); - // Track whether we keep packed QKV for FA kernels - bool use_packed_for_fa = parameters.is_packed_qkv; - - // Track if we used the fully fused path (packed + share_buffer + rotary) - bool used_fused_packed_path = false; - - // ========================================================================= - // Handle Packed QKV Input - // ========================================================================= - if (parameters.is_packed_qkv) { - T* unpacked_buffer = reinterpret_cast(data.unpacked_qkv_buffer); - if (unpacked_buffer != nullptr) { - T* unpacked_q = unpacked_buffer; - - // Check if we can use the fully fused path - if (parameters.kv_share_buffer && parameters.do_rotary && !data.disable_fused_kv) { - // FULLY FUSED PATH: Unpack + RoPE Q + RoPE K + Append KV in single kernel - // This eliminates 4 kernel launches! - ORT_RETURN_IF_ERROR(LaunchUnpackQKVWithRoPEAndAppendKV( - reinterpret_cast(data.query), // packed QKV - unpacked_q, // Q output buffer (rotated) - data.present_key, // K cache (direct write) - data.present_value, // V cache (direct write) - num_heads, - kv_num_heads, - head_size, - sequence_length, - batch_size, - parameters.seqlen_present_kv_cache, - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - parameters.rotary_dim, - data.position_ids, - parameters.rotary_interleaved, - !past_bsnh, // is_cache_bnsh - stream, - max_threads_per_block)); - - // Update query to point to rotated Q - query = unpacked_q; - use_packed_for_fa = false; - used_fused_packed_path = true; - - // Track buffer usage: Only Q is stored in unpacked_qkv_buffer (fused path writes K/V to cache) - size_t q_bytes = static_cast(batch_size) * sequence_length * num_heads * head_size * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, q_bytes); - - // K and V are already in cache - no need to set key/value pointers + const T* q_prep = nullptr; + const T* k_prep = nullptr; + const T* v_prep = nullptr; + ORT_RETURN_IF_ERROR(PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep, k_prep, v_prep)); - } else { - // Standard path: Unpack first, then process K/V separately - size_t q_size = static_cast(batch_size) * sequence_length * num_heads * head_size; - T* unpacked_k = unpacked_buffer + q_size; + void* query = const_cast(q_prep); + (void)k_prep; // Key/value are now processed by PrepareQKV + (void)v_prep; - size_t k_size = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - T* unpacked_v = unpacked_k + k_size; + bool use_packed_for_fa = false; - // If we need Q rotation, we MUST unpack Q as well. - T* q_dst = parameters.do_rotary ? unpacked_q : nullptr; - - // Always unpack to BSNH as LaunchConcatNewToPastKV expects contiguous BSNH input - ORT_RETURN_IF_ERROR((LaunchUnpackQKV(reinterpret_cast(data.query), q_dst, unpacked_k, unpacked_v, num_heads, kv_num_heads, head_size, sequence_length, batch_size, stream, max_threads_per_block))); - - // Update key/value to point to unpacked buffers - key = unpacked_k; - value = unpacked_v; - - if (parameters.do_rotary) { - query = unpacked_q; - use_packed_for_fa = false; - } - - // Track buffer usage: Q+K+V unpacked - size_t total_bytes = (q_size + 2 * k_size) * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, total_bytes); - } - } - } - // ========================================================================= - // Handle Unpacked Q, K, V Input (with optional RoPE) - // ========================================================================= - else { - if (parameters.do_rotary) { - // For unpacked input, we need to rotate Q and K. - // The rotated Q and K will be stored in unpacked_qkv_buffer with layout [Q (B*S*H*D), K (B*S*H_kv*D)]. - T* unpacked_buffer = reinterpret_cast(data.unpacked_qkv_buffer); - if (unpacked_buffer != nullptr) { - query = unpacked_buffer; - // Do not update key here for Unpacked path. - // key must remain pointing to data.key (Input) for Explicit K Rotation (k_src). - // k_dst will be calculated from unpacked_buffer explicitly. - } - } - } - - const int64_t* position_ids = data.position_ids; - - // Explicit Q Rotation (skip if fused path already applied RoPE) - if (parameters.do_rotary && !used_fused_packed_path) { - // Rotate Q - // Q ptr is already set to the destination buffer (unpacked_buffer) above. - // Input for Rotation: - // If packed: we unpacked into `query` buffer. So Input==Output (In-place). - // If unpacked: we set `query = unpacked_buffer`. But Input is `data.query`. - const T* q_input_for_rope = parameters.is_packed_qkv ? reinterpret_cast(query) : reinterpret_cast(data.query); - T* q_output_for_rope = reinterpret_cast(query); // Destination - - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, - q_output_for_rope, - q_input_for_rope, - nullptr, // position_ids unused for format 2/3 - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - batch_size, - sequence_length, - num_heads, - head_size, - parameters.rotary_dim, - parameters.max_sequence_length, - 2, // position_ids_format = 2 (Implicit: past_seq_lens[b] + s) - parameters.rotary_interleaved, - max_threads_per_block, - false // is_input_bnsh_format (Q is BSNH) - )); - DUMP_TENSOR("Rotated Q", q_output_for_rope, batch_size, sequence_length, num_heads, head_size); - - // Rotate K will be done later in fused kernel. - } - - // Skip KV append if we used the fully fused path (KV already in cache) - if (!used_fused_packed_path) { - if (parameters.kv_share_buffer && !parameters.is_first_prompt) { - constexpr bool is_new_kv_bnsh_format = false; - if (parameters.do_rotary) { - // Explicit K Rotation (replacing internal RoPE in fused kernel) - size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; - T* k_dst = reinterpret_cast(data.unpacked_qkv_buffer) + q_elements; - const T* k_src = reinterpret_cast(key); - - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, - k_dst, - k_src, - position_ids, - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - batch_size, - sequence_length, - kv_num_heads, - head_size, - parameters.rotary_dim, - parameters.max_sequence_length, - position_ids != nullptr ? 1 : 2, - parameters.rotary_interleaved, - max_threads_per_block, - false)); - - if (!data.disable_fused_kv) { - // Use fused kernel for K (rotated) + V append - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlaceFused( - batch_size, - kv_num_heads, - head_size, - parameters.seqlen_present_kv_cache, - data.past_seq_lens, - data.total_seq_lens, - sequence_length, - k_dst, - reinterpret_cast(data.value), - data.present_key, - data.present_value, - !past_bsnh, - is_new_kv_bnsh_format, - stream, - max_threads_per_block)); - } else { - // Unfused Fallback: LaunchConcatKVInPlace - // We must pass the ROTATED K (k_dst) to it. - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - parameters, data, k_dst, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } - - // Track buffer usage: Q + K rotated in unpacked_qkv_buffer - size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - size_t total_bytes = (q_elements + k_elements) * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, total_bytes); - } else { - // No RoPE - use original kernel - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } - } else { - // ORT MUST perform the append (using unpacked data for packed case) - bool skip_new_append = false; - // FUSED ROPE: Pass RoPE params to ConcatKV (applies RoPE to K as it is appended) - // IMPORTANT: For Fused RoPE with unpacked input, we must pass data.key (the original input), - // not the scratch buffer 'key' which is empty since explicit rotation was skipped. - const void* key_for_concat = parameters.is_packed_qkv ? key : data.key; - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKVHelper(parameters, data, key_for_concat, value, stream, max_threads_per_block, skip_new_append, - data.cos_cache, data.sin_cache, parameters.rotary_dim, nullptr, parameters.rotary_interleaved)); - } - } - - DUMP_TENSOR("Total Seq Lens", data.total_seq_lens, batch_size, 1); - DUMP_TENSOR("Past Seq Lens", data.past_seq_lens, batch_size, 1); - DUMP_TENSOR("Present Key", data.present_key, batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - DUMP_TENSOR("Present Value", data.present_value, batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* present_key = data.present_key; + void* present_value = data.present_value; // Disable internal RoPE in Flash Attention (pass nullptr) void* cos_cache = nullptr; @@ -1047,7 +675,6 @@ Status FlashAttention( void* kernel_new_v = nullptr; // Use padded seq lens for first prompt since mha_fwd_kvcache assumes uniform seqlen_q. - // The causal mask offset (seqlen_k - seqlen_q) becomes negative when seqlen_k < seqlen_q, causing incorrect masking. int* seq_lens = parameters.is_first_prompt ? data.padded_seq_lens : data.total_seq_lens; ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( @@ -1057,12 +684,16 @@ Status FlashAttention( /*cache_batch_idx*/ nullptr, /*leftpad_k*/ nullptr, head_sink, /*block_table*/ nullptr, batch_size, num_heads, kv_num_heads, head_size, sequence_length, parameters.seqlen_present_kv_cache, kv_sequence_length, - parameters.rotary_dim, scale, parameters.softcap, is_causal, is_bf16, + 0, // rotary_dim = 0 as it is already rotated + scale, parameters.softcap, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), reinterpret_cast(data.out_accum), parameters.local_window_size - 1, parameters.rotary_interleaved, use_packed_for_fa, 0, 1)); + DUMP_TENSOR("Total Seq Lens", data.total_seq_lens, batch_size, 1); + DUMP_TENSOR("Past Seq Lens", data.past_seq_lens, batch_size, 1); + return Status::OK(); } #endif @@ -1084,164 +715,17 @@ Status EfficientAttention( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; -#if DUMP_TENSOR_LEVEL > 0 - printf("[GQA EfficientAttention] is_packed_qkv: %d, is_first_prompt: %d, is_subsequent_prompt: %d, kv_share_buffer: %d\n", - static_cast(parameters.is_packed_qkv), - static_cast(parameters.is_first_prompt), - static_cast(parameters.is_subsequent_prompt), - static_cast(parameters.kv_share_buffer)); -#endif - - const void* query; - const void* key; - const void* value; - - if (!parameters.is_packed_qkv) { - query = reinterpret_cast(data.query); - key = reinterpret_cast(data.key); - value = reinterpret_cast(data.value); - } else { - size_t q_size = static_cast(batch_size) * sequence_length * num_heads * head_size; - size_t k_size = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - auto q = reinterpret_cast(data.unpacked_qkv_buffer); - auto k = reinterpret_cast(data.unpacked_qkv_buffer + q_size); - auto v = reinterpret_cast(data.unpacked_qkv_buffer + q_size + k_size); - - Status status = LaunchUnpackQKV( - reinterpret_cast(data.query), q, k, v, num_heads, kv_num_heads, - head_size, sequence_length, batch_size, stream, max_threads_per_block); - if (status != Status::OK()) { - return status; - } - - query = reinterpret_cast(q); - key = reinterpret_cast(k); - value = reinterpret_cast(v); - - // Track buffer usage: Q+K+V unpacked - size_t total_bytes = (q_size + 2 * k_size) * sizeof(T); - UpdateUnpackedQkvMaxUsed(data, total_bytes); - } - - const int64_t* position_ids = data.position_ids; - if (parameters.do_rotary) { - auto q_buffer = reinterpret_cast(data.rotary_buffer); - - // Launch rotary embedding kernel for Q - if (position_ids != nullptr) { - // User provided explicit position_ids - Use Format 1 - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, q_buffer, reinterpret_cast(query), - position_ids, nullptr /*past_seq_lens not used in format 1*/, - data.cos_cache, data.sin_cache, - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.max_sequence_length, - 1, // Format 1: Explicit position_ids - parameters.rotary_interleaved, - max_threads_per_block, - false)); - } else { - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, q_buffer, reinterpret_cast(query), - nullptr, data.past_seq_lens, - data.cos_cache, data.sin_cache, - parameters.batch_size, parameters.sequence_length, - parameters.num_heads, parameters.head_size, - parameters.rotary_dim, parameters.max_sequence_length, - 2, // Format 2: Implicit (past_seq_lens[b] + s) - parameters.rotary_interleaved, - max_threads_per_block, - false)); - } - query = reinterpret_cast(q_buffer); - - // For kv_share_buffer path, we use Fused RoPE in LaunchConcatKVInPlaceWithRoPE. - // For non-share-buffer path, we use Fused RoPE in LaunchConcatNewToPastKVHelper. - // No explicit K rotation needed here - handled by fused kernels. - - // key remains pointing to original source for use in fused kernel below - - // Track rotary buffer usage: Q rotated (K rotation is fused in KV append) - size_t q_bytes = static_cast(batch_size) * sequence_length * num_heads * head_size * sizeof(T); - size_t k_bytes = static_cast(batch_size) * sequence_length * kv_num_heads * head_size * sizeof(T); - // Note: rotary_buffer layout is [Q_rotated, K_rotated] - no position_ids here - UpdateRotaryMaxUsed(data, q_bytes + k_bytes); + ORT_GQA_TRACE("EfficientAttention"); - // Track position_ids_buffer usage - size_t pos_ids_bytes = static_cast(batch_size) * sequence_length * sizeof(int64_t); - UpdatePositionIdsMaxUsed(data, pos_ids_bytes); - } + const T* q_prep = nullptr; + const T* k_prep = nullptr; + const T* v_prep = nullptr; + ORT_RETURN_IF_ERROR(PrepareQKV(stream, max_threads_per_block, parameters, data, q_prep, k_prep, v_prep)); - if (parameters.kv_share_buffer) { - // Concatenate new kv in place - constexpr bool is_new_kv_bnsh_format = false; - - if (parameters.do_rotary) { - // Explicit K Rotation - size_t q_elements = static_cast(batch_size) * sequence_length * num_heads * head_size; - T* k_dst = reinterpret_cast(data.rotary_buffer) + q_elements; - const T* k_src = reinterpret_cast(key); - - ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel( - stream, - k_dst, - k_src, - position_ids, - data.past_seq_lens, - data.cos_cache, - data.sin_cache, - batch_size, - sequence_length, - parameters.kv_num_heads, - parameters.head_size, - parameters.rotary_dim, - parameters.max_sequence_length, - position_ids != nullptr ? 1 : 2, - parameters.rotary_interleaved, - max_threads_per_block, - false)); - - if (!data.disable_fused_kv) { - // Use truly fused kernel for K (already rotated) + V append in single kernel - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlaceFused( - batch_size, - parameters.kv_num_heads, - parameters.head_size, - parameters.seqlen_present_kv_cache, - data.past_seq_lens, - data.total_seq_lens, - parameters.sequence_length, - k_dst, - reinterpret_cast(value), - data.present_key, - data.present_value, - past_kv_format != AttentionQkvFormat::Q_K_V_BSNH, // is_past_kv_bnsh_format - is_new_kv_bnsh_format, - stream, - max_threads_per_block)); - } else { - // Unfused Fallback - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - parameters, data, k_dst, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } + const void* query = reinterpret_cast(q_prep); + const void* key = reinterpret_cast(k_prep); + const void* value = reinterpret_cast(v_prep); - // Track rotary buffer usage: Q + K rotated (no position_ids in rotary_buffer) - size_t k_elements = static_cast(batch_size) * sequence_length * kv_num_heads * head_size; - UpdateRotaryMaxUsed(data, (q_elements + k_elements) * sizeof(T)); - } else { - // No RoPE - use original kernel - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace( - parameters, data, key, value, is_new_kv_bnsh_format, stream, max_threads_per_block)); - } - } else { - // Copy past and concat new KV to present buffer - // FUSED ROPE: Pass RoPE params to ConcatKV - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKVHelper(parameters, data, key, value, stream, max_threads_per_block, false, - data.cos_cache, data.sin_cache, parameters.rotary_dim, nullptr, parameters.rotary_interleaved)); - } - - // Ungroup if grouped, otherwise use present kv directly const bool is_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; if (num_heads == kv_num_heads) { // Use present kv directly if not grouped @@ -1309,7 +793,7 @@ Status QkvToContext( #if USE_FLASH_ATTENTION if (data.use_flash_attention_fast_decode) { - return FlashAttentionDecoding(device_prop, stream, parameters, data, scale); + return FlashDecoding(device_prop, stream, parameters, data, scale); } if (data.use_flash_attention) { @@ -1327,6 +811,7 @@ Status QkvToContext( } template struct GroupQueryAttentionData; +template struct GroupQueryAttentionData; template Status QkvToContext( const cudaDeviceProp& device_prop, @@ -1335,24 +820,15 @@ template Status QkvToContext( contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); -template Status LaunchUnpackQKV( - const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, - const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, - cudaStream_t stream, const int max_threads_per_block); - -template struct GroupQueryAttentionData; - template Status QkvToContext( const cudaDeviceProp& device_prop, cublasHandle_t& cublas, Stream* ort_stream, - GroupQueryAttentionParameters& parameters, + contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data); -template Status LaunchUnpackQKV( - const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads, - const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, - cudaStream_t stream, const int max_threads_per_block); +template Status LaunchUnpackQKV(const half* packed_qkv, half* unpacked_q, half* unpacked_k, half* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); +template Status LaunchUnpackQKV(const BFloat16* packed_qkv, BFloat16* unpacked_q, BFloat16* unpacked_k, BFloat16* unpacked_v, const int num_heads, const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, cudaStream_t stream, const int max_threads_per_block); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index c42fe53e4b625..4ad71c5003e0e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -39,12 +39,9 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp // auto req = GQABufferRequirements::Compute(params, use_flash, fast_decode, use_mea, disable_fused); // unpacked_qkv_buffer = GetScratchBuffer(req.unpacked_qkv_bytes, ...); // rotary_buffer = GetScratchBuffer(req.rotary_buffer_bytes, ...); -// position_ids_buffer = GetScratchBuffer(req.position_ids_bytes, ...); // ============================================================================ struct GQABufferRequirements { - size_t unpacked_qkv_bytes = 0; - size_t rotary_buffer_bytes = 0; - size_t position_ids_bytes = 0; + size_t qkv_buffer_bytes = 0; template static GQABufferRequirements Compute( @@ -53,6 +50,9 @@ struct GQABufferRequirements { bool use_flash_attention_fast_decode, bool use_memory_efficient_attention) { GQABufferRequirements req; + if (use_flash_attention_fast_decode) { + return req; // All zeros - no scratch buffers needed + } const size_t elem_size = sizeof(T); const size_t batch_size = static_cast(params.batch_size); @@ -61,49 +61,36 @@ struct GQABufferRequirements { const size_t kv_num_heads = static_cast(params.kv_num_heads); const size_t head_size = static_cast(params.head_size); - // Fast decode path: Flash Attention handles everything internally - if (use_flash_attention_fast_decode) { - return req; // All zeros - no scratch buffers needed - } - - // Q, K, V element counts + // Base requirements for all paths const size_t q_elements = batch_size * seq_len * num_heads * head_size; const size_t k_elements = batch_size * seq_len * kv_num_heads * head_size; const size_t v_elements = k_elements; if (use_flash_attention) { // Flash Attention path: - // - unpacked_qkv_buffer is used for: - // 1. Unpacking packed QKV input - // 2. Storing rotated Q (and K for non-fused path) - // - rotary_buffer is NOT used (rotations go to unpacked_qkv_buffer) - // - position_ids_buffer is NOT used (flash attention uses implicit position IDs) - - if (params.is_packed_qkv) { - // Need full Q+K+V for unpacking - req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements + v_elements); - } else if (params.do_rotary) { - // Unpacked input with RoPE: need Q+K for rotation output - req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements); + // qkv_buffer is used for: + // 1. Unpacking packed Q (and K/V if needed) + // 2. Storing rotated Q + // + // Logic: + // - we generally only need Q buffer (for rotary Q) if we can write K/V directly to cache/output. + + if (params.do_rotary || params.is_packed_qkv) { + // Just Q buffer needed for rotation/unpacking. + // K and V are written directly to present_key/value (unpacked/rotated/quantized/appended). + req.qkv_buffer_bytes = elem_size * q_elements; } - // Note: unpacked + no-RoPE case does NOT need unpacked_qkv_buffer - } else if (use_memory_efficient_attention) { // Memory Efficient Attention path: - // - unpacked_qkv_buffer: for unpacking packed QKV - // - rotary_buffer: for Q and K rotation output (separate from unpack buffer) - // - position_ids_buffer: for explicit position IDs if needed + // - qkv_buffer: for unpacking packed QKV or Q rotation + // MEA path usually needs Q, and also K, V if they need unpacking. + // Current MEA implementation can handle separate K/V, but if packed, we unpack all. if (params.is_packed_qkv) { - req.unpacked_qkv_bytes = elem_size * (q_elements + k_elements + v_elements); - } - - if (params.do_rotary) { + req.qkv_buffer_bytes = elem_size * (q_elements + k_elements + v_elements); + } else if (params.do_rotary) { // Q rotation + K rotation - // Note: K uses kv_num_heads which may be less than num_heads - req.rotary_buffer_bytes = elem_size * (q_elements + k_elements); - // Position IDs space (always allocated for MEA + RoPE path) - req.position_ids_bytes = sizeof(int64_t) * batch_size * seq_len; + req.qkv_buffer_bytes = elem_size * (q_elements + k_elements); } } @@ -111,47 +98,6 @@ struct GQABufferRequirements { } }; -// ============================================================================ -// Debug helper for tracking buffer usage -// ============================================================================ -// Call these after buffer access to record the maximum offset used. -// In release builds, these are no-ops. -// -// Example: -// T* unpacked_q = data.unpacked_qkv_buffer; -// // ... kernel writes to unpacked_q[0..Q_size-1] ... -// UpdateUnpackedQkvMaxUsed(data, Q_size * sizeof(T)); -// ============================================================================ -#ifndef NDEBUG -template -inline void UpdateUnpackedQkvMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { - if (bytes_used > data.unpacked_qkv_max_used) { - data.unpacked_qkv_max_used = bytes_used; - } -} - -template -inline void UpdateRotaryMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { - if (bytes_used > data.rotary_max_used) { - data.rotary_max_used = bytes_used; - } -} - -template -inline void UpdatePositionIdsMaxUsed(GroupQueryAttentionData& data, size_t bytes_used) { - if (bytes_used > data.position_ids_max_used) { - data.position_ids_max_used = bytes_used; - } -} -#else -template -inline void UpdateUnpackedQkvMaxUsed(GroupQueryAttentionData&, size_t) {} -template -inline void UpdateRotaryMaxUsed(GroupQueryAttentionData&, size_t) {} -template -inline void UpdatePositionIdsMaxUsed(GroupQueryAttentionData&, size_t) {} -#endif - Status LaunchGetSequenceLengths( const int* total_seq_lens_minus_one, int* past_seq_lens, @@ -163,6 +109,16 @@ Status LaunchGetSequenceLengths( cudaStream_t stream, const int max_threads_per_block); +template +Status LaunchUnpackRoPEAppendKV( + const T* packed_qkv, const T* query, const T* key, const T* value, + T* unpacked_q, T* k_cache, T* v_cache, + const int num_heads, const int kv_num_heads, const int head_size, + const int sequence_length, const int batch_size, const int max_seqlen, + const int* past_seq_lens, const T* cos_cache, const T* sin_cache, + const int rotary_dim, const int64_t* position_ids, const bool interleaved, + const bool is_cache_bnsh, cudaStream_t stream, const int max_threads_per_block); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh new file mode 100644 index 0000000000000..ddf24aff27442 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_qkv.cuh @@ -0,0 +1,248 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include + +#include "contrib_ops/cuda/bert/group_query_attention_impl.h" +#include "contrib_ops/cpu/bert/attention_common.h" +#include "contrib_ops/cuda/bert/rotary_common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Fused kernel: Unpack QKV + Apply RoPE to Q and K + Append K/V directly to cache +// +// OPTIMIZATION: This version uses Shared Memory to store the current head being processed. +// Shared memory allows RoPE dispatcher to access paired elements in non-interleaved mode +// (element i pairs with i ± rotary_dim/2) without global memory gathers. +// +// Alignment Note: This kernel assumes that base pointers (packed_qkv, query, etc.) +// are 16-byte aligned and that head_size is a multiple of elements_per_thread. +// +// Grid Layout: +// blockIdx.x: sequence index (s) -> Max 2^31-1 (Supports very long context) +// blockIdx.y: head index (head_idx) -> Max 65535 +// blockIdx.z: batch index (b) -> Max 65535 +template +__global__ void UnpackRoPEAppend( + const T* packed_qkv, + const T* query, + const T* key, + const T* value, + T* unpacked_q, + T* k_cache, + T* v_cache, + const int num_heads, + const int kv_num_heads, + const int head_size, + const int d, // packed QKV hidden stride = (num_heads + 2*kv_num_heads) * head_size + const int max_seqlen, // KV cache max sequence length + const int* past_seq_lens, + const T* cos_cache, + const T* sin_cache, + const int rotary_dim, + const int64_t* position_ids, + const bool interleaved, + const bool is_cache_bnsh) { + using LoadT = float4; + constexpr int elements_per_thread = sizeof(LoadT) / sizeof(T); + + const int s = blockIdx.x; + const int head_idx = blockIdx.y; + const int b = blockIdx.z; + const int tid = threadIdx.x; + const int h = tid * elements_per_thread; + + // Guard work with 'valid' instead of early return to ensure all threads reach __syncthreads() + const bool valid = (h < head_size); + + const int q_hidden = num_heads * head_size; + const int k_hidden = kv_num_heads * head_size; + const int sequence_length = gridDim.x; + + __shared__ T shared_head[MAX_HEAD_SIZE]; + + // Determine Head Type and Offset within hidden dimension + enum HeadType { QUERY, + KEY, + VALUE }; + HeadType head_type; + int n; // Index within its specific type + int offset_in_hidden; + + if (head_idx < num_heads) { + head_type = QUERY; + n = head_idx; + offset_in_hidden = n * head_size; + } else if (head_idx < num_heads + kv_num_heads) { + head_type = KEY; + n = head_idx - num_heads; + offset_in_hidden = q_hidden + n * head_size; + } else { + head_type = VALUE; + n = head_idx - (num_heads + kv_num_heads); + offset_in_hidden = q_hidden + k_hidden + n * head_size; + } + + // 1. Load data into Registers + T vals[elements_per_thread]; + if (valid) { + if (packed_qkv != nullptr) { + const int64_t packed_idx = static_cast(b) * sequence_length * d + + static_cast(s) * d + + static_cast(offset_in_hidden) + h; + *reinterpret_cast(vals) = reinterpret_cast(packed_qkv)[packed_idx / elements_per_thread]; + } else { + if (head_type == QUERY) { + const int64_t q_idx = static_cast(b) * sequence_length * q_hidden + + static_cast(s) * q_hidden + + static_cast(n) * head_size + h; + *reinterpret_cast(vals) = reinterpret_cast(query)[q_idx / elements_per_thread]; + } else if (head_type == KEY) { + const int64_t k_idx = static_cast(b) * sequence_length * k_hidden + + static_cast(s) * k_hidden + + static_cast(n) * head_size + h; + *reinterpret_cast(vals) = reinterpret_cast(key)[k_idx / elements_per_thread]; + } else { + const int64_t v_idx = static_cast(b) * sequence_length * k_hidden + + static_cast(s) * k_hidden + + static_cast(n) * head_size + h; + *reinterpret_cast(vals) = reinterpret_cast(value)[v_idx / elements_per_thread]; + } + } + } + + // 2. Process RoPE + // Optimization: Only use shared memory for non-interleaved mode + const bool is_qk = (head_type == QUERY || head_type == KEY); + if (valid && rotary_dim > 0 && is_qk && !interleaved) { + T* shared_ptr = &shared_head[h]; + *reinterpret_cast(shared_ptr) = *reinterpret_cast(vals); + } + + // CRITICAL: Barrier must be outside the 'if(valid)' and 'if(is_qk)' blocks + // to ensure every thread in the block participates. + __syncthreads(); + + if (valid && rotary_dim > 0 && is_qk) { + const int past_seq_len = past_seq_lens[b]; + const int64_t pos_base = static_cast(b) * sequence_length; + int pos_id = (position_ids != nullptr) ? static_cast(position_ids[pos_base + s]) : (past_seq_len + s); + const int h_idx = h / elements_per_thread; + + onnxruntime::contrib::cuda::RotaryDispatcher::apply( + *reinterpret_cast(vals), + reinterpret_cast(cos_cache), + reinterpret_cast(sin_cache), + rotary_dim, h_idx, pos_id, interleaved, + reinterpret_cast(shared_head), + 0); + } + + // 3. Store results back to Global Memory + if (valid) { + if (head_type == QUERY) { + if (unpacked_q != nullptr) { + const int64_t q_out_idx = static_cast(b) * sequence_length * q_hidden + + static_cast(s) * q_hidden + + static_cast(n) * head_size + h; + reinterpret_cast(unpacked_q)[q_out_idx / elements_per_thread] = *reinterpret_cast(vals); + } + } else { + const int cache_s = past_seq_lens[b] + s; + if (cache_s < max_seqlen) { + T* cache_ptr = (head_type == KEY) ? k_cache : v_cache; + if (cache_ptr != nullptr) { + int64_t cache_idx = is_cache_bnsh ? (static_cast(b) * kv_num_heads * max_seqlen * head_size + static_cast(n) * max_seqlen * head_size + static_cast(cache_s) * head_size + h) : (static_cast(b) * max_seqlen * kv_num_heads * head_size + static_cast(cache_s) * kv_num_heads * head_size + static_cast(n) * head_size + h); + reinterpret_cast(cache_ptr)[cache_idx / elements_per_thread] = *reinterpret_cast(vals); + } + } + } + } +} + +template +Status LaunchUnpackRoPEAppendKV( + const T* packed_qkv, const T* query, const T* key, const T* value, + T* unpacked_q, T* k_cache, T* v_cache, + const int num_heads, const int kv_num_heads, const int head_size, + const int sequence_length, const int batch_size, const int max_seqlen, + const int* past_seq_lens, const T* cos_cache, const T* sin_cache, + const int rotary_dim, const int64_t* position_ids, const bool interleaved, + const bool is_cache_bnsh, cudaStream_t stream, const int max_threads_per_block) { + constexpr int elements_per_vector = sizeof(float4) / sizeof(T); + + if (head_size % elements_per_vector != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size must be divisible by vector size (16 bytes)."); + } + + // rotary_dim <= head_size check to prevent out-of-bounds in shared memory + if (rotary_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "rotary_dim (", rotary_dim, ") cannot exceed head_size (", head_size, ")."); + } + + if (!interleaved && rotary_dim % 2 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Non-interleaved RoPE requires even rotary_dim."); + } + + const int total_heads = num_heads + 2 * kv_num_heads; + const int d = total_heads * head_size; + + const int threads_per_block = (head_size + elements_per_vector - 1) / elements_per_vector; + if (threads_per_block > max_threads_per_block) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size too large for current block configuration."); + } + + if (total_heads > 65535) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Total heads (", total_heads, ") exceeds CUDA grid limit (65535)."); + } + if (batch_size > 65535) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "batch_size (", batch_size, ") exceeds CUDA grid limit (65535)."); + } + + const dim3 grid(sequence_length, total_heads, batch_size); + const dim3 block(threads_per_block); + + // Dynamic dispatch for MAX_HEAD_SIZE templates to improve occupancy for common LLM head sizes + if (head_size <= 64) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); + } else if (head_size <= 128) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); + } else if (head_size <= 256) { + UnpackRoPEAppend<<>>( + packed_qkv, query, key, value, unpacked_q, k_cache, v_cache, + num_heads, kv_num_heads, head_size, d, max_seqlen, past_seq_lens, + cos_cache, sin_cache, rotary_dim, position_ids, interleaved, is_cache_bnsh); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Head size (", head_size, ") exceeds maximum supported MAX_HEAD_SIZE (256)."); + } + + return CUDA_CALL(cudaGetLastError()); +} + +// Explicit template instantiations +template Status LaunchUnpackRoPEAppendKV( + const half*, const half*, const half*, const half*, half*, half*, half*, + int, int, int, int, int, int, const int*, const half*, const half*, int, const int64_t*, bool, bool, + cudaStream_t, int); + +template Status LaunchUnpackRoPEAppendKV( + const BFloat16*, const BFloat16*, const BFloat16*, const BFloat16*, BFloat16*, BFloat16*, BFloat16*, + int, int, int, int, int, int, const int*, const BFloat16*, const BFloat16*, int, const int64_t*, bool, bool, + cudaStream_t, int); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index ce6b4724af705..0c1e346503194 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -17,7 +17,7 @@ namespace onnxruntime { namespace contrib { namespace cuda { -template +template __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const T* input, // BxSxNxH const T* cos_cache, // Mx(H/2) @@ -38,15 +38,30 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNx const int i = threadIdx.x; + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; + + [[maybe_unused]] extern __shared__ char smem_[]; + [[maybe_unused]] T* smem = reinterpret_cast(smem_); + + if constexpr (use_smem) { + // Load to shared memory for safe in-place update + if (i < head_size) { + smem[i] = input_data[i]; + } + __syncthreads(); + } + if (i >= head_size) { return; } - const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; - T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; - if (i >= rotary_embedding_dim) { - output_data[i] = input_data[i]; + if constexpr (use_smem) { + output_data[i] = smem[i]; + } else { + output_data[i] = input_data[i]; + } return; } @@ -79,7 +94,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNx sign = (i < half_rotary_embedding_dim) ? -1 : 1; j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; } - output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + + // Use values from shared memory + if constexpr (use_smem) { + output_data[i] = smem[i] * cos_data[cache_idx] + sign * smem[j] * sin_data[cache_idx]; + } else { + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } } template @@ -137,9 +158,21 @@ Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* inpu const dim3 grid(sequence_length, batch_size, num_heads); assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length, - num_heads, head_size, rotary_embedding_dim, position_ids_format, - interleaved, in_strides, out_strides); + + if (output == input) { + // In-place operation: use shared memory to avoid read-after-write hazards + size_t smem_size = head_size * sizeof(T); + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length, + num_heads, head_size, rotary_embedding_dim, position_ids_format, + interleaved, in_strides, out_strides); + } else { + // Separate buffers: no shared memory needed + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, past_sequence_lengths, sequence_length, + num_heads, head_size, rotary_embedding_dim, position_ids_format, + interleaved, in_strides, out_strides); + } return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc index e413ccf580870..f4c3eb9914118 100644 --- a/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc +++ b/onnxruntime/contrib_ops/cuda/collective/distributed_reshape.cc @@ -495,7 +495,7 @@ std::tuple ComputeRepeatAndRepeatStride( const std::vector& device_elements) { int64_t first_device_id = device_elements.at(0); int64_t first_device_id_count = 0; - for (size_t i = 0; i < device_elements.size(); ++i) { + for (size_t i = 0; i < static_cast(device_elements.size()); ++i) { if (device_elements.at(i) == first_device_id) { ++first_device_id_count; } @@ -505,8 +505,8 @@ std::tuple ComputeRepeatAndRepeatStride( // Check if the device mesh pattern is supported. // Supported examples: [0, 1, 2] and [0, 1, 0, 1, 0, 1]. // Unsupported examples: [0, 1, 2, 1, 2, 0] and [0, 1, 2, 0]. - for (size_t repeat = 0; repeat < first_device_id_count; ++repeat) { - for (size_t device_id = 0; device_id < repeat_stride; ++device_id) { + for (size_t repeat = 0; repeat < static_cast(first_device_id_count); ++repeat) { + for (size_t device_id = 0; device_id < static_cast(repeat_stride); ++device_id) { ORT_ENFORCE( device_elements.at(repeat * repeat_stride + device_id) == device_elements.at(device_id), "Unsupported device mesh pattern."); @@ -556,7 +556,7 @@ std::tuple ComputeNativeSpecForTwoAxisDecomposition( // S[0], shape=[16], device=[0, 1] -> S[0]R, shape=[4, 4], device=[0, 1] std::vector dst_axis_specs; for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { - if (src_axis != decomposed_axis_in_src) { + if (src_axis != static_cast(decomposed_axis_in_src)) { // Sharding spec is copied if the axis is not decomposed. // E.g, shape [5, 6] -> Reshape -> shape [5, 3, 2] // The spec for "5" is copied. @@ -606,7 +606,7 @@ std::tuple ComputeNativeSpecForTwoAxisDecomposition( DeviceMesh dst_device_mesh; std::tie(repeats, repeat_stride) = ComputeRepeatAndRepeatStride(src_spec.device_mesh.device_mesh_elements); for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { - if (src_axis != decomposed_axis_in_src) { + if (src_axis != static_cast(decomposed_axis_in_src)) { dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); } else if (dst_shape[decomposition_axis_in_dst] == 1) { // S[0] -> RS[0] @@ -660,7 +660,7 @@ std::tuple ComputeNativeSpecForTwoAxisDecomposition( // Source tensor is sharded on non-decomposed axis. std::vector dst_axis_specs; for (size_t src_axis = 0; src_axis < src_shape.size(); ++src_axis) { - if (src_axis != decomposed_axis_in_src) { + if (src_axis != static_cast(decomposed_axis_in_src)) { dst_axis_specs.push_back(AxisPartitionSpec::CreateCopy(src_spec.GetAxisSpec(src_axis))); } else { // R -> RR diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc index 167b2af946183..5170c982f248d 100644 --- a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc +++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc @@ -73,9 +73,9 @@ Status ShardedMoE::ComputeInternal(OpKernelContext* context) const { MoEParameters moe_params(tensor_shards_); ORT_RETURN_IF_ERROR(::onnxruntime::contrib::moe_helper::CheckInputs( moe_params, input, router_probs, - fc1_experts_weights, fc1_experts_bias_optional, nullptr, - fc2_experts_weights, fc2_experts_bias_optional, nullptr, - fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, + fc1_experts_weights, fc1_experts_bias_optional, nullptr, nullptr, + fc2_experts_weights, fc2_experts_bias_optional, nullptr, nullptr, + fc3_experts_weights_optional, fc3_experts_bias_optional, nullptr, nullptr, 1, // no quantization so pack size is 1 activation_type_ == ort_fastertransformer::ActivationType::SwiGLU, 0)); // no block-wise quantization for sharded MoE diff --git a/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h index 1fe8035cbcdae..7722cd5a84f07 100644 --- a/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h +++ b/onnxruntime/contrib_ops/cuda/llm/cutlass_type_conversion.h @@ -29,7 +29,14 @@ #if defined(ENABLE_FP4) #include "cutlass/float_subbyte.h" +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif #include +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif #endif namespace onnxruntime::llm { diff --git a/onnxruntime/contrib_ops/webgpu/moe/moe.h b/onnxruntime/contrib_ops/webgpu/moe/moe.h index 5e329dc12b5c9..332aa39a8d23e 100755 --- a/onnxruntime/contrib_ops/webgpu/moe/moe.h +++ b/onnxruntime/contrib_ops/webgpu/moe/moe.h @@ -3,6 +3,8 @@ #pragma once +#include + #include "core/providers/webgpu/program.h" #include "core/providers/webgpu/webgpu_kernel.h" @@ -31,7 +33,7 @@ class MoE : public WebGpuKernel { activation_alpha_ = static_cast(info.GetAttrOrDefault("activation_alpha", 1.0)); activation_beta_ = static_cast(info.GetAttrOrDefault("activation_beta", 1.0)); swiglu_fusion_ = static_cast(info.GetAttrOrDefault("swiglu_fusion", 0)); - swiglu_limit_ = info.GetAttrOrDefault("swiglu_limit", 0); + swiglu_limit_ = info.GetAttrOrDefault("swiglu_limit", std::numeric_limits::infinity()); k_ = static_cast(info.GetAttrOrDefault("k", 4)); normalize_routing_weights_ = info.GetAttrOrDefault("normalize_routing_weights", 0) == 1; use_sparse_mixer_ = info.GetAttrOrDefault("use_sparse_mixer", 0) == 1; diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index afea9f62419fa..00711e416e4e3 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -366,7 +366,11 @@ CPUIDInfo::CPUIDInfo() { #endif // defined(CPUINFO_SUPPORTED) // Note: This should be run after cpuinfo initialization if cpuinfo is enabled. + // On Wasm/Emscripten, cpuinfo cannot detect the CPU vendor so skip to avoid + // an unhelpful "Unknown CPU vendor" warning. +#if !defined(__wasm__) VendorInfoInit(); +#endif #ifdef CPUIDINFO_ARCH_X86 X86Init(); diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index ca9315c7ef95d..be301019df5c0 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -168,7 +168,7 @@ class CPUIDInfo { bool has_arm_sme2_{false}; std::string vendor_; - uint32_t vendor_id_; + uint32_t vendor_id_{0}; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index a656abb098911..7648aaf8f9d33 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -237,7 +237,19 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::AMD, device_id), mem_type1); } else if (strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || - strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) { + strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0 || + // PR #27207 (merged to main/1.25.x, not in 1.24.x) shortened the WebGPU/WebNN + // memory info names from "WebGPU_Buffer"/"WebNN_Tensor" to "WebGPU_Buf"/"WebNN_Ten" + // to enable Small String Optimization (SSO) on wasm32 (emscripten), where strings + // must be <= 10 chars for SSO. + // + // A WebGPU/WebNN plugin EP built against 1.25.x will use the new short names. + // Accept both old and new names here so that plugin EPs targeting either 1.24.x + // or 1.25.x can work with this 1.24.x runtime. + // + // See: https://github.com/microsoft/onnxruntime/pull/27207 + strcmp(name1, "WebGPU_Buf") == 0 || + strcmp(name1, "WebNN_Ten") == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, device_id), diff --git a/onnxruntime/core/framework/data_types.cc b/onnxruntime/core/framework/data_types.cc index 30896b37654ff..e5a8e718bd024 100644 --- a/onnxruntime/core/framework/data_types.cc +++ b/onnxruntime/core/framework/data_types.cc @@ -645,6 +645,9 @@ ORT_REGISTER_TENSOR_TYPE(Float4E2M1x2); ORT_REGISTER_TENSOR_TYPE(Int4x2); ORT_REGISTER_TENSOR_TYPE(UInt4x2); +ORT_REGISTER_TENSOR_TYPE(Int2x4); +ORT_REGISTER_TENSOR_TYPE(UInt2x4); + #if !defined(DISABLE_SPARSE_TENSORS) ORT_REGISTER_SPARSE_TENSOR_TYPE(int32_t); ORT_REGISTER_SPARSE_TENSOR_TYPE(float); @@ -708,6 +711,9 @@ ORT_REGISTER_SEQ_TENSOR_TYPE(Float8E5M2FNUZ); ORT_REGISTER_SEQ_TENSOR_TYPE(Int4x2); ORT_REGISTER_SEQ_TENSOR_TYPE(UInt4x2); +ORT_REGISTER_SEQ_TENSOR_TYPE(Int2x4); +ORT_REGISTER_SEQ_TENSOR_TYPE(UInt2x4); + #if !defined(DISABLE_ML_OPS) ORT_REGISTER_SEQ(VectorMapStringToFloat); ORT_REGISTER_SEQ(VectorMapInt64ToFloat); @@ -735,7 +741,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Float8E5M2FNUZ); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int2x4); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt2x4); #else @@ -755,7 +763,9 @@ ORT_REGISTER_SEQ(VectorMapInt64ToFloat); ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, MLFloat16); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, BFloat16); \ ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int4x2); \ - ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt4x2); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, Int2x4); \ + ORT_REGISTER_OPTIONAL_TYPE(ORT_TYPE, UInt2x4); #endif @@ -825,6 +835,8 @@ void RegisterAllProtos(const std::function& reg_fn) { #endif REGISTER_TENSOR_PROTO(Int4x2, reg_fn); REGISTER_TENSOR_PROTO(UInt4x2, reg_fn); + REGISTER_TENSOR_PROTO(Int2x4, reg_fn); + REGISTER_TENSOR_PROTO(UInt2x4, reg_fn); #if !defined(DISABLE_SPARSE_TENSORS) REGISTER_SPARSE_TENSOR_PROTO(int32_t, reg_fn); @@ -886,6 +898,8 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_SEQ_TENSOR_PROTO(Int4x2, reg_fn); REGISTER_SEQ_TENSOR_PROTO(UInt4x2, reg_fn); + REGISTER_SEQ_TENSOR_PROTO(Int2x4, reg_fn); + REGISTER_SEQ_TENSOR_PROTO(UInt2x4, reg_fn); #if !defined(DISABLE_ML_OPS) REGISTER_ONNX_PROTO(VectorMapStringToFloat, reg_fn); @@ -916,7 +930,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Float8E5M2FNUZ, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int2x4, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt2x4, reg_fn); #else @@ -936,7 +952,9 @@ void RegisterAllProtos(const std::function& reg_fn) { REGISTER_OPTIONAL_PROTO(ORT_TYPE, MLFloat16, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, BFloat16, reg_fn); \ REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int4x2, reg_fn); \ - REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt4x2, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, Int2x4, reg_fn); \ + REGISTER_OPTIONAL_PROTO(ORT_TYPE, UInt2x4, reg_fn); #endif @@ -1003,6 +1021,10 @@ const char* DataTypeImpl::ToString(MLDataType type) { return "Int4x2"; case TensorProto_DataType_UINT4: return "UInt4x2"; + case TensorProto_DataType_INT2: + return "Int2x4"; + case TensorProto_DataType_UINT2: + return "UInt2x4"; default: break; } @@ -1077,6 +1099,10 @@ const TensorTypeBase* DataTypeImpl::TensorTypeFromONNXEnum(int type) { return DataTypeImpl::GetTensorType()->AsTensorType(); case TensorProto_DataType_UINT4: return DataTypeImpl::GetTensorType()->AsTensorType(); + case TensorProto_DataType_INT2: + return DataTypeImpl::GetTensorType()->AsTensorType(); + case TensorProto_DataType_UINT2: + return DataTypeImpl::GetTensorType()->AsTensorType(); default: ORT_NOT_IMPLEMENTED("tensor type ", type, " is not supported"); @@ -1130,6 +1156,10 @@ const SequenceTensorTypeBase* DataTypeImpl::SequenceTensorTypeFromONNXEnum(int t return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); case TensorProto_DataType_UINT4: return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); + case TensorProto_DataType_INT2: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); + case TensorProto_DataType_UINT2: + return DataTypeImpl::GetSequenceTensorType()->AsSequenceTensorType(); default: ORT_NOT_IMPLEMENTED("sequence tensor type ", type, " is not supported"); @@ -1232,6 +1262,8 @@ ORT_REGISTER_PRIM_SUBBYTE_TYPE(Float4E2M1x2, 2); ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int4x2, 2); ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt4x2, 2); +ORT_REGISTER_PRIM_SUBBYTE_TYPE(Int2x4, 4); +ORT_REGISTER_PRIM_SUBBYTE_TYPE(UInt2x4, 4); namespace { template @@ -1334,6 +1366,12 @@ const std::vector& DataTypeImpl::AllTensorTypesIRv11() { return all_tensor_types; } +const std::vector& DataTypeImpl::AllTensorTypesIRv13() { + static std::vector all_tensor_types = + GetTensorTypesFromTypeList(); + return all_tensor_types; +} + const std::vector& DataTypeImpl::AllFixedSizeSequenceTensorTypes() { return AllFixedSizeSequenceTensorTypesIRv4(); } diff --git a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc index 38dd8de01147c..5137c22d6cf61 100644 --- a/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc +++ b/onnxruntime/core/framework/debug_node_inputs_outputs_utils.cc @@ -621,8 +621,8 @@ void DumpNodeInputs( std::cout << " is non-tensor type.\n"; } } else { - // this could happen with an empty Optional input - std::cout << " was missing data type\n"; + // this could happen with an empty Optional input or the tensor is removed after pre-packing. + std::cout << " was missing data type (maybe pre-packed).\n"; } } else { std::cout << "Input " << i << " is optional and was not provided.\n"; diff --git a/onnxruntime/core/framework/device_stream_collection.cc b/onnxruntime/core/framework/device_stream_collection.cc index a32973ddb8c9e..76da5702634aa 100644 --- a/onnxruntime/core/framework/device_stream_collection.cc +++ b/onnxruntime/core/framework/device_stream_collection.cc @@ -5,6 +5,8 @@ #include "core/framework/device_stream_collection.h" #include "core/framework/session_state.h" +#include + namespace onnxruntime { struct DummyNotification : public synchronize::Notification { @@ -50,7 +52,11 @@ class DeviceStreamCollectionImpl { Status CleanUp(bool sync_streams) { if (sync_streams) { - for (auto& device_stream : device_streams_) { + for (size_t i = 0, lim = device_streams_.size(); i < lim; ++i) { + Stream* device_stream = device_streams_[i]; + if (stream_override_ && i == stream_override_->first) { + device_stream = stream_override_->second; + } if (device_stream) { ORT_RETURN_IF_ERROR(device_stream->CleanUpOnRunEnd()); if (is_main_graph_) { @@ -76,11 +82,39 @@ class DeviceStreamCollectionImpl { void SetDeviceStream(size_t idx, Stream* stream) { ORT_ENFORCE(idx < num_streams_); + if (stream_override_) { + if (idx == stream_override_->first) { + ORT_THROW("Cannot set device stream for index ", idx, + " when there is an active stream override for the same index."); + } + } device_streams_[idx] = stream; } + Status SetStreamOverride(Stream* stream) { + ORT_ENFORCE(stream != nullptr); + for (size_t i = 0, lim = device_streams_.size(); i < lim; ++i) { + if (device_streams_[i] != nullptr && + // Exact match + device_streams_[i]->GetDevice() == stream->GetDevice()) { + stream_override_.emplace(i, stream); + return Status::OK(); + } + } + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "No matching stream found to override from OrtRunOptions"); + } + + void ResetStreamOverride() { + stream_override_.reset(); + } + Stream* GetStream(size_t stream_idx) const { ORT_ENFORCE(stream_idx < num_streams_); + if (stream_override_) { + if (stream_idx == stream_override_->first) { + return stream_override_->second; + } + } return device_streams_[stream_idx]; } @@ -94,6 +128,11 @@ class DeviceStreamCollectionImpl { size_t num_streams_; std::vector device_streams_; InlinedVector> owned_streams_; + // RunOptions allow specifying a stream override for a specific run. + // if this is present, it would be used as a stream for a given stream_id + // we declare it sepately as the original stream in device_streams_ should stay + // intact for future runs as we cache it in SessionState. + std::optional> stream_override_; const AllocatorMap& allocators_; bool is_main_graph_ = false; // This is used in ExecutionFrame when memory pattern is enabled, to allocate the peak size memory @@ -117,6 +156,14 @@ void DeviceStreamCollection::SetDeviceStream(size_t idx, Stream* stream) { impl_->SetDeviceStream(idx, stream); } +Status DeviceStreamCollection::SetStreamOverride(Stream* stream) { + return impl_->SetStreamOverride(stream); +} + +void DeviceStreamCollection::ResetStreamOverride() { + impl_->ResetStreamOverride(); +} + size_t DeviceStreamCollection::NumStreams() const { return impl_->NumStreams(); } @@ -140,6 +187,7 @@ DeviceStreamCollectionHolder::DeviceStreamCollectionHolder(const SessionState* s DeviceStreamCollectionHolder::~DeviceStreamCollectionHolder() { if (p_) { + p_->ResetStreamOverride(); session_state_->RecycleDeviceStreamCollection(std::move(p_)); } } diff --git a/onnxruntime/core/framework/device_stream_collection.h b/onnxruntime/core/framework/device_stream_collection.h index c76c7c731571c..34d2ecba13476 100644 --- a/onnxruntime/core/framework/device_stream_collection.h +++ b/onnxruntime/core/framework/device_stream_collection.h @@ -28,6 +28,15 @@ class DeviceStreamCollection { // a EP which doesn't support Stream, i.e. CPU based EPs. void SetDeviceStream(size_t stream_idx, Stream* stream); + // override the stream for matching device. + // only one override is allowed at a time presumably coming from + // OrtRunOptions + // returns an error if no matching stream + Status SetStreamOverride(Stream* stream); + + // Remove the override before caching/reusing the collection. + void ResetStreamOverride(); + // get the Stream instance on given stream index // The return value could be nullptr, which means the EP on this // logic sequence doesn't support Stream. diff --git a/onnxruntime/core/framework/element_type_lists.h b/onnxruntime/core/framework/element_type_lists.h index ce7c243849d5e..67358d045ba68 100644 --- a/onnxruntime/core/framework/element_type_lists.h +++ b/onnxruntime/core/framework/element_type_lists.h @@ -12,6 +12,7 @@ #include "core/common/float8.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "core/framework/float4.h" namespace onnxruntime { @@ -99,6 +100,13 @@ using AllIRv11 = using AllIRv11 = AllIRv10; #endif +// IR v13 adds INT2/UINT2 (2-bit integer types) +using AllIRv13 = + boost::mp11::mp_push_back< + AllIRv11, + UInt2x4, + Int2x4>; + // TODO: This needs upgrade to some newer version ,buit it has been // at this version for a while and it needs changes at the use sites // where-in the types in the newer IR versions are not supported. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 43caf4766d5c0..9cb2111670ba6 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -68,6 +68,7 @@ struct PartitionParams { std::reference_wrapper transform_layout_function; std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) + std::reference_wrapper on_partition_assignment_fn; }; } // namespace @@ -426,6 +427,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_fn, const layout_transformation::DebugGraphFn& debug_graph_fn, const CheckLoadCancellationFn& check_load_cancellation_fn, + const OnPartitionAssignmentFunction& on_partition_assignment_fn, const logging::Logger& logger, IResourceAccountant* resource_accountant, const GraphOptimizerRegistry& graph_optimizer_registry, bool disable_model_compile) { @@ -444,6 +446,7 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, fused_kernel_registry, current_ep, mode, fused_node_unique_id, transform_layout_fn, debug_graph_fn, check_load_cancellation_fn, + on_partition_assignment_fn, logger, resource_accountant, graph_optimizer_registry, disable_model_compile)); } @@ -518,6 +521,12 @@ static Status PartitionOnnxFormatModelImpl(Graph& graph, FuncManager& func_mgr, Node* n = nullptr; if (sub_graph_available_for_assignment) { + if (on_partition_assignment_fn) { + // Call custom function provided by owner of GraphPartitioner whenever a subgraph is assigned to an EP. + // This can be used, for example, to collect partitioning information. + on_partition_assignment_fn(graph, *capability, type); + } + n = PlaceNode(graph, *capability->sub_graph, fusion_style, type, mode, fused_node_unique_id); } @@ -1018,6 +1027,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, auto& fused_node_unique_id = partition_params.fused_node_unique_id.get(); const auto& transform_layout_function = partition_params.transform_layout_function; const CheckLoadCancellationFn& check_load_cancellation_fn = partition_params.check_load_cancellation_fn; + const OnPartitionAssignmentFunction& on_partition_assignment_fn = partition_params.on_partition_assignment_fn; do { // process full graph with each EP @@ -1034,6 +1044,7 @@ static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, transform_layout_function, partition_params.debug_graph_fn, check_load_cancellation_fn, + on_partition_assignment_fn, logger, resource_accountant, graph_optimizer_registry, disable_model_compile)); } @@ -1280,7 +1291,8 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, std::ref(*fused_kernel_registry), std::ref(fused_node_unique_id), std::cref(transform_layout_function), - std::cref(debug_graph_fn)}; + std::cref(debug_graph_fn), + std::cref(on_partition_assignment_fn_)}; #else // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1290,6 +1302,7 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, PartitionParams partition_params{ std::ref(graph), std::cref(check_load_cancellation_fn), + std::cref(on_partition_assignment_fn_), }; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index abe46cea58ab2..eb70b9f89933d 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -20,6 +20,12 @@ namespace epctx { struct ModelGenOptions; } +// OnPartitionAssignmentFunction is called by GraphPartitioner when a subgraph is assigned to +// an execution provider. Can be used to collect partitioning information. +using OnPartitionAssignmentFunction = std::function; + class GraphPartitioner { public: enum class Mode { @@ -40,11 +46,13 @@ class GraphPartitioner { GraphPartitioner(KernelRegistryManager& kernel_registry_mgr, const ExecutionProviders& providers, std::unique_ptr graph_optimizer_registry, - CheckLoadCancellationFn check_load_cancellation_fn) + CheckLoadCancellationFn check_load_cancellation_fn, + OnPartitionAssignmentFunction on_partition_assignment_fn = {}) : kernel_registry_mgr_(kernel_registry_mgr), providers_(providers), graph_optimizer_registry_(std::move(graph_optimizer_registry)), - check_load_cancellation_fn_(std::move(check_load_cancellation_fn)) { + check_load_cancellation_fn_(std::move(check_load_cancellation_fn)), + on_partition_assignment_fn_(std::move(on_partition_assignment_fn)) { } // Run partitioning. @@ -89,6 +97,7 @@ class GraphPartitioner { const ExecutionProviders& providers_; std::unique_ptr graph_optimizer_registry_; CheckLoadCancellationFn check_load_cancellation_fn_; + OnPartitionAssignmentFunction on_partition_assignment_fn_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/onnxruntime_map_type_info.cc b/onnxruntime/core/framework/onnxruntime_map_type_info.cc index 461e82d72dc83..ffeb8b5b4a193 100644 --- a/onnxruntime/core/framework/onnxruntime_map_type_info.cc +++ b/onnxruntime/core/framework/onnxruntime_map_type_info.cc @@ -87,6 +87,12 @@ ToONNXTensorElementDataType(ONNX_NAMESPACE::TensorProto_DataType data_type) { case TensorType::TensorProto_DataType_FLOAT4E2M1: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; } // maps to a pair of float4 (size == 1 byte) + case TensorType::TensorProto_DataType_INT2: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; + } // maps to 4 packed int2 values (size == 1 byte) + case TensorType::TensorProto_DataType_UINT2: { + return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; + } // maps to 4 packed uint2 values (size == 1 byte) default: { return ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } diff --git a/onnxruntime/core/framework/print_tensor_statistics_utils.h b/onnxruntime/core/framework/print_tensor_statistics_utils.h index 64d60e048a112..0f524f231f13d 100644 --- a/onnxruntime/core/framework/print_tensor_statistics_utils.h +++ b/onnxruntime/core/framework/print_tensor_statistics_utils.h @@ -4,6 +4,7 @@ #include #include "core/framework/print_tensor_utils.h" +#include "core/framework/int2.h" namespace onnxruntime { namespace utils { @@ -94,36 +95,38 @@ void PrintCommonStats(const T* data, size_t count, TensorStatisticsData& tensor_ } } -#define DEF_PRINT_COMMON_STATS_4BIT(FOUR_BIT_TYPE) \ - template <> \ - inline void PrintCommonStats( \ - const FOUR_BIT_TYPE* data, size_t count, TensorStatisticsData&) { \ - using UnpackedType = typename FOUR_BIT_TYPE::UnpackedType; \ - UnpackedType min = data[0].GetElem(0); \ - UnpackedType max = min; \ - for (size_t i = 1; i < count; i++) { \ - auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(i); \ - auto value = data[indices.first].GetElem(indices.second); \ - if (value > max) { \ - max = value; \ - } \ - if (value < min) { \ - min = value; \ - } \ - } \ - \ - std::cout << "Min="; \ - PrintValue(min); \ - \ - std::cout << ",Max="; \ - PrintValue(max); \ +#define DEF_PRINT_COMMON_STATS_PACKED(PACKED_TYPE) \ + template <> \ + inline void PrintCommonStats( \ + const PACKED_TYPE* data, size_t count, TensorStatisticsData&) { \ + using UnpackedType = typename PACKED_TYPE::UnpackedType; \ + UnpackedType min = data[0].GetElem(0); \ + UnpackedType max = min; \ + for (size_t i = 1; i < count; i++) { \ + auto indices = PACKED_TYPE::GetTensorElemIndices(i); \ + auto value = data[indices.first].GetElem(indices.second); \ + if (value > max) { \ + max = value; \ + } \ + if (value < min) { \ + min = value; \ + } \ + } \ + \ + std::cout << "Min="; \ + PrintValue(min); \ + \ + std::cout << ",Max="; \ + PrintValue(max); \ } -DEF_PRINT_COMMON_STATS_4BIT(Int4x2) -DEF_PRINT_COMMON_STATS_4BIT(UInt4x2) +DEF_PRINT_COMMON_STATS_PACKED(Int4x2) +DEF_PRINT_COMMON_STATS_PACKED(UInt4x2) #if !defined(DISABLE_FLOAT4_TYPES) -DEF_PRINT_COMMON_STATS_4BIT(Float4E2M1x2) +DEF_PRINT_COMMON_STATS_PACKED(Float4E2M1x2) #endif +DEF_PRINT_COMMON_STATS_PACKED(Int2x4) +DEF_PRINT_COMMON_STATS_PACKED(UInt2x4) template void PrintHalfStats(const T* data, size_t count) { diff --git a/onnxruntime/core/framework/print_tensor_utils.h b/onnxruntime/core/framework/print_tensor_utils.h index 47be8b8dc2057..0c0f9e2a13cbb 100644 --- a/onnxruntime/core/framework/print_tensor_utils.h +++ b/onnxruntime/core/framework/print_tensor_utils.h @@ -5,6 +5,7 @@ #include #include #include +#include "core/framework/int2.h" namespace onnxruntime { namespace utils { @@ -74,31 +75,33 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t std::cout << std::endl; } -// 4 BIT TYPE - Print snippet of 2D tensor with shape (dim0, dim1) -#define DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(FOUR_BIT_TYPE) \ - template <> \ - inline void PrintCpuTensorSnippet(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, \ - int64_t edge_items) { \ - for (int64_t i = 0; i < dim0; i++) { \ - SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ - auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t j = 1; j < dim1; j++) { \ - SKIP_NON_EDGE_ITEMS_LAST_DIM(dim1, j, edge_items); \ - std::cout << ", "; \ - indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// PACKED TYPE - Print snippet of 2D tensor with shape (dim0, dim1) +#define DEF_PRINT_CPU_TENSOR_SNIPPET_2D_PACKED(PACKED_TYPE) \ + template <> \ + inline void PrintCpuTensorSnippet(const PACKED_TYPE* tensor, int64_t dim0, int64_t dim1, \ + int64_t edge_items) { \ + for (int64_t i = 0; i < dim0; i++) { \ + SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ + auto indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t j = 1; j < dim1; j++) { \ + SKIP_NON_EDGE_ITEMS_LAST_DIM(dim1, j, edge_items); \ + std::cout << ", "; \ + indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(Int4x2) -DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(UInt4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_PACKED(Int4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_PACKED(UInt4x2) #if !defined(DISABLE_FLOAT4_TYPES) -DEF_PRINT_CPU_TENSOR_SNIPPET_2D_4BIT(Float4E2M1x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_PACKED(Float4E2M1x2) #endif +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_PACKED(Int2x4) +DEF_PRINT_CPU_TENSOR_SNIPPET_2D_PACKED(UInt2x4) // Print snippet of 3D tensor with shape (dim0, dim1, dim2) template @@ -120,35 +123,37 @@ void PrintCpuTensorSnippet(const T* tensor, int64_t dim0, int64_t dim1, int64_t std::cout << std::endl; } -// 4 BIT TYPE - Print snippet of 3D tensor with shape (dim0, dim1, dim2) -#define DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(FOUR_BIT_TYPE) \ - template <> \ - inline void PrintCpuTensorSnippet(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2, \ - int64_t edge_items) { \ - for (int64_t i = 0; i < dim0; i++) { \ - SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ - for (int64_t j = 0; j < dim1; j++) { \ - SKIP_NON_EDGE_ITEMS(dim1, j, edge_items); \ - auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t k = 1; k < dim2; k++) { \ - SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items); \ - std::cout << ", "; \ - indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// PACKED TYPE - Print snippet of 3D tensor with shape (dim0, dim1, dim2) +#define DEF_PRINT_CPU_TENSOR_SNIPPET_3D_PACKED(PACKED_TYPE) \ + template <> \ + inline void PrintCpuTensorSnippet(const PACKED_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2, \ + int64_t edge_items) { \ + for (int64_t i = 0; i < dim0; i++) { \ + SKIP_NON_EDGE_ITEMS(dim0, i, edge_items); \ + for (int64_t j = 0; j < dim1; j++) { \ + SKIP_NON_EDGE_ITEMS(dim1, j, edge_items); \ + auto indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t k = 1; k < dim2; k++) { \ + SKIP_NON_EDGE_ITEMS_LAST_DIM(dim2, k, edge_items); \ + std::cout << ", "; \ + indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(Int4x2) -DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(UInt4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_PACKED(Int4x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_PACKED(UInt4x2) #if !defined(DISABLE_FLOAT4_TYPES) -DEF_PRINT_CPU_TENSOR_SNIPPET_3D_4BIT(Float4E2M1x2) +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_PACKED(Float4E2M1x2) #endif +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_PACKED(Int2x4) +DEF_PRINT_CPU_TENSOR_SNIPPET_3D_PACKED(UInt2x4) // Print 2D tensor template @@ -164,28 +169,30 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1) { std::cout << std::endl; } -// 4 BIT TYPE - Print 2D tensor -#define DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(FOUR_BIT_TYPE) \ - template <> \ - inline void PrintCpuTensorFull(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1) { \ - for (int64_t i = 0; i < dim0; i++) { \ - auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t j = 1; j < dim1; j++) { \ - std::cout << ", "; \ - indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// PACKED TYPE - Print 2D tensor +#define DEF_PRINT_CPU_TENSOR_FULL_2D_PACKED(PACKED_TYPE) \ + template <> \ + inline void PrintCpuTensorFull(const PACKED_TYPE* tensor, int64_t dim0, int64_t dim1) { \ + for (int64_t i = 0; i < dim0; i++) { \ + auto indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t j = 1; j < dim1; j++) { \ + std::cout << ", "; \ + indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1 + j)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(Int4x2) -DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(UInt4x2) +DEF_PRINT_CPU_TENSOR_FULL_2D_PACKED(Int4x2) +DEF_PRINT_CPU_TENSOR_FULL_2D_PACKED(UInt4x2) #if !defined(DISABLE_FLOAT4_TYPES) -DEF_PRINT_CPU_TENSOR_FULL_2D_4BIT(Float4E2M1x2) +DEF_PRINT_CPU_TENSOR_FULL_2D_PACKED(Float4E2M1x2) #endif +DEF_PRINT_CPU_TENSOR_FULL_2D_PACKED(Int2x4) +DEF_PRINT_CPU_TENSOR_FULL_2D_PACKED(UInt2x4) // Print 3D tensor template @@ -204,31 +211,33 @@ void PrintCpuTensorFull(const T* tensor, int64_t dim0, int64_t dim1, int64_t dim std::cout << std::endl; } -// 4 BIT TYPE - Print 3D tensor -#define DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(FOUR_BIT_TYPE) \ - template <> \ - inline void PrintCpuTensorFull(const FOUR_BIT_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2) { \ - for (int64_t i = 0; i < dim0; i++) { \ - for (int64_t j = 0; j < dim1; j++) { \ - auto indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - for (int64_t k = 1; k < dim2; k++) { \ - std::cout << ", "; \ - indices = FOUR_BIT_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ - PrintValue(tensor[indices.first].GetElem(indices.second)); \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ - } \ - std::cout << std::endl; \ +// PACKED TYPE - Print 3D tensor +#define DEF_PRINT_CPU_TENSOR_FULL_3D_PACKED(PACKED_TYPE) \ + template <> \ + inline void PrintCpuTensorFull(const PACKED_TYPE* tensor, int64_t dim0, int64_t dim1, int64_t dim2) { \ + for (int64_t i = 0; i < dim0; i++) { \ + for (int64_t j = 0; j < dim1; j++) { \ + auto indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + for (int64_t k = 1; k < dim2; k++) { \ + std::cout << ", "; \ + indices = PACKED_TYPE::GetTensorElemIndices(static_cast(i * dim1 * dim2 + j * dim2 + k)); \ + PrintValue(tensor[indices.first].GetElem(indices.second)); \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ + } \ + std::cout << std::endl; \ } -DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(Int4x2) -DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(UInt4x2) +DEF_PRINT_CPU_TENSOR_FULL_3D_PACKED(Int4x2) +DEF_PRINT_CPU_TENSOR_FULL_3D_PACKED(UInt4x2) #if !defined(DISABLE_FLOAT4_TYPES) -DEF_PRINT_CPU_TENSOR_FULL_3D_4BIT(Float4E2M1x2) +DEF_PRINT_CPU_TENSOR_FULL_3D_PACKED(Float4E2M1x2) #endif +DEF_PRINT_CPU_TENSOR_FULL_3D_PACKED(Int2x4) +DEF_PRINT_CPU_TENSOR_FULL_3D_PACKED(UInt2x4) template void PrintCpuTensor(const onnxruntime::Tensor& tensor, diff --git a/onnxruntime/core/framework/run_options.cc b/onnxruntime/core/framework/run_options.cc index 0a2bb9507ac85..45635e973d09d 100644 --- a/onnxruntime/core/framework/run_options.cc +++ b/onnxruntime/core/framework/run_options.cc @@ -58,6 +58,10 @@ ORT_API_STATUS_IMPL(OrtApis::RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* op return nullptr; } +ORT_API(void, OrtApis::RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream) { + options->sync_stream = sync_stream; +} + ORT_API_STATUS_IMPL(OrtApis::AddRunConfigEntry, _Inout_ OrtRunOptions* options, _In_z_ const char* config_key, _In_z_ const char* config_value) { return onnxruntime::ToOrtStatus(options->config_options.AddConfigEntry(config_key, config_value)); diff --git a/onnxruntime/core/framework/tensor_type_and_shape.cc b/onnxruntime/core/framework/tensor_type_and_shape.cc index cbf1a953819d3..16817ba1707bd 100644 --- a/onnxruntime/core/framework/tensor_type_and_shape.cc +++ b/onnxruntime/core/framework/tensor_type_and_shape.cc @@ -229,6 +229,12 @@ constexpr ONNXTensorElementDataType TensorDataTypeToOnnxRuntimeTensorElementData case o::TensorProto_DataType_FLOAT4E2M1: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT4E2M1; break; + case o::TensorProto_DataType_INT2: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; + break; + case o::TensorProto_DataType_UINT2: + type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; + break; default: type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; break; @@ -304,6 +310,64 @@ std::unique_ptr OrtTensorTypeAndShapeInfo::GetTensorS return GetTensorShapeAndTypeHelper(type, shape, dim_params); } +ORT_API_STATUS_IMPL(OrtApis::GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count) { + API_IMPL_BEGIN + if (!value->IsAllocated() || (!value->IsTensor() && !value->IsSparseTensor())) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Input parameter `value` must contain a constructed tensor or sparse tensor"); + } + + if (elem_type == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `elem_type` must not be NULL"); + } + + if (shape_data == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `shape_data` must not be NULL"); + } + + if (shape_data_count == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Output parameter `shape_data_count` must not be NULL"); + } + + gsl::span shape_span; + onnxruntime::MLDataType ml_data_type = nullptr; + ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + + if (value->IsTensor()) { + const Tensor& tensor = value->Get(); + ml_data_type = tensor.DataType(); + shape_span = tensor.Shape().GetDims(); + } else { +#if !defined(DISABLE_SPARSE_TENSORS) + const SparseTensor& tensor = value->Get(); + ml_data_type = tensor.DataType(); + shape_span = tensor.DenseShape().GetDims(); +#else + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "SparseTensor is not supported in this build."); +#endif + } + + if (ml_data_type != nullptr) { + type = MLDataTypeToOnnxRuntimeTensorElementDataType(ml_data_type); + } + + if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) { + return OrtApis::CreateStatus(ORT_FAIL, "Tensor does not have a valid or supported tensor element data type"); + } + + *elem_type = type; + *shape_data = shape_span.empty() ? nullptr : shape_span.data(); + *shape_data_count = shape_span.size(); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::GetTensorTypeAndShape, _In_ const OrtValue* v, _Outptr_ OrtTensorTypeAndShapeInfo** out) { API_IMPL_BEGIN diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 0f5622ec2ed45..e4c7830ffbb55 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -73,6 +73,21 @@ TensorProto ToScalarTensor(TensorProto_DataType datatype, int32_t value) { return t; \ } +// 2-bit types use the same storage pattern as 4-bit types +#define TO_TENSOR_ORT_TYPE_2BIT_TYPE(TYPE) \ + template <> \ + TensorProto ToTensor(const onnxruntime::TYPE& value) { \ + return ToScalarTensor(ToTensorProtoElementType(), static_cast(value.ToBits())); \ + } \ + template <> \ + TensorProto ToTensor(const std::vector& values) { \ + TensorProto t = ToTensorInitialize(ToTensorProtoElementType()); \ + for (const onnxruntime::TYPE& val : values) { \ + t.add_int32_data(static_cast(val.ToBits())); \ + } \ + return t; \ + } + namespace ONNX_NAMESPACE { // Provide template specializations for onnxruntime-specific types. @@ -90,6 +105,9 @@ TO_TENSOR_ORT_TYPE_4BIT_TYPE(Float4E2M1x2) TO_TENSOR_ORT_TYPE_4BIT_TYPE(Int4x2) TO_TENSOR_ORT_TYPE_4BIT_TYPE(UInt4x2) +TO_TENSOR_ORT_TYPE_2BIT_TYPE(Int2x4) +TO_TENSOR_ORT_TYPE_2BIT_TYPE(UInt2x4) + bool operator==(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, const ONNX_NAMESPACE::TensorShapeProto_Dimension& r) { if (l.has_dim_value()) { @@ -167,6 +185,10 @@ Status UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int4x2, CalcNumInt4Pairs) DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt4x2, CalcNumInt4Pairs) +// 2-bit types use the same pattern - CalcNumInt2Quads gives number of packed bytes +DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Int2x4, CalcNumInt2Quads) +DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(UInt2x4, CalcNumInt2Quads) + #if !defined(DISABLE_FLOAT4_TYPES) DEFINE_4BIT_UNPACK_TENSOR_WITH_RAW_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs) #endif @@ -307,7 +329,8 @@ Status TensorProtoWithExternalDataToTensorProto( } Status ValidateExternalDataPath(const std::filesystem::path& base_dir, - const std::filesystem::path& location) { + const std::filesystem::path& location, + const std::filesystem::path& model_path) { // Reject absolute paths ORT_RETURN_IF(location.is_absolute(), "Absolute paths not allowed for external data location"); @@ -315,14 +338,54 @@ Status ValidateExternalDataPath(const std::filesystem::path& base_dir, // Resolve and verify the path stays within model directory auto base_canonical = std::filesystem::weakly_canonical(base_dir); // If the symlink exists, it resolves to the target path; - // so if the symllink is outside the directory it would be caught here. + // so if the symlink is outside the directory it would be caught here. auto resolved = std::filesystem::weakly_canonical(base_dir / location); + // Check that resolved path starts with base directory auto [base_end, resolved_it] = std::mismatch( base_canonical.begin(), base_canonical.end(), resolved.begin(), resolved.end()); - ORT_RETURN_IF(base_end != base_canonical.end(), - "External data path: ", location, " escapes model directory: ", base_dir); + + if (base_end != base_canonical.end()) { + // If validation against logical base_dir fails, we check against the + // real (canonical) path of the model file to support symlinked models + // (e.g. models in Hugging Face Hub local cache). + if (!model_path.empty()) { + auto real_model_dir = std::filesystem::weakly_canonical(model_path).parent_path(); + + auto [real_base_end, real_resolved_it] = std::mismatch( + real_model_dir.begin(), real_model_dir.end(), + resolved.begin(), resolved.end()); + + if (real_base_end == real_model_dir.end()) { + return Status::OK(); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "External data path: ", location, " (resolved path: ", resolved, + ") escapes both model directory: ", base_dir, + " and real model directory: ", real_model_dir); + } + + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "External data path: ", location, " (resolved path: ", resolved, + ") escapes model directory: ", base_dir); + } + } else { + // The basedir is empty, which occurs when 1) the session loads a model from bytes and 2) the application does not + // set an external file folder path via the session config option + // `kOrtSessionOptionsModelExternalInitializersFileFolderPath`. + + // We conservatively check that the normalized relative path does not contain ".." path components that would allow + // access to arbitrary files outside of the current working directory. Based on ONNX checker validation. + auto norm_location = location.lexically_normal(); + + for (const auto& path_component : norm_location) { + if (path_component == ORT_TSTR("..")) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "External data path: ", location, + " (model loaded from bytes) escapes working directory"); + } + } } return Status::OK(); } @@ -395,6 +458,8 @@ void ConvertRawDataInTensorProto(TensorProto& tensor) { {TensorProto_DataType_FLOAT8E5M2FNUZ, sizeof(uint8_t)}, {TensorProto_DataType_UINT4, sizeof(uint8_t)}, {TensorProto_DataType_INT4, sizeof(uint8_t)}, + {TensorProto_DataType_UINT2, sizeof(uint8_t)}, + {TensorProto_DataType_INT2, sizeof(uint8_t)}, }; auto pos = tensorproto_data_size.find(tensor.data_type()); if (pos == tensorproto_data_size.end()) { @@ -418,6 +483,8 @@ void ConvertRawDataInTensorProto(TensorProto& tensor) { case TensorProto_DataType_BOOL: case TensorProto_DataType_UINT4: case TensorProto_DataType_INT4: + case TensorProto_DataType_UINT2: + case TensorProto_DataType_INT2: case TensorProto_DataType_UINT8: case TensorProto_DataType_INT8: case TensorProto_DataType_UINT16: @@ -515,6 +582,10 @@ Status UnpackTensorWithExternalData(const ONNX_NAMESPACE::TensorProto& tensor, DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int4x2, CalcNumInt4Pairs) DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt4x2, CalcNumInt4Pairs) +// 2-bit types +DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Int2x4, CalcNumInt2Quads) +DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(UInt2x4, CalcNumInt2Quads) + #if !defined(DISABLE_FLOAT4_TYPES) DEFINE_4BIT_UNPACK_TENSOR_WITH_EXT_DATA_IMPL(Float4E2M1x2, CalcNumFloat4Pairs) #endif @@ -899,6 +970,41 @@ DEFINE_INT4_UNPACK_TENSOR_IMPL(Int4x2, TensorProto_DataType_INT4) // UnpackTensor DEFINE_INT4_UNPACK_TENSOR_IMPL(UInt4x2, TensorProto_DataType_UINT4) +// 2-bit type unpack implementation +#define DEFINE_INT2_UNPACK_TENSOR_IMPL(INT2_TYPE, ONNX_INT2_TYPE) \ + template <> \ + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT2_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + return size == 0 ? Status::OK() : Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_INT2_TYPE != tensor.data_type()) { \ + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int2_quads = INT2_TYPE::CalcNumInt2Quads(expected_num_elems); \ + \ + if (raw_data != nullptr) { \ + return UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + } \ + \ + ORT_RETURN_IF_NOT(static_cast(tensor.int32_data_size()) == expected_int2_quads, \ + "UnpackTensor: the pre-allocated size does not match the size in proto"); \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT2_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + \ + return Status::OK(); \ + } + +// UnpackTensor +DEFINE_INT2_UNPACK_TENSOR_IMPL(Int2x4, TensorProto_DataType_INT2) + +// UnpackTensor +DEFINE_INT2_UNPACK_TENSOR_IMPL(UInt2x4, TensorProto_DataType_UINT2) + #if !defined(DISABLE_FLOAT4_TYPES) template <> @@ -985,6 +1091,9 @@ INSTANTIATE_UNPACK_TENSOR(Float8E5M2FNUZ) INSTANTIATE_UNPACK_TENSOR(Int4x2) INSTANTIATE_UNPACK_TENSOR(UInt4x2) +INSTANTIATE_UNPACK_TENSOR(Int2x4) +INSTANTIATE_UNPACK_TENSOR(UInt2x4) + #define CASE_PROTO_TRACE(X, Y) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ if (!IAllocator::CalcMemSizeForArrayWithAlignment(size, sizeof(Y), out)) { \ @@ -1008,6 +1117,14 @@ INSTANTIATE_UNPACK_TENSOR(UInt4x2) break; #endif +// 2-bit types +#define CASE_PROTO_TRACE_INT2(X, Y) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!IAllocator::CalcMemSizeForArrayWithAlignment(Y::CalcNumInt2Quads(size), sizeof(Y), out)) { \ + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Invalid TensorProto"); \ + } \ + break; + template common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, int32_t element_type, size_t* out) { const auto size = narrow(shape.Size()); @@ -1034,6 +1151,8 @@ common::Status GetSizeInBytesFromTensorShapeAndType(const TensorShape& shape, in #endif CASE_PROTO_TRACE_INT4(UINT4, UInt4x2); CASE_PROTO_TRACE_INT4(INT4, Int4x2); + CASE_PROTO_TRACE_INT2(UINT2, UInt2x4); + CASE_PROTO_TRACE_INT2(INT2, Int2x4); #if !defined(DISABLE_FLOAT4_TYPES) CASE_PROTO_TRACE_FLOAT4(FLOAT4E2M1, Float4E2M1x2); @@ -1428,6 +1547,8 @@ Status TensorProtoToTensor(const Env& env, const std::filesystem::path& model_pa #endif CASE_PROTO(INT4, Int4x2); CASE_PROTO(UINT4, UInt4x2); + CASE_PROTO(INT2, Int2x4); + CASE_PROTO(UINT2, UInt2x4); #if !defined(DISABLE_FLOAT4_TYPES) CASE_PROTO(FLOAT4E2M1, Float4E2M1x2); @@ -1513,6 +1634,8 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { #endif CASE_TYPE(UINT4) CASE_TYPE(INT4) + CASE_TYPE(UINT2) + CASE_TYPE(INT2) #if !defined(DISABLE_FLOAT4_TYPES) CASE_TYPE(FLOAT4E2M1) @@ -1659,117 +1782,140 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) { } #if !defined(DISABLE_SPARSE_TENSORS) -static Status CopySparseData(size_t n_sparse_elements, +static Status CopySparseData(const std::string& name, + int64_t nnz_elements, const ONNX_NAMESPACE::TensorProto& indices, const std::filesystem::path& model_path, - gsl::span - dims, - std::function - copier) { + gsl::span dense_dims, + int64_t dense_elements, + std::function copier) { Status status = Status::OK(); TensorShape indices_shape(indices.dims().data(), indices.dims().size()); - const auto elements = narrow(indices_shape.Size()); + const int64_t indices_elements = indices_shape.Size(); - std::vector indices_values; // used for conversion of smaller size indices + InlinedVector indices_values; // used for conversion of smaller size indices std::vector unpack_buffer; gsl::span indices_data; - const bool has_raw_data = indices.has_raw_data(); + const bool needs_unpack = utils::HasRawData(indices) || utils::HasExternalData(indices); switch (indices.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_INT64: - if (has_raw_data) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == (elements * sizeof(int64_t)), - "Sparse Indices raw data size does not match expected."); + if (needs_unpack) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == (narrow(indices_elements) * sizeof(int64_t)), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int64_t)); ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); indices_data = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); } else { - ORT_RETURN_IF_NOT(indices.int64_data_size() == static_cast(elements), - "Sparse indices int64 data size does not match expected"); - indices_data = gsl::make_span(indices.int64_data().data(), elements); + ORT_RETURN_IF_NOT(indices.int64_data_size() == indices_elements, + "Sparse tensor: ", name, " indices int64 data size does not match expected: ", + indices_elements); + indices_data = gsl::make_span(indices.int64_data().data(), narrow(indices_elements)); } break; case ONNX_NAMESPACE::TensorProto_DataType_INT32: { - if (has_raw_data) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == (elements * sizeof(int32_t)), - "Sparse Indices raw data size does not match expected."); + if (needs_unpack) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == (narrow(indices_elements) * sizeof(int32_t)), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int32_t)); ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); auto int32_span = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); indices_values.insert(indices_values.cend(), int32_span.begin(), int32_span.end()); unpack_buffer.clear(); unpack_buffer.shrink_to_fit(); } else { - ORT_RETURN_IF_NOT(indices.int32_data_size() == static_cast(elements), - "Sparse indices int32 data size does not match expected"); + ORT_RETURN_IF_NOT(indices.int32_data_size() == indices_elements, + "Sparse tensor: ", name, " indices int32 data size does not match expected: ", + indices_elements); indices_values.insert(indices_values.cend(), indices.int32_data().cbegin(), indices.int32_data().cend()); } indices_data = gsl::make_span(indices_values); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT16: { - if (has_raw_data) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == (elements * sizeof(int16_t)), - "Sparse Indices raw data size does not match expected."); + if (needs_unpack) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == (narrow(indices_elements) * sizeof(int16_t)), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int16_t)); ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); auto int16_span = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); indices_values.insert(indices_values.cend(), int16_span.begin(), int16_span.end()); - indices_data = gsl::make_span(indices_values); unpack_buffer.clear(); unpack_buffer.shrink_to_fit(); } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, - "Invalid SparseTensor indices. INT16 indices must be in the raw data of indices tensor"); + ORT_RETURN_IF_NOT(indices.int32_data_size() == indices_elements, + "Sparse tensor: ", name, " indices int16 data size does not match expected: ", + indices_elements); + indices_values.insert(indices_values.cend(), indices.int32_data().cbegin(), indices.int32_data().cend()); } + indices_data = gsl::make_span(indices_values); break; } case ONNX_NAMESPACE::TensorProto_DataType_INT8: { - if (has_raw_data) { - ORT_RETURN_IF_NOT(indices.raw_data().size() == elements, - "Sparse Indices raw data size does not match expected."); + if (needs_unpack) { + ORT_RETURN_IF_NOT(indices.raw_data().size() == narrow(indices_elements), + "Sparse tensor: ", name, " indices raw data size does not match expected: ", + indices_elements * sizeof(int8_t)); ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer)); auto int8_span = ReinterpretAsSpan(gsl::make_span(unpack_buffer)); indices_values.insert(indices_values.cend(), int8_span.begin(), int8_span.end()); - indices_data = gsl::make_span(indices_values); unpack_buffer.clear(); unpack_buffer.shrink_to_fit(); } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, - "Invalid SparseTensor indices. INT8 indices must be in the raw data of indices tensor"); + ORT_RETURN_IF_NOT(indices.int32_data_size() == indices_elements, + "Sparse tensor: ", name, " indices int8 data size does not match expected: ", + indices_elements); + indices_values.insert(indices_values.cend(), indices.int32_data().cbegin(), indices.int32_data().cend()); } + indices_data = gsl::make_span(indices_values); break; } default: return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_GRAPH, - "Invalid SparseTensor indices. Should one of the following types: int8, int16, int32 or int64"); + "Sparse tensor: ", name, " indices. Should be one of the following types: int8, int16, int32 or int64"); } - if (indices_shape.NumDimensions() == 1) { + const auto indices_rank = indices_shape.NumDimensions(); + if (indices_rank == 1) { // flattened indexes - for (size_t i = 0; i < n_sparse_elements; ++i) { - copier(i, narrow(indices_data[i])); + for (size_t i = 0, lim = narrow(nnz_elements); i < lim; ++i) { + const auto idx = indices_data[i]; + ORT_RETURN_IF_NOT(idx >= 0 && idx < dense_elements, + "Sparse tensor: ", name, " index is out of bounds. Got:", idx, + " expected to be in [0, ", dense_elements, ")"); + + copier(i, narrow(idx)); } - } else if (indices_shape.NumDimensions() == 2) { + } else if (indices_rank == 2) { // entries in format {NNZ, rank} - ORT_ENFORCE(indices_shape[1] > 0 && static_cast(indices_shape[1]) == dims.size()); - auto rank = static_cast(indices_shape[1]); + ORT_ENFORCE(indices_shape[1] > 0 && static_cast(indices_shape[1]) == dense_dims.size()); + const auto rank = static_cast(indices_shape[1]); auto cur_index = indices_data.begin(); - std::vector multipliers; + InlinedVector multipliers; multipliers.resize(rank); // calculate sum of inner dimension elements for each dimension. // e.g. if shape {2,3,4}, the result should be {3*4, 4, 1} multipliers[rank - 1] = 1; for (auto r = rank - 1; r > 0; --r) { - multipliers[r - 1] = SafeInt(dims[r]) * multipliers[r]; + multipliers[r - 1] = SafeInt(dense_dims[r]) * multipliers[r]; } // calculate the offset for the entry // e.g. if shape was {2,3,4} and entry was (1, 0, 2) the offset is 14 // as there are 2 rows, each with 12 entries per row - for (size_t i = 0; i < n_sparse_elements; ++i) { + for (size_t i = 0, lim = narrow(nnz_elements); i < lim; ++i) { SafeInt idx = 0; for (size_t j = 0; j < rank; ++j) { - idx += SafeInt(cur_index[j]) * multipliers[j]; + const auto dim_index = cur_index[j]; + ORT_RETURN_IF_NOT(dim_index >= 0 && dim_index < dense_dims[j], + "Sparse tensor: ", name, " index is out of bounds. Got:", dim_index, + " expected to be in [0, ", dense_dims[j], ")"); + idx += SafeInt(dim_index) * multipliers[j]; } + ORT_RETURN_IF_NOT(idx >= 0 && idx < dense_elements, + "Sparse tensor: ", name, " index is out of bounds. Got:", static_cast(idx), + " expected to be in [0, ", dense_elements, ")"); copier(i, static_cast(idx)); cur_index += rank; @@ -1778,7 +1924,7 @@ static Status CopySparseData(size_t n_sparse_elements, ORT_ENFORCE(cur_index == indices_data.end()); } else { status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, - "Invalid SparseTensor indices. Should be rank 0 or 1. Got:", indices_shape); + "Sparse tensor: ", name, " indices shape. Expected to be rank 1 or 2. Got:", indices_shape); } return status; @@ -1787,53 +1933,110 @@ static Status CopySparseData(size_t n_sparse_elements, common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& dense) { - Status status = Status::OK(); + Status status; const auto& sparse_values = sparse.values(); - auto type = sparse_values.data_type(); - dense.set_data_type(type); - *dense.mutable_name() = sparse_values.name(); + const auto& name = sparse_values.name(); - SafeInt n_sparse_elements = 1; - for (auto dim : sparse_values.dims()) { - n_sparse_elements *= dim; + const auto values_rank = sparse_values.dims_size(); + if (values_rank != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, " values should be rank 1 for COO format. Got:", values_rank); } - SafeInt n_dense_elements = 1; + auto type = sparse_values.data_type(); + dense.set_data_type(type); + *dense.mutable_name() = name; + SafeInt dense_elements = 1; + for (auto dim : sparse.dims()) { - n_dense_elements *= dim; + if (dim < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, " dense dims expected to be non-negative. Got:", dim); + } + dense_elements *= dim; dense.add_dims(dim); } + const auto dense_dims = gsl::make_span(dense.dims().data(), dense.dims().size()); + + SafeInt nnz_elements = 1; + for (auto dim : sparse_values.dims()) { + if (dim < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, " tensor dims expected to be non-negative. Got:", dim); + } + nnz_elements *= dim; + } + const auto& indices = sparse.indices(); - auto dims = gsl::make_span(dense.dims().data(), dense.dims().size()); + const auto indices_rank = indices.dims_size(); + if (indices_rank != 1 && indices_rank != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, " indices should be rank 1 or 2 for supported COO format. Got:", indices_rank); + } - if (type != TensorProto_DataType_STRING) { - auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); - size_t element_size = ml_data->Size(); + const auto indices_dims = gsl::make_span(indices.dims().data(), indices.dims().size()); + + if (indices_dims[0] != nnz_elements) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, + " indices outer dimension should match the number of non-zero values. Got:", + indices_dims[0], " expected: ", static_cast(nnz_elements)); + } - // need to read in sparse data first as it could be in a type specific field, in raw data, or in external data - std::vector sparse_data_storage; - ORT_RETURN_IF_ERROR(UnpackInitializerData(sparse_values, model_path, sparse_data_storage)); - void* sparse_data = sparse_data_storage.data(); + if (indices_rank == 2 && dense_dims.size() != narrow(indices_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, + " indices is rank 2, its inner dimension should match the rank of the dense tensor. Got:", + indices_dims[1], " expected: ", dense_dims.size()); + } + + if (indices_rank == 2) { + const auto num_indices = TensorShape(indices_dims).Size(); + const int64_t expected_indices_entries = SafeInt(nnz_elements) * indices_dims[1]; + if (num_indices != expected_indices_entries) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, + "Sparse tensor: ", name, + " indices is rank 2, it should have NNZ values * indices_dims[1] entries. Got:", + num_indices, " expected: ", expected_indices_entries); + } + } + + if (dense_elements == 0) { + // if there are no elements in the dense tensor, we can return early with an empty tensor proto + return status; + } + + if (type != ONNX_NAMESPACE::TensorProto_DataType_STRING) { + auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType(); + const size_t element_size = ml_data->Size(); // by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move // into the TensorProto. - std::string dense_data_storage(n_dense_elements * element_size, 0); - if (n_sparse_elements > 0) { + std::string dense_data_storage(narrow(dense_elements) * element_size, 0); + if (nnz_elements > 0) { + // need to read in sparse data first as it could be in a type specific field, in raw data, or in external data + std::vector values_data; + ORT_RETURN_IF_ERROR(UnpackInitializerData(sparse_values, model_path, values_data)); + ORT_RETURN_IF_NOT(values_data.size() == static_cast(nnz_elements) * element_size, + "Sparse tensor: ", name, " values data size does not match expected: ", + static_cast(nnz_elements) * element_size); + void* sparse_data = values_data.data(); void* dense_data = dense_data_storage.data(); switch (element_size) { case 1: { status = CopySparseData( - n_sparse_elements, indices, model_path, dims, [sparse_data, dense_data](size_t from_idx, size_t to_idx) { + name, nnz_elements, indices, model_path, dense_dims, dense_elements, + [sparse_data, dense_data](size_t from_idx, size_t to_idx) { static_cast(dense_data)[to_idx] = static_cast(sparse_data)[from_idx]; }); break; } case 2: { - status = CopySparseData(n_sparse_elements, indices, model_path, dims, + status = CopySparseData(name, nnz_elements, indices, model_path, dense_dims, dense_elements, [sparse_data, dense_data](size_t from_idx, size_t to_idx) { const auto* src = static_cast(sparse_data) + from_idx; auto* dst = static_cast(dense_data) + to_idx; @@ -1843,7 +2046,7 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT break; } case 4: { - status = CopySparseData(n_sparse_elements, indices, model_path, dims, + status = CopySparseData(name, nnz_elements, indices, model_path, dense_dims, dense_elements, [sparse_data, dense_data](size_t from_idx, size_t to_idx) { const auto* src = static_cast(sparse_data) + from_idx; auto* dst = static_cast(dense_data) + to_idx; @@ -1853,7 +2056,7 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT break; } case 8: { - status = CopySparseData(n_sparse_elements, indices, model_path, dims, + status = CopySparseData(name, nnz_elements, indices, model_path, dense_dims, dense_elements, [sparse_data, dense_data](size_t from_idx, size_t to_idx) { const auto* src = static_cast(sparse_data) + from_idx; auto* dst = static_cast(dense_data) + to_idx; @@ -2073,11 +2276,14 @@ template common::Status GetSizeInBytesFromTensorProto<0>(const ONNX_NAMESPACE::T break; \ } -#define CASE_UNPACK_4BIT_TYPE(TYPE, ELEMENT_TYPE, DATA_SIZE, CALC_PAIR_FUN) \ +// Sub-byte types (2-bit and 4-bit) are stored in a packed format. +// This unpacking code is shared for INT4, UINT4, FLOAT4E2M1, INT2, and UINT2. +// CALC_PACKED_UNITS_FUN specifies the function to calculate packed byte count from element count. +#define CASE_UNPACK_SUBBYTE_TYPE(TYPE, ELEMENT_TYPE, DATA_SIZE, CALC_PACKED_UNITS_FUN) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##TYPE: { \ TensorShape tensor_shape = GetTensorShapeFromTensorProto(initializer); \ size_t element_count = static_cast(tensor_shape.Size()); \ - size_t packed_element_count = ELEMENT_TYPE::CALC_PAIR_FUN(element_count); \ + size_t packed_element_count = ELEMENT_TYPE::CALC_PACKED_UNITS_FUN(element_count); \ unpacked_tensor.resize(packed_element_count * sizeof(ELEMENT_TYPE)); \ return onnxruntime::utils::UnpackTensor(initializer, \ initializer.has_raw_data() ? initializer.raw_data().data() : nullptr, \ @@ -2120,11 +2326,13 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, CASE_UNPACK(FLOAT8E5M2, onnxruntime::Float8E5M2, int32_data_size); CASE_UNPACK(FLOAT8E5M2FNUZ, onnxruntime::Float8E5M2FNUZ, int32_data_size); #endif - CASE_UNPACK_4BIT_TYPE(INT4, Int4x2, int32_data_size, CalcNumInt4Pairs); - CASE_UNPACK_4BIT_TYPE(UINT4, UInt4x2, int32_data_size, CalcNumInt4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(INT4, Int4x2, int32_data_size, CalcNumInt4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(UINT4, UInt4x2, int32_data_size, CalcNumInt4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(INT2, Int2x4, int32_data_size, CalcNumInt2Quads); + CASE_UNPACK_SUBBYTE_TYPE(UINT2, UInt2x4, int32_data_size, CalcNumInt2Quads); #if !defined(DISABLE_FLOAT4_TYPES) - CASE_UNPACK_4BIT_TYPE(FLOAT4E2M1, Float4E2M1x2, int32_data_size, CalcNumFloat4Pairs); + CASE_UNPACK_SUBBYTE_TYPE(FLOAT4E2M1, Float4E2M1x2, int32_data_size, CalcNumFloat4Pairs); #endif default: diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 8c9f64e9fbb9f..941cd9af34b61 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -249,10 +249,14 @@ common::Status ConstantNodeProtoToTensorProto(const ONNX_NAMESPACE::NodeProto& n void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor); #if !defined(DISABLE_SPARSE_TENSORS) -// Convert a SparseTensorProto to a dense TensorProto -// If the SparseTensorProto contains external data then it loads the data and converts to dense tensor proto -// The resulting TensorProto will contain the data as raw data. -// model_path is used for constructing full path for external_data +/// +// The function supports only COO format with 1D or 2D indices. Values shape is expected to be 1D. +// The function does not support sparse tensors of other formats like CSR/CSC. +/// +/// +/// model path is only used if there are references to external data. +/// The resulting dense tensor proto. +/// Status common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse, const std::filesystem::path& model_path, ONNX_NAMESPACE::TensorProto& dense); @@ -522,16 +526,19 @@ Status TensorProtoWithExternalDataToTensorProto( ONNX_NAMESPACE::TensorProto& new_tensor_proto); /// -/// The functions will make sure the 'location' specified in the external data is under the 'base_dir'. +/// Validates if the external data path is under the model directory. +/// If the model is a symlink, it checks against both the logical model directory (base_dir) +/// and the real/canonical directory of the model. /// If the `base_dir` is empty, the function only ensures that `location` is not an absolute path. /// -/// model location directory -/// location is a string retrieved from TensorProto external data that is not -/// an in-memory tag -/// The function will fail if the resolved full path is not under the model directory -/// or one of the subdirectories +/// Logical model location directory +/// Location string retrieved from TensorProto external data +/// Optional path to the model file, used for canonical path validation if base_dir check fails +/// The function will fail if the resolved full path is not under the logical model directory +/// nor the real directory of the model path Status ValidateExternalDataPath(const std::filesystem::path& base_dir, - const std::filesystem::path& location); + const std::filesystem::path& location, + const std::filesystem::path& model_path = {}); #endif // !defined(SHARED_PROVIDER) diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 4b4c483ba1202..5eed13ec1073c 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -227,6 +227,16 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; +} + +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; +} + #if !defined(DISABLE_FLOAT4_TYPES) template <> constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 1864bfde31d22..d30c7cd74a76a 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -368,7 +368,7 @@ Status EpValueInfo::GetProducerInfo(OrtValueInfo::ProducerInfo& producer_info) c producer_info.output_index = 0; if (graph_ == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unable to get producer node for OrtValueInfo '", name_, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_FOUND, "Unable to get producer node for OrtValueInfo '", name_, "' that is not owned by a OrtGraph."); } @@ -379,7 +379,15 @@ Status EpValueInfo::GetProducerInfo(OrtValueInfo::ProducerInfo& producer_info) c const EpNode* ep_node = graph_->GetNode(node->Index()); if (ep_node == nullptr) { - return Status::OK(); // Node is not in this GraphViewer + producer_info.node = nullptr; + producer_info.output_index = 0; +#if !defined(ORT_MINIMAL_BUILD) + const auto& logger = graph_->GetGraphViewer().GetGraph().GetLogger(); + LOGS(logger, WARNING) << "Unable to get producer node for OrtValueInfo '" + << name_ + << "' that is not owned by an OrtGraph."; +#endif // !defined(ORT_MINIMAL_BUILD) + return Status::OK(); } size_t output_index = 0; @@ -543,6 +551,9 @@ void EpGraph::IndexToEpNodeMap::Resize(NodeIndex min_node_index, NodeIndex max_n } EpNode* EpGraph::IndexToEpNodeMap::GetEpNode(NodeIndex node_index) const { + if (node_index < min_node_index_ || node_index > (min_node_index_ + nodes_.size() - 1)) { + return nullptr; + } size_t i = node_index - min_node_index_; assert(i < nodes_.size()); return nodes_[i]; @@ -566,10 +577,10 @@ EpGraph::EpGraph(std::unique_ptr graph_viewer, owned_indexed_sub_graph_(std::move(indexed_sub_graph)) {} // Static class function to create a std::unique_ptr. -Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { +Status EpGraph::Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result, bool create_parent_node) { auto ep_graph = std::make_unique(graph_viewer, PrivateTag{}); - return CreateImpl(std::move(ep_graph), graph_viewer, result); + return CreateImpl(std::move(ep_graph), graph_viewer, result, create_parent_node); } // Static class function to create a std::unique_ptr. @@ -584,7 +595,8 @@ Status EpGraph::Create(std::unique_ptr src_graph_viewer, return CreateImpl(std::move(ep_graph), graph_viewer, result); } -Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result) { +Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, + /*out*/ std::unique_ptr& result, bool create_parent_node) { AllocatorPtr initializer_allocator = CPUAllocator::DefaultInstance(); std::unordered_map> value_infos_map; @@ -687,6 +699,9 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& } } + std::unique_ptr ep_parent_node = nullptr; + std::unordered_map> parent_node_value_infos_map; + // If this is a subgraph, add the OrtValueInfo and OrtValue objects that come from the outer scope. // Wait until we have already processed OrtValueInfos consumed and produced by nodes so that we only add // outer OrtValueInfo/OrtValue if they are actually used by the nodes in this GraphViewer. @@ -694,6 +709,20 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& gsl::not_null parent_graph = graph_viewer.GetGraph().ParentGraph(); gsl::not_null parent_node = graph_viewer.ParentNode(); + // If the subgraph of a control-flow op is created before its parent node (for example, when constructing + // the graph during ORT's GetCapability() in a bottom-up manner), the parent node must also be created. + if (create_parent_node) { + std::unique_ptr ep_node = nullptr; + + // At this point, the EpGraph that contains the parent node hasn't been created yet. + // It's not needed to create that EpGraph here, so just pass nullptr. + ORT_RETURN_IF_ERROR(EpNode::Create(*parent_node, /*ep_graph*/ nullptr, parent_node_value_infos_map, ep_node)); + + // Note: Calling ep_parent_node.GetGraph() will return nullptr because + // ep_parent_node was created without an associated EpGraph pointer. + ep_parent_node = std::move(ep_node); + } + for (gsl::not_null implicit_node_arg : parent_node->ImplicitInputDefs()) { const std::string& implicit_name = implicit_node_arg->Name(); auto value_info_iter = value_infos_map.find(implicit_name); @@ -741,6 +770,9 @@ Status EpGraph::CreateImpl(std::unique_ptr ep_graph, const GraphViewer& ep_graph->outer_scope_initializer_values_ = std::move(outer_scope_initializer_values); ep_graph->inputs_ = std::move(graph_input_value_infos); ep_graph->outputs_ = std::move(graph_output_value_infos); + ep_graph->parent_node_owned_ = std::move(ep_parent_node); + ep_graph->parent_node_ = ep_graph->parent_node_owned_ ? ep_graph->parent_node_owned_.get() : nullptr; + ep_graph->parent_node_value_infos_map_ = std::move(parent_node_value_infos_map); result = std::move(ep_graph); @@ -873,10 +905,15 @@ Status EpGraph::GetNodes(gsl::span dst) const { Status EpGraph::GetParentNode(const OrtNode*& result) const { result = parent_node_ != nullptr ? parent_node_->ToExternal() : nullptr; + return Status::OK(); } -void EpGraph::SetParentNode(const EpNode* node) { parent_node_ = node; } +void EpGraph::SetParentNode(const EpNode* node) { + parent_node_ = node; + parent_node_owned_ = nullptr; + parent_node_value_infos_map_.clear(); +} const GraphViewer& EpGraph::GetGraphViewer() const { return graph_viewer_; } diff --git a/onnxruntime/core/graph/ep_api_types.h b/onnxruntime/core/graph/ep_api_types.h index e003f02a79a2d..e5747f2b4c2fe 100644 --- a/onnxruntime/core/graph/ep_api_types.h +++ b/onnxruntime/core/graph/ep_api_types.h @@ -191,6 +191,9 @@ struct EpNode : public OrtNode { const char** opt_attribute_names) const override; // Gets this node's parent graph, which is the graph that directly contains this node. + // Note: This call may return NULL if this node is obtained by calling GetParentNode() + // on an EpGraph that is a subgraph of a control-flow op, and the parent graph has not been created yet, + // for example during ORT's GetCapability() when processing the innermost subgraph. Status GetGraph(const OrtGraph*& parent_graph) const override; // @@ -269,8 +272,16 @@ struct EpGraph : public OrtGraph { /// /// /// + /// If the `graph_viewer` is a subgraph of a control flow op, + /// e.g. Loop/If/Scan op, and `create_parent_node` is set to true, + /// then `result` EpGraph will create and own parent node's EpNode + /// instance. It's mainly used in EP's GetCapability() as it's + /// a bottom-up approach where inner-most subgraph will be constructed + /// first and by the time its parent node/graph hasn't be constructed yet. /// - static Status Create(const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + static Status Create(const GraphViewer& graph_viewer, + /*out*/ std::unique_ptr& result, + bool create_parent_node = false); /// /// Creates an instance of EpGraph, which wraps a GraphViewer. @@ -364,17 +375,27 @@ struct EpGraph : public OrtGraph { private: /// /// The real implementation of creating an EpGraph instance. - /// Please use one of the above 'Create' functions that internally call this function, and avoid calling this function directly. + /// Please use one of the above 'Create' functions that internally call this function, + /// and avoid calling this function directly. /// /// /// /// + /// /// - static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, /*out*/ std::unique_ptr& result); + static Status CreateImpl(std::unique_ptr ep_graph, const GraphViewer& graph_viewer, + /*out*/ std::unique_ptr& result, bool create_parent_node = false); const GraphViewer& graph_viewer_; + + // Hold the parent node created and owned by this graph + std::unique_ptr parent_node_owned_ = nullptr; + + // Holds either a pointer to a parent node not owned by this graph, a pointer to parent_node_owned_, or nullptr. const EpNode* parent_node_ = nullptr; + std::unordered_map> parent_node_value_infos_map_; + std::unique_ptr owned_graph_viewer_ = nullptr; std::unique_ptr owned_indexed_sub_graph_ = nullptr; diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index dd3eb59b7fafb..c41dc0b288930 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -3771,7 +3771,7 @@ Status Graph::ConvertInitializersIntoOrtValues() { std::unique_ptr external_data_info; ORT_RETURN_IF_ERROR(onnxruntime::ExternalDataInfo::Create(tensor_proto.external_data(), external_data_info)); const auto& location = external_data_info->GetRelPath(); - auto st = utils::ValidateExternalDataPath(model_dir, location); + auto st = utils::ValidateExternalDataPath(model_dir, location, model_path); if (!st.IsOK()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "External data path validation failed for initializer: ", tensor_proto.name(), diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bf88a0556683d..db7ec288001f9 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -840,7 +840,7 @@ enum MLAS_CONV_ALGORITHM { MlasConvAlgorithmGemmDirect, MlasConvAlgorithmExpandThenGemm, MlasConvAlgorithmExpandThenGemmSegmented, -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) MlasConvAlgorithmDepthwise, #endif }; @@ -1126,6 +1126,16 @@ MlasEltwiseAdd( size_t N ); +template +void +MLASCALL +MlasEltwiseMul( + const T* left, + const T* right, + T* output, + size_t N + ); + template void MLASCALL @@ -1232,7 +1242,8 @@ MlasNchwcConv( float* Output, const MLAS_ACTIVATION* Activation, bool ZeroMode, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + bool UseBf16 = false ); void @@ -1955,6 +1966,7 @@ struct MLAS_SBGEMM_DATA_PARAMS { const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ + bool ZeroMode = true; /**< true: C = A*B, false: C += A*B */ }; /** @@ -2115,14 +2127,3 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); - -#if defined(USE_KLEIDIAI) -/** - * @brief Function to override the packing mechanism decision if kleidi ai is included - * @param enable enable kleidiai packing (allow or disallow depending on true/false) - * @return -*/ -void -MLASCALL -MlasGemmBatchPackUseKleidi(bool enable); -#endif diff --git a/onnxruntime/core/mlas/inc/mlas_q4.h b/onnxruntime/core/mlas/inc/mlas_q4.h index 69f0435615079..d60e5b0164fe8 100644 --- a/onnxruntime/core/mlas/inc/mlas_q4.h +++ b/onnxruntime/core/mlas/inc/mlas_q4.h @@ -57,10 +57,10 @@ MlasQ4GemmPackBSize( * * @param QType type of block quantization * @param PackedBuf destination buffer - * @param FpData the pointer to fp32 matrix - * @param N the number of columns of matrix B. - * @param K the number of rows of matrix B. - * @param ldb leading dimension of B + * @param FpData the pointer to fp32 matrix, with shape [K, N]. + * @param N the number of columns of matrix B (Output Channels). + * @param K the number of rows of matrix B (Input Channels). + * @param ldb leading dimension of FpData (usually N) */ void MLASCALL diff --git a/onnxruntime/core/mlas/inc/mlas_qnbit.h b/onnxruntime/core/mlas/inc/mlas_qnbit.h index fc3c0b6016ced..39df8cf4e9a34 100644 --- a/onnxruntime/core/mlas/inc/mlas_qnbit.h +++ b/onnxruntime/core/mlas/inc/mlas_qnbit.h @@ -27,11 +27,11 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ - HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ - BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ - SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ - HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ + SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ + HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ + BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ + SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ + HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ } MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** @@ -41,13 +41,13 @@ typedef enum { */ template struct MLAS_QNBIT_GEMM_DATA_PARAMS { - const T* A = nullptr; ///< address of A (float32/16 matrix) - size_t lda = 0; ///< leading dimension of A - const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) - const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block - const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* A = nullptr; ///< address of A (float32/16 matrix) + size_t lda = 0; ///< leading dimension of A + const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) + const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block /// /// Address of scale * accumulate(quant - zp), one per block, where `scale`, `quant`, `zp` are respectively @@ -58,10 +58,10 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS { /// This input is to be used only when A is quantized to uint8. /// const T* BlkUnsignedQuantAZeroPointCorrection = nullptr; - - const T* Bias = nullptr; ///< optional address of Bias, vector size N - T* C = nullptr; ///< address of result matrix - size_t ldc = 0; ///< leading dimension of C + + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix + size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; @@ -232,3 +232,124 @@ MlasQNBitGemmScalesPacked( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool HasZeroPoint ); + +/** + * @brief Determines whether the Lut (Lookup Table) GEMM optimization path is available. + * + * @param[in] N column size of matrix B + * @param[in] K row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 2 means 2 bit ints) + * @param[in] BlkLen number of quantized values per block + * @return true if Lut GEMM is available for the given parameters + */ +bool MLASCALL +MlasIsLutGemmAvailable( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen +); + +/** + * @brief Initializes kernel configuration for Lut GEMM. + * + * @param[in] M row size of output matrix + * @param[in] N column size of matrix B + * @param[in] nbits quantized value bit width + * @param[in] block_size number of quantized values per block + * @param[in] has_zero_point whether zero points are provided + */ +void MLASCALL +MlasInitLutGemmKernelConfig( + size_t M, + size_t N, + size_t nbits, + size_t block_size, + bool has_zero_point +); + +/** + * @brief Clears the cached LUT GEMM kernel configuration. + * Call this when the model dimensions change or to reset state between operations. + * Primarily used in testing scenarios to ensure clean state between test runs. + */ +void MLASCALL +MlasClearLutGemmKernelConfig(); + +/** + * @brief Gets the total size in bytes of the prepacked buffer for Lut GEMM. + * This buffer contains packed quantized B data followed by packed scales and zero points. + * + * @param[in] N column size of matrix B + * @param[in] K row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 2 means 2 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided + * @return Total size in bytes of the prepacked buffer + */ +size_t MLASCALL +MlasLutGemmPackedSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +); + +/** + * @brief Packs quantized B data and/or scales/zero points into a buffer for Lut GEMM. + * If QuantBScale is nullptr, only packs B data. If QuantBData is nullptr, only packs scales. + * + * @param[in] N column size of matrix B + * @param[in] K row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 2 means 2 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided + * @param[in] QuantBData quantized B data (nullptr to skip B packing) + * @param[in] QuantBScale quantized B scales (nullptr to skip scale packing) + * @param[in] QuantBZeroPoint quantized B zero points (nullptr if HasZeroPoint is false) + * @param[out] PackedBuf output buffer (must be at least MlasLutGemmPackedSize bytes) + * @param[in] ThreadPool thread pool for parallel packing + */ +void MLASCALL +MlasLutGemmPack( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + const std::byte* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + std::byte* PackedBuf, + MLAS_THREADPOOL* ThreadPool +); + +/** + * @brief Executes TMAC compute using Lut (Lookup Table) based GEMM. + * + * This function handles generating the look up tables and accumulating the matmul results. + * Results will be stored in C. + * + * @param[in] A activation matrix + * @param[in] BlkLen number of quantized values per block + * @param[in] PackedBuf packed buffer containing weights and scales/zp (from MlasLutGemmPack) + * @param[out] C output matrix + * @param[in] K inner dimension + * @param[in] M batch size (number of rows in activation) + * @param[in] N column size of matrix B + * @param[in] HasZeroPoint whether zero points are provided + * @param[in] threadpool thread pool for parallel computation + */ +void MLASCALL +MlasLutGemm( + const void* A, + size_t BlkLen, + const void* PackedBuf, + void* C, + size_t K, + size_t M, + size_t N, + bool HasZeroPoint, + MLAS_THREADPOOL* threadpool +); diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index 9518134631f2d..0e5cb12012f7a 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -698,13 +698,13 @@ Return Value: const size_t OutputGroupSize = FilterCount * OutputSize; const size_t FilterGroupSize = FilterCount * K; + const float* input = WorkBlock->Input + BatchGroupStart * InputGroupSize; + float* output = WorkBlock->Output + BatchGroupStart * OutputGroupSize; + for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) { size_t group = bg % GroupCount; - - const float* input = WorkBlock->Input + bg * InputGroupSize; const float* filter = WorkBlock->Filter + group * FilterGroupSize; - float* output = WorkBlock->Output + bg * OutputGroupSize; // // Invoke the non-threaded GEMM directly with the input tensor. @@ -726,6 +726,9 @@ Return Value: MlasActivation(Parameters->Activation, output, bias, FilterCount, OutputSize, OutputSize); + + input += InputGroupSize; + output += OutputGroupSize; } } @@ -805,6 +808,90 @@ Return Value: } } +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) + +void +MlasDepthwiseThreaded( + void* Context, + ptrdiff_t Index +) + +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + convolution operation. + + If using this, the entire convolution operation is parallelized on the + (batch size * group count) parameter and this routine has logic to + perform a specific thread's shard of the entire Convolution operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + Index - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ + +{ + + MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; + + const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; + + const size_t GroupCount = Parameters->GroupCount; + const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; + + const size_t TargetThreadCount = WorkBlock->TargetThreadCount; + + const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount; + const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount; + + size_t BatchGroupStart; + size_t BatchGroupEnd; + + if (static_cast(Index) < BatchGroupCountExtra) { + BatchGroupStart = (BatchGroupCountPerThread + 1) * Index; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1; + } else { + BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread; + } + + const size_t FilterCount = Parameters->FilterCount; + const size_t OutputSize = Parameters->OutputSize; + const size_t K = Parameters->K; + + const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; + const size_t OutputGroupSize = FilterCount * OutputSize; + const size_t FilterGroupSize = FilterCount * K; + + for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) { + size_t group = bg % GroupCount; + + const float* input = WorkBlock->Input + bg * InputGroupSize; + const float* filter = WorkBlock->Filter + group * FilterGroupSize; + float* output = WorkBlock->Output + bg * OutputGroupSize; + const float* bias = WorkBlock->Bias; + if (bias != nullptr) { + bias += group * FilterCount; + } + + float* WorkingBuffer = WorkBlock->WorkingBuffer; + + MlasConvDepthwiseFloat_CHW(Parameters, input, filter, output, WorkingBuffer); + MlasActivation(Parameters->Activation, output, bias, FilterCount, OutputSize, OutputSize); + } +} + +#endif + inline bool MlasConvTryMultithread( @@ -985,7 +1072,7 @@ Return Value: return; } -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) if (Algorithm == MlasConvAlgorithmDepthwise) { // Fill the Working Buffer with Zero for use by the depthwise kernel. @@ -1019,6 +1106,35 @@ Return Value: return; } + +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) + + if (Algorithm == MlasConvAlgorithmDepthwise && ((BatchCount > 1) || (GroupCount > 1))) { + const size_t BatchGroupCount = BatchCount * GroupCount; + + ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); + } + + MLAS_CONV_WORK_BLOCK WorkBlock; + + WorkBlock.Parameters = Parameters; + WorkBlock.Input = Input; + WorkBlock.Filter = Filter; + WorkBlock.Bias = Bias; + WorkBlock.WorkingBuffer = WorkingBuffer; + WorkBlock.Output = Output; + WorkBlock.TargetThreadCount = TargetThreadCount; + + MlasExecuteThreaded(MlasDepthwiseThreaded, &WorkBlock, TargetThreadCount, ThreadPool); + + return; + } + +#endif + // // Iterate over each batch and group. // @@ -1082,7 +1198,7 @@ Return Value: break; } -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) case MlasConvAlgorithmDepthwise: { @@ -1337,17 +1453,26 @@ Return Value: } else { -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) - // Scalar direct conv for depthwise convolution. - // Currently only support 3x3 kernel with padding <=1 and dilations = 1. + // Scalar (WASM_SCALAR) / vectorized (ARM64) direct conv for depthwise convolution. + // Currently only support 3x3 kernel with padding <=1 and dilations = 1 + // and on ARM64, it is further restricted to strides = 1. // TODO: support more general depthwise convolution. + // On ARM64, only support stride = 1 for depthwise conv. + #if defined(MLAS_TARGET_ARM64) + bool depthwise_conv_stride_support_check = Parameters->StrideShape[0] == 1 && Parameters->StrideShape[1] == 1; + #else + bool depthwise_conv_stride_support_check = true; + #endif + if (Dimensions == 2 && Parameters->FilterCount == 1 && Parameters->InputChannels == 1 && Parameters->KernelShape[0] == 3 && Parameters->KernelShape[1] == 3 && Parameters->Padding[0] <= 1 && Parameters->Padding[1] <= 1 && Parameters->Padding[2] <= 1 && Parameters->Padding[3] <= 1 + && depthwise_conv_stride_support_check && Parameters->DilationShape[0] == 1 && Parameters->DilationShape[1] == 1) { *WorkingBufferSize = Parameters->InputShape[1] + 2; @@ -1411,8 +1536,8 @@ Return Value: if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) { - size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K, - Parameters->FilterCount * Parameters->OutputSize, + size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K, + Parameters->FilterCount * Parameters->OutputSize, static_cast(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)}); TargetThreadCount = MaximumThreadCount; if (static_cast(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) { diff --git a/onnxruntime/core/mlas/lib/eltwise.cpp b/onnxruntime/core/mlas/lib/eltwise.cpp index f63d71b40bfbb..82457deb811a2 100644 --- a/onnxruntime/core/mlas/lib/eltwise.cpp +++ b/onnxruntime/core/mlas/lib/eltwise.cpp @@ -53,6 +53,38 @@ MlasEltwiseAdd( } } +template <> +void +MLASCALL +MlasEltwiseMul( + const float* left, + const float* right, + float* output, + size_t N +) { + while (N > 0) { + if (N >= 4) { + MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left); + MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right); + + MLAS_FLOAT32X4 ResultVec = MlasMultiplyFloat32x4(LeftVec, RightVec); + + MlasStoreFloat32x4(output, ResultVec); + + left += 4; + right += 4; + output += 4; + N -= 4; + } else { + *output = *left * *right; + + left += 1; + right += 1; + output += 1; + N -= 1; + } + } +} template <> void diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp index 87184bf8bb3cf..06f9d97b872c7 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.cpp @@ -20,6 +20,13 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" #include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#if defined(ENABLE_QMX_KERNELS) +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa.h" +#endif // ENABLE_QMX_KERNELS + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod = {kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, @@ -122,6 +129,60 @@ const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_sme2 = kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa, kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa}; +const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm_sme = + {kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa, + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme_mopa}; + +const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm_sme2 = + {kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa, + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa}; + +#if defined(ENABLE_QMX_KERNELS) +const kai_matmul_clamp_f32_f32p_f32p_ukernel sgemm_gemm_qmx = + {kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_dst_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_get_dst_size_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa, + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa}; + +const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm_qmx = + {kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa, + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa}; +#endif // ENABLE_QMX_KERNELS + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() { if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm; @@ -142,7 +203,17 @@ const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel() { if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { return sgemm_gemm_sme2; } else { +#if defined(ENABLE_QMX_KERNELS) + if (ArmKleidiAI::vendor_name.compare("Qualcomm") == 0) + { + KLEIDIAI_KERNEL_LOG("SGEMM: Using QMX Kernel"); + return sgemm_gemm_qmx; + } else { + return sgemm_gemm_sme; + } +#else return sgemm_gemm_sme; +#endif // ENABLE_QMX_KERNELS } } @@ -153,3 +224,21 @@ const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel() { return sgemm_gemv_sme; } } + +const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel& GetKleidiAIQGemmUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()) { + return qgemm_gemm_sme2; + } else { +#if defined(ENABLE_QMX_KERNELS) + if (ArmKleidiAI::vendor_name.compare("Qualcomm") == 0) + { + KLEIDIAI_KERNEL_LOG("QGEMM: Using QMX Kernel"); + return qgemm_gemm_qmx; + } else { + return qgemm_gemm_sme; + } +#else + return qgemm_gemm_sme; +#endif // ENABLE_QMX_KERNELS + } +} diff --git a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h index e69c72329d64b..7bd8959b0b5bd 100644 --- a/onnxruntime/core/mlas/lib/kai_ukernel_interface.h +++ b/onnxruntime/core/mlas/lib/kai_ukernel_interface.h @@ -12,8 +12,12 @@ #include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p_interface.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h" + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel(); const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel(); const kai_matmul_clamp_f32_f32p_f32p_ukernel& GetKleidiAISGemmUKernel(); const kai_matmul_clamp_f32_f32_f32p_ukernel& GetKleidiAISGemvUKernel(); + +const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel& GetKleidiAIQGemmUKernel(); \ No newline at end of file diff --git a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp index 487e1533f5967..5f9d121232a27 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/convolve_kleidiai.cpp @@ -16,6 +16,10 @@ #include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" #include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" #include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" +#if defined(ENABLE_QMX_KERNELS) +#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_qmx_mopa.h" +#endif // ENABLE_QMX_KERNELS + // Right-hand-side (weights) cache key struct RhsCacheKey { @@ -391,6 +395,12 @@ static std::shared_ptr LhsPtrFill(const size_t ci, const size_t i auto lhs_ptrs = std::shared_ptr(new const void*[lhs_ptrs_k * lhs_ptrs_m], std::default_delete()); + // Initialize all padding entries. For partial tiles (m < m_step), + // the kai LHS packing kernel may still read pointer entries beyond the logically + // filled 'm' positions. Leaving these uninitialized can cause non-deterministic + // reads and corrupt packed LHS data. + auto lhs_ptrs_ = lhs_ptrs.get(); + std::fill(lhs_ptrs_, lhs_ptrs_ + (lhs_ptrs_k * lhs_ptrs_m), reinterpret_cast(&pad_ptr[0])); auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1); auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1); @@ -426,7 +436,6 @@ static std::shared_ptr LhsPtrFill(const size_t ci, const size_t i }; size_t m_{0}; - auto lhs_ptrs_ = lhs_ptrs.get(); for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) { for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) { size_t k_{0}; @@ -456,7 +465,23 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s // figure out how many blocks needed to correctly fill padding padsize = ((ci + padsize - 1) / padsize) * padsize; } - static std::vectorpad_ptr(padsize, 0.f); + + // pad_ptr must be at least 'ci' floats for padding pixels. + // Using a thread_local grow-only buffer to avoid cross-thread interference and ensure sizing is correct. + thread_local std::vector pad_ptr; + const float* old_pad_ptr = pad_ptr.data(); + bool has_pad_ptr_changed = false; + + if (pad_ptr.size() < padsize) { + pad_ptr.resize(padsize, 0.f); + if (pad_ptr.data() != old_pad_ptr) { + has_pad_ptr_changed = true; + } + } else { + // Ensure any previously-used region remains zeroed (grow-only means it should already be zeros, + // but keep this explicit for safety). + std::fill(pad_ptr.begin(), pad_ptr.end(), 0.f); + } LhsCacheKey key = { ci, ih, iw, @@ -477,6 +502,16 @@ static std::unique_ptr LhsPackImageDataSme(const size_t ci, const s // Cache of computed lhs ptr offsets. thread_local to prevent interference from parallel sessions. thread_local std::unordered_map> lhs_ptrs_cache; + if (has_pad_ptr_changed) + { + // If the pad buffer was resized and a re-allocation has occurred, the cached lhs ptrs are invalid as they + // would be referencing the old pad buffer. + // See discussion in https://github.com/microsoft/onnxruntime/pull/27214. + // TODO(hasesh / JonathanC-ARM): A better approach would be to include the pad buffer address in the cache key + // or any other approach that would reduce unnecessary cache invalidations. + lhs_ptrs_cache.clear(); + } + std::shared_ptr lhs_ptrs; if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { lhs_ptrs = found->second; @@ -596,11 +631,29 @@ static void ConvolveSme(const size_t co, //channels out -std::numeric_limits::max(), std::numeric_limits::max() ); } else { - KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); - kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( - TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() - ); + #if defined(ENABLE_QMX_KERNELS) + if (ArmKleidiAI::vendor_name.compare("Qualcomm") == 0) + { + KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_qmx_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); + kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_qmx_mopa( + TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + else { + KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); + kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( + TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + #else + KLEIDIAI_KERNEL_LOG("kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa" << " M=" << TileSizeM << " N=" << TileSizeN << " k_chunk_count=" << (d_kh * d_kw) << " k_chunk_length=" << ci); + kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( + TileSizeM, TileSizeN, d_kh * d_kw, ci, ATile, BTile, CTile, co * sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + #endif // ENABLE_QMX_KERNELS } }); diff --git a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h index ca81b9fa426ee..4c088e8660874 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h +++ b/onnxruntime/core/mlas/lib/kleidiai/mlasi_kleidiai.h @@ -53,6 +53,8 @@ namespace ArmKleidiAI { // By default we should try for SME2 first before falling back to SME. inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); +inline const bool UseSME = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME(); +inline const std::string_view vendor_name = MLAS_CPUIDINFO::GetCPUIDInfo().GetCPUVendor(); // Buffer packing routines. // @@ -105,14 +107,14 @@ MlasGemmBatch( size_t MLASCALL -MlasDynamicQgemmPackBSize( +MlasDynamicQGemmPackBSize( size_t N, size_t K ); void MLASCALL -MlasDynamicQgemmPackB( +MlasDynamicQGemmPackB( size_t N, size_t K, const int8_t* B, diff --git a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp index 1d682b372e2f5..b6a23735bd131 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/qgemm_kleidiai.cpp @@ -10,25 +10,40 @@ #include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" #include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" -#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" +#include "kai_ukernel_interface.h" +#if defined(ENABLE_QMX_KERNELS) +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_qmx_mopa.h" +#endif // ENABLE_QMX_KERNELS #include "mlasi_kleidiai.h" -//Matmul with float output of dynamic quantized A and symmetric quantized B. +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffersQgemm { + std::vector lhs_packed; + std::vector lhs_base_table; +}; +static thread_local KaiTlsBuffersQgemm g_kai_tls_qgemm; + +const kai_matmul_clamp_f32_qai8dxp_qsi8cxp_ukernel qgemm_gemm = GetKleidiAIQGemmUKernel(); + +// Matmul with float output of dynamic-quantized A and symmetric-quantized B. size_t MLASCALL -ArmKleidiAI::MlasDynamicQgemmPackBSize( +ArmKleidiAI::MlasDynamicQGemmPackBSize( size_t N, size_t K ) { - //Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use - auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + // Degenerate shapes: there is nothing to pack. + if (N == 0 || K == 0) { + return 0; + } - //regardless of kernel variant use neon packing variant + auto nr = qgemm_gemm.get_nr(); + auto kr = qgemm_gemm.get_kr(); + auto sr = qgemm_gemm.get_sr(); + + // Regardless of kernel variant, use the NEON packing variant. KLEIDIAI_KERNEL_LOG("kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon Groups=1" << " N="<< N << " K=" << K << " nr=" << nr << " kr=" << kr << " sr=" << sr); return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); @@ -36,7 +51,7 @@ ArmKleidiAI::MlasDynamicQgemmPackBSize( void MLASCALL -ArmKleidiAI::MlasDynamicQgemmPackB( +ArmKleidiAI::MlasDynamicQGemmPackB( size_t N, size_t K, const int8_t* B, @@ -44,10 +59,14 @@ ArmKleidiAI::MlasDynamicQgemmPackB( const float* Bias, void* PackedB ) { - // Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use - auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + // Degenerate shapes: nothing to pack. Avoid calling into packers that may not tolerate K==0. + if (N == 0 || K == 0) { + return; + } + + auto nr = qgemm_gemm.get_nr(); + auto kr = qgemm_gemm.get_kr(); + auto sr = qgemm_gemm.get_sr(); // y - float output // scale_factor_lhs - lhs scaling factor @@ -57,9 +76,9 @@ ArmKleidiAI::MlasDynamicQgemmPackB( // lhs_zp - lhs zero point // y = (1/(scale_factor_lhs * scale_factor_rhs) * sum( (lhs_q + lhs_zp)*rhs_q )) + bias - // rhs packing requires lhs_zp because it will perform lhs_zp*rhs_q during rhs packing - // because lhs quantization is hidden from us, by lhs quant packing, we don't have a value for lhs_zp it is - // lhs dynamic quantization + // RHS packing requires lhs_zp because it will perform lhs_zp*rhs_q during RHS packing. + // Because LHS quantization is hidden from us by LHS quant packing, we don't have a value for lhs_zp. + // LHS uses dynamic quantization. kai_rhs_pack_qsi8cx_params params{ 1, // lhs_zp - set to 1 so it becomes sum((lhs_q + 1)*rhs_q )), @@ -67,7 +86,7 @@ ArmKleidiAI::MlasDynamicQgemmPackB( 1.f // it is not used }; - //regardless of kernel variant use neon packing variant + // Regardless of kernel variant, use the NEON packing variant. kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(1, N, K, nr, kr, sr, B, // N bias values Bias, @@ -80,42 +99,142 @@ MLASCALL ArmKleidiAI::MlasDynamicQGemmBatch( const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, - const size_t BatchN, + const size_t BatchSize, MLAS_THREADPOOL* ThreadPool ) { - for (auto b = BatchN; b > 0; --b,++DataParams) { - auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); - auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + const size_t mr = qgemm_gemm.get_mr(); + const size_t kr = qgemm_gemm.get_kr(); + const size_t sr = qgemm_gemm.get_sr(); - //TODO enable multi-threading for lhs packing and matmul - MLAS_UNREFERENCED_PARAMETER(ThreadPool); + size_t m_step = qgemm_gemm.get_m_step(); + size_t n_step = qgemm_gemm.get_n_step(); - //Dynamic Quantize A - lhs - auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); - std::byte* lhs = nullptr; - std::unique_ptr fallback; + if (BatchSize == 0 || Shape.M == 0 || Shape.N == 0 || Shape.K == 0) { + return; + } - if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { - lhs = static_cast(DataParams->Workspace); - } else { - fallback = std::make_unique(lhs_size); - lhs = fallback.get(); + // We are required to fail fast when we reach this stage as we will not be able + // to reverse the packing decision that was made for RHS. + + if (DataParams == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires valid DataParams."); + } + + for (size_t batch_idx = 0; batch_idx < BatchSize; ++batch_idx) { + const auto& params = DataParams[batch_idx]; + + if (params.A == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires non-null A pointer."); + } + if (params.C == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires non-null C pointer."); + } + if (params.PackedB == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires non-null PackedB pointer."); } + const size_t lda = params.lda != 0 ? params.lda : Shape.K; + const size_t ldc = params.ldc != 0 ? params.ldc : Shape.N; + + if (lda < Shape.K) { + MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires lda >= K."); + } + if (ldc < Shape.N) { + MLAS_THROW_EX(std::runtime_error, "Dynamic QGEMM requires ldc >= N."); + } + } + + // Dynamic-quantize A (LHS). + const size_t LhsPackedStride = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); + std::byte* LhsPackedData = nullptr; + + if (g_kai_tls_qgemm.lhs_packed.capacity() < LhsPackedStride * BatchSize) { + + g_kai_tls_qgemm.lhs_packed.reserve(LhsPackedStride * BatchSize); + } + g_kai_tls_qgemm.lhs_packed.resize(LhsPackedStride * BatchSize); + LhsPackedData = g_kai_tls_qgemm.lhs_packed.data(); + + // Per-batch table of LHS base pointers. + if (g_kai_tls_qgemm.lhs_base_table.capacity() < BatchSize) { + + g_kai_tls_qgemm.lhs_base_table.reserve(BatchSize); + } + g_kai_tls_qgemm.lhs_base_table.resize(BatchSize); + // Capture the shared batch table pointer so worker threads use the same backing storage. + const std::byte** tls_lhs_base = g_kai_tls_qgemm.lhs_base_table.data(); + // B batches require no packing. + // We have already decided the matmul variant we are using before having values for M, N, and K. + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + + std::byte* lhs = nullptr; + if (DataParams[batch_idx].Workspace && DataParams[batch_idx].WorkspaceSize >= LhsPackedStride) { + lhs = static_cast(DataParams[batch_idx].Workspace); + } else { + lhs = &(LhsPackedData[LhsPackedStride * batch_idx]); + } KLEIDIAI_KERNEL_LOG("kai_run_lhs_quant_pack_qai8dxp_f32" << " M="<< Shape.M << " K=" << Shape.K << " mr=" << mr << " kr=" << kr << " sr=" << sr << " m_idx_start=0"); - kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, - Shape.K*sizeof(float), lhs); - - KLEIDIAI_KERNEL_LOG("kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa"); - kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( - Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, - DataParams->C, - Shape.N * sizeof(float), - sizeof(float), - -std::numeric_limits::max(), std::numeric_limits::max() + kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams[batch_idx].A, DataParams[batch_idx].lda*sizeof(float), lhs); + tls_lhs_base[batch_idx] = lhs; + }); + + // Tile iteration dimensions. + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(Shape.M, m_step); // M + dim[2] = MlasDivRoundup(Shape.N, n_step); // N + + // Minimize the kernel call count for the number of available threads. + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // Scale required tiles over available tile processors. + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // Compute new step sizes. + m_step *= MlasDivRoundup(MlasDivRoundup(Shape.M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(Shape.N, dim[2]), n_step); + + // Update tile iterations. + dim[1] = MlasDivRoundup(Shape.M, m_step); + dim[2] = MlasDivRoundup(Shape.N, n_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + + // Compute B, M, N indices from the iteration index. + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = qgemm_gemm.get_rhs_packed_offset(NIdx * n_step, Shape.K); + + const std::byte* B_base = reinterpret_cast(DataParams[BIdx].PackedB); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset =qgemm_gemm.get_lhs_packed_offset(MIdx * m_step, Shape.K); + + const std::byte* A_base = tls_lhs_base[BIdx]; // LhsPackedData + LhsPackedStride * BIdx; OR DataParams[batch_idx].Workspace; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > Shape.M ? (Shape.M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > Shape.N ? (Shape.N - NIdx * n_step) : n_step; + + float* dst_tile = reinterpret_cast( + reinterpret_cast(DataParams[BIdx].C) + + MIdx * m_step * DataParams[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) ); - } + + qgemm_gemm.run_matmul( + TileSizeM, TileSizeN, Shape.K, ATile, BTile, + dst_tile, + DataParams[BIdx].ldc * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + }); } diff --git a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp index 250b5d076475d..618d52c7af661 100644 --- a/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp +++ b/onnxruntime/core/mlas/lib/kleidiai/sgemm_kleidiai.cpp @@ -16,6 +16,9 @@ #include "mlasi_kleidiai.h" #include "kai_ukernel_interface.h" +#if defined(ENABLE_QMX_KERNELS) +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_qmx_mopa.h" +#endif // ENABLE_QMX_KERNELS // Thread-local reusable buffers to reduce allocation overhead across tiles. struct KaiTlsBuffers { @@ -145,9 +148,9 @@ ArmKleidiAI::MlasGemvBatch( if (M != 1 && N != 1) { return false; } - + const bool m_path = (M == 1); - + // We cannot support cases where N == 1 and B is already packed. // When both are 1, we route through the M-path, so this naturally doesn't trigger. if (!m_path && Data->BIsPacked) { @@ -165,15 +168,15 @@ ArmKleidiAI::MlasGemvBatch( // - M-path: LHS is A, stride = lda // - N-path: LHS is B, stride = ldb size_t lhs_ld = m_path ? Data[b].lda : Data[b].ldb; - + const float* rhs_base = m_path ? static_cast(Data[b].B) : static_cast(Data[b].A); - const float* lhs_base = m_path ? static_cast(Data[b].A) + const float* lhs_base = m_path ? static_cast(Data[b].A) : static_cast(Data[b].B); // Prepare packed RHS if needed const void* rhs_packed_ptr = nullptr; - + // The if branch can only be taken in cases where we are dealing with M == 1 // We previously reject any prepacked B where N == 1 // In cases where N == 1 we Pack A Matrix as the RHS using tb = CBlasTrans diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index ad62cccbfb9c7..ac7528853c596 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -851,17 +851,9 @@ bool MLAS_THREADPOOL* ThreadPool ); -typedef void (MLASCALL MLAS_GEMM_BATCH)( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t M, - size_t N, - size_t K, - const MLAS_SGEMM_DATA_PARAMS* Data, - size_t BatchSize, - MLAS_THREADPOOL* ThreadPool); - -typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)( +typedef +bool +(MLASCALL MLAS_SGEMM_BATCH_OVERRIDE)( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t M, @@ -871,19 +863,17 @@ typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)( size_t BatchSize, MLAS_THREADPOOL* ThreadPool); -typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE)( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, - size_t N, - size_t K); - -typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)( +typedef +size_t +(MLASCALL MLAS_SGEMM_PACK_B_SIZE_OVERRIDE)( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K); -typedef void (MLASCALL MLAS_GEMM_PACK_B)( +typedef +bool +(MLASCALL MLAS_SGEMM_PACK_B_OVERRIDE)( CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, @@ -892,13 +882,28 @@ typedef void (MLASCALL MLAS_GEMM_PACK_B)( size_t ldb, void* PackedB); -typedef bool (MLASCALL MLAS_GEMM_PACK_B_OVERRIDE)( - CBLAS_TRANSPOSE TransA, - CBLAS_TRANSPOSE TransB, +typedef +void +(MLASCALL MLAS_DYNAMIC_QGEMM_BATCH_OVERRIDE)( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool); + +typedef +size_t +(MLASCALL MLAS_DYNAMIC_QGEMM_PACK_B_SIZE_OVERRIDE)( + size_t N, + size_t K); + +typedef +void +(MLASCALL MLAS_DYNAMIC_QGEMM_PACK_B_OVERRIDE)( size_t N, size_t K, - const float* B, - size_t ldb, + const int8_t* B, + const float* Scales, + const float* Bias, void* PackedB); extern "C" { @@ -967,6 +972,9 @@ extern "C" { MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; +#if defined(__aarch64__) && defined(__linux__) + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseBf16KernelNeon; +#endif MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon; MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon; MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon; @@ -1238,6 +1246,10 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx; +struct MLAS_QNBIT_LUT_GEMM_DISPATCH; + +extern const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2; + // // Rotary embedding dispatch structure. // @@ -1326,10 +1338,15 @@ struct MLAS_PLATFORM { bool Avx512Supported_ = false; bool ArmNeonIsQuantActivationsUnsigned = false; - // Mlas overrides initialisation - MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; - MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; - MLAS_GEMM_PACK_B_OVERRIDE* MlasGemmPackBOverride = nullptr; + // MLAS SGemm overrides + MLAS_SGEMM_BATCH_OVERRIDE* MlasSGemmBatchOverride = nullptr; + MLAS_SGEMM_PACK_B_SIZE_OVERRIDE* MlasSGemmPackBSizeOverride = nullptr; + MLAS_SGEMM_PACK_B_OVERRIDE* MlasSGemmPackBOverride = nullptr; + // MLAS Dynamic QGemm overrides + MLAS_DYNAMIC_QGEMM_BATCH_OVERRIDE* MlasDynamicQGemmBatchOverride = nullptr; + MLAS_DYNAMIC_QGEMM_PACK_B_SIZE_OVERRIDE* MlasDynamicQGemmPackBSizeOverride = nullptr; + MLAS_DYNAMIC_QGEMM_PACK_B_OVERRIDE* MlasDynamicQGemmPackBOverride = nullptr; + // MLAS Conv overrides MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; @@ -1368,6 +1385,9 @@ struct MLAS_PLATFORM { MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; +#if defined(__aarch64__) && defined(__linux__) + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseBf16Kernel; +#endif MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; uint32_t NchwcBlockSize; #endif @@ -1443,6 +1463,7 @@ struct MLAS_PLATFORM { const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_LUT_GEMM_DISPATCH* LutGenKernel{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; @@ -1601,7 +1622,8 @@ MlasFp32FromBits( #pragma warning(pop) #endif -#if defined(MLAS_TARGET_WASM_SCALAR) +#if defined(MLAS_TARGET_WASM_SCALAR) || defined(MLAS_TARGET_ARM64) + void MLASCALL diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index f30c49220b7ef..6ebd6be068b12 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -421,6 +421,8 @@ Return Value: this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; this->RopeDispatch = &MlasRopeDispatchAvx2; + // TODO(vraspar): check if this really goes here or if there are other platform reqs that we need to fulfill + this->LutGenKernel = &MlasLutGenKernelAvx2; // // Check if the processor supports Hybrid core architecture. @@ -573,6 +575,9 @@ Return Value: this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; +#if defined(__aarch64__) && defined(__linux__) + this->ConvPointwiseBf16Kernel = MlasConvPointwiseBf16KernelNeon; +#endif this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon; this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon; this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon; @@ -605,9 +610,12 @@ Return Value: #if defined(USE_KLEIDIAI) if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ - this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; - this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; - this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; + this->MlasSGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; + this->MlasSGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; + this->MlasSGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; + this->MlasDynamicQGemmBatchOverride = ArmKleidiAI::MlasDynamicQGemmBatch; + this->MlasDynamicQGemmPackBSizeOverride = ArmKleidiAI::MlasDynamicQGemmPackBSize; + this->MlasDynamicQGemmPackBOverride = ArmKleidiAI::MlasDynamicQGemmPackB; this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; this->MlasConvOverride = ArmKleidiAI::MlasConv; } diff --git a/onnxruntime/core/mlas/lib/q4_dq.cpp b/onnxruntime/core/mlas/lib/q4_dq.cpp index c543770ee22d8..fbbf4005ae4a5 100644 --- a/onnxruntime/core/mlas/lib/q4_dq.cpp +++ b/onnxruntime/core/mlas/lib/q4_dq.cpp @@ -545,7 +545,7 @@ struct BlockwiseQuantizer { } } - for (int32_t j = c; j < c_end; ++j) { + for (int32_t j = c; j < c_end; ++j) { // this does not work if j runs more then 1 because zp_bytes is indexed by i. const int32_t meta_c = j / QuantBlk::kColumn; for (int32_t i = r; i < r_end; i += kPackSize) { for (int l = 0; l < kPackSize && i + l < r_end; l++) { @@ -656,19 +656,35 @@ struct BlockwiseQuantizer { * @tparam signed_quant quantized type is signed */ template -struct BlockwiseQDQQuantizer; - -template -struct BlockwiseQDQQuantizer { +struct BlockwiseQDQQuantizer { static MLAS_FORCEINLINE uint8_t GetElem(uint8_t val, int32_t idx) { - return (val >> (idx << 2)) & 0xF; + if constexpr (qbits == 2) { + return (val >> (idx << 1)) & 0x3; + } else if constexpr (qbits == 4) { + return (val >> (idx << 2)) & 0xF; + } } static MLAS_FORCEINLINE uint8_t SetElem(uint8_t val, int32_t idx, uint8_t dst) { - auto shift = idx << 2; - return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + if constexpr (qbits == 2) { + auto shift = idx << 1; + return ((val & 0x3) << shift) | (dst & (~(0x3 << shift))); + } else if constexpr (qbits == 4) { + auto shift = idx << 2; + return ((val & 0xF) << shift) | (dst & (~(0xF << shift))); + } + } + + template + static MLAS_FORCEINLINE uint8_t Pack(uint8_t v0, uint8_t v1, uint8_t v2, uint8_t v3) + { + if constexpr (add2) { + return ((v0 & 0x3) ^ 2) | (((v1 & 0x3) ^ 2) << 2) | (((v2 & 0x3) ^ 2) << 4) | (((v3 & 0x3) ^ 2) << 6); + } else { + return (v0 & 0x3) | ((v1 & 0x3) << 2) | ((v2 & 0x3) << 4) | ((v3 & 0x3) << 6); + } } template @@ -1436,8 +1452,7 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); -template -void +template void MlasBlockwiseQuantMetaShape( int block_size, bool columnwise, @@ -1445,7 +1460,7 @@ MlasBlockwiseQuantMetaShape( int columns, int& meta_rows, int& meta_cols - ); +); template void @@ -1513,8 +1528,7 @@ MlasBlockwiseQuantizedShape( int& q_cols ); -template -void +template void MlasBlockwiseQuantizedShape( int block_size, bool columnwise, @@ -1524,7 +1538,7 @@ MlasBlockwiseQuantizedShape( int& q_cols ); - template +template void MlasBlockwiseQuantizedShape( int block_size, @@ -2016,6 +2030,19 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template bool +MlasQDQQuantizeBlockwise( + const float* src, + float* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( const MLAS_FP16* src, @@ -2029,6 +2056,19 @@ MlasQDQQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template bool +MlasQDQQuantizeBlockwise( + const MLAS_FP16* src, + MLAS_FP16* scales, + uint8_t* zero_points, + uint8_t* dst, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template void MlasQDQTransposeBlockwiseQuantized( @@ -2055,6 +2095,36 @@ MlasQDQTransposeBlockwiseQuantized( } } +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasQDQTransposeBlockwiseQuantized( + const uint8_t* src_weights, + const float* src_scales, + const uint8_t* src_zero_points, + uint8_t* dst_weights, + float* dst_scales, + uint8_t* dst_zero_points, + bool columnwise, + int rows, + int columns, + int quant_block_size, + MLAS_THREADPOOL* thread_pool +); + template void MlasQDQTransposeBlockwiseQuantized( const uint8_t* src_weights, diff --git a/onnxruntime/core/mlas/lib/qgemm.cpp b/onnxruntime/core/mlas/lib/qgemm.cpp index cb5bd09daeb87..5f6a8f8394470 100644 --- a/onnxruntime/core/mlas/lib/qgemm.cpp +++ b/onnxruntime/core/mlas/lib/qgemm.cpp @@ -206,7 +206,7 @@ MLASCALL MlasIsDynamicQGemmAvailable() { #if defined(USE_KLEIDIAI) - return ArmKleidiAI::UseSME2; + return (ArmKleidiAI::UseSME2 || ArmKleidiAI::UseSME); #else return false; #endif @@ -224,7 +224,9 @@ MlasDynamicQGemmBatch ( #if defined(USE_KLEIDIAI) //No fallback - ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); + if (GetMlasPlatform().MlasDynamicQGemmBatchOverride != nullptr) { + GetMlasPlatform().MlasDynamicQGemmBatchOverride(Shape, DataParams, BatchN, ThreadPool); + } #endif MLAS_UNREFERENCED_PARAMETER(Shape); @@ -348,8 +350,9 @@ MlasDynamicQgemmPackBSize( size_t bytes = 0; #if defined(USE_KLEIDIAI) //No fallback available - //TODO: Insert Override - bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); + if (GetMlasPlatform().MlasDynamicQGemmPackBSizeOverride != nullptr) { + bytes = GetMlasPlatform().MlasDynamicQGemmPackBSizeOverride(N, K); + } #endif MLAS_UNREFERENCED_PARAMETER(N); @@ -442,7 +445,9 @@ MlasDynamicQgemmPackB( #if defined(USE_KLEIDIAI) //No fallback - ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); + if (GetMlasPlatform().MlasDynamicQGemmPackBOverride != nullptr) { + GetMlasPlatform().MlasDynamicQGemmPackBOverride(N, K, B, Scales, Bias, PackedB); + } #endif MLAS_UNREFERENCED_PARAMETER(N); diff --git a/onnxruntime/core/mlas/lib/qlutgemm.cpp b/onnxruntime/core/mlas/lib/qlutgemm.cpp new file mode 100644 index 0000000000000..32c72342b4803 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qlutgemm.cpp @@ -0,0 +1,718 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qlutgemm.cpp + +Abstract: + + This module implements kernel functions for generating lookup tables (LUT) + and computing matrix multiplication for the T-MAC GEMM optimization strategy. + + It provides functionality to pack quantized weight data, compute LUT scales + and biases, and perform efficient quantized GEMM operations using lookup + table based computation. + +--*/ +#include "qlutgemm.h" + +#include +#include +#include +#include +#include +#include +#include + +/** + * Global cache for T-MAC kernel parameters, indexed by configuration. + * This map and its associated mutex ensure thread-safe parameter management + * across concurrent MLAS calls. + */ +static std::unordered_map tmac_kernel_configs; +static std::mutex tmac_kernel_configs_mutex; + +static std::string +GetTmacKey(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) +{ + // Generate a unique cache key based on the GEMM and quantization configuration. + return std::to_string(M) + "_" + std::to_string(N) + "_" + std::to_string(nbits) + "_" + + std::to_string(block_size) + "_" + (has_zero_point ? "1" : "0"); +} + +MlasTMACKernelParams +MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) +{ + std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point); + std::lock_guard lock(tmac_kernel_configs_mutex); + auto it = tmac_kernel_configs.find(key); + if (it != tmac_kernel_configs.end()) { + return it->second; + } + MLAS_THROW_EX(std::runtime_error, "T-MAC kernel parameters not initialized for key: " + key); +} + +void MLASCALL +MlasClearLutGemmKernelConfig() +{ + std::lock_guard lock(tmac_kernel_configs_mutex); + tmac_kernel_configs.clear(); +} + +void MLASCALL +MlasInitLutGemmKernelConfig(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point) +{ + std::string key = GetTmacKey(M, N, nbits, block_size, has_zero_point); + { + std::lock_guard lock(tmac_kernel_configs_mutex); + if (tmac_kernel_configs.find(key) != tmac_kernel_configs.end()) { + return; + } + } + + MlasTMACKernelParams params; + params.g = 4; + params.ngroups_per_elem = 8 / params.g; + params.simd_n_in = 16; + params.simd_n_out = 8; + params.chunk_n = 8; + + params.bits = nbits; + params.q_group_size = block_size; + + if (block_size % 64 == 0) { + params.act_group_size = 64; + } else if (block_size % 32 == 0) { + params.act_group_size = 32; + } else { + // throw error + MLAS_THROW_EX(std::runtime_error, "Unsupported activation group size"); + } + params.actk = params.act_group_size / params.g; + + // search space + std::vector bms; + if (nbits == 1 || nbits == 2 || nbits == 4) { + bms = {256, 512, 1024, 2048, 320, 640, 1280}; + } else if (nbits == 3) { + bms = {192, 384, 576, 758}; + } + + std::vector kfactors = {8, 16}; + + // TODO(vraspar): add profile based policy + size_t threads = static_cast(std::thread::hardware_concurrency()); + + float smallest_penalty = 1e9f; + params.bm = bms[0]; + for (size_t bm : bms) { + if (M % (bm / nbits) != 0 || bm % nbits != 0) { + continue; + } + size_t num_tiles = M / (bm / nbits); + size_t num_groups = (num_tiles + threads - 1) / threads; + float penalty = 0.1f * static_cast(num_groups) + + (static_cast(num_groups) - 1.0f * static_cast(num_tiles) / static_cast(threads)) / + static_cast(num_groups); + if (penalty < smallest_penalty) { + smallest_penalty = penalty; + params.bm = bm; + } + } + + size_t largest_kfactor = 0; + params.kfactor = kfactors[0]; + for (size_t kfactor : kfactors) { + if ((kfactor < params.actk) || (kfactor * params.g > params.q_group_size)) { + continue; + } + if (kfactor > largest_kfactor) { + largest_kfactor = kfactor; + params.kfactor = kfactor; + } + } + + params.n_tiles_num = M * params.bits / params.bm; + params.has_scale = true; // TODO(vraspar): TMAC supports only scale for now + params.has_zero_point = has_zero_point; + params.one_scale = false; // TODO(vraspar): support one scale case for bitnet + + { + std::lock_guard lock(tmac_kernel_configs_mutex); + tmac_kernel_configs[key] = params; + } + return; +} + +// Internal helper: calculates packed quantized B data size +static size_t +LutGemmPackQuantBDataSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +) +{ + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const size_t PackedQuantBDataSize = (N * BlkBitWidth) * (K / tmac_params.g / tmac_params.ngroups_per_elem); + return PackedQuantBDataSize; +} + +// Internal helper: packs quantized B data +static void +LutGemmPackQuantBData( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + // decompose W into w1,... w_bits create temp buffer buf2 of size N * bits * (K/g) + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const size_t bits = tmac_params.bits; + const size_t g = tmac_params.g; + const size_t ngroups_per_elem = tmac_params.ngroups_per_elem; + const size_t simd_n_in = tmac_params.simd_n_in; + const size_t simd_n_out = tmac_params.simd_n_out; + const size_t bm = tmac_params.bm; + const size_t kfactor = tmac_params.kfactor; + + assert(BlkLen % g == 0); + assert((BlkLen / g) % kfactor == 0); + + const size_t mgroup = ngroups_per_elem * simd_n_in; // 32 + assert(bm % mgroup == 0); + assert(bm % bits == 0); + + std::unique_ptr buf(new uint8_t[N * bits * (K / g)]); + memset(buf.get(), 0, N * bits * (K / g)); + + const size_t Iterations = N; // we parallelize over N, TODO:: tune if needed + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + size_t im = static_cast(tid); + for (size_t ik = 0; ik < K; ++ik) { + size_t idx = (im * K + ik); + size_t num_elem_per_byte = 8 / bits; + size_t elem_idx = idx % num_elem_per_byte; + + uint8_t v = ((const uint8_t*)QuantBDataBegin)[idx / num_elem_per_byte] >> (elem_idx * bits); + + for (size_t ib = 0; ib < bits; ++ib) { + size_t new_ik = ik / g; + size_t shft_left = ik % g; + buf[im * bits * K / g + ib * K / g + new_ik] += static_cast(((v >> ib) & 1) << shft_left); + } + } + } + ); + + // Now buf contains the bit planes grouped by g along K + // Next, we need to do a multi-reshape/transpose into the final layout + + const size_t c0_fac2 = K / g; + const size_t c0_fac1 = simd_n_out * c0_fac2; + const size_t c0_fac0 = bits * c0_fac1; + + const size_t c1_nb2 = K / g; + const size_t c1_nb1 = simd_n_in * c1_nb2; + const size_t c1_nb0 = ngroups_per_elem * c1_nb1; + const size_t c1_fac2 = K / g; + const size_t c1_fac1 = ngroups_per_elem * c1_fac2; + const size_t c1_fac0 = simd_n_in * c1_fac1; + + const size_t c2_nb4 = kfactor; + const size_t c2_nb3 = K / g / kfactor * c2_nb4; + const size_t c2_nb2 = ngroups_per_elem * c2_nb3; + const size_t c2_nb1 = simd_n_in * c2_nb2; + const size_t c2_nb0 = bm / mgroup * c2_nb1; + const size_t c2_fac3 = simd_n_in * ngroups_per_elem; + const size_t c2_fac2 = kfactor * c2_fac3; + const size_t c2_fac1 = bm / mgroup * c2_fac2; + const size_t c2_fac0 = K / g / kfactor * c2_fac1; + + const size_t PackedQuantBDataSize = (N * bits) * (K / g / ngroups_per_elem); + memset(PackedQuantBDataBegin, 0, PackedQuantBDataSize); // TODO: is this needed? + + // NOTE: The second packing loop is intentionally serialized to avoid data races. + // T-MAC packs multiple output features (N) into a single byte if ngroups_per_elem > 1. + // Parallelizing this across N would lead to concurrent bit-plane updates on the same memory location. + for (size_t im = 0; im < Iterations; im++) { + for (size_t ib = 0; ib < bits; ib++) { + for (size_t ik = 0; ik < K / g; ik++) { + // w = w.reshape(M // bits // simd_n_out, simd_n_out, bits, K // g).transpose(0, 2, 1, 3) + size_t new_im = im / simd_n_out; + size_t new_isno = im % simd_n_out; + size_t new_ib = ib; + size_t new_ik = ik; + size_t new_idx = new_im * c0_fac0 + new_ib * c0_fac1 + new_isno * c0_fac2 + new_ik; + + // w = w.reshape(M // mgroup, ngroups_per_elem, simd_n_in, K // g).transpose(0, 2, 1, 3) + new_im = new_idx / c1_nb0; + size_t new_ing = (new_idx % c1_nb0) / c1_nb1; + size_t new_isni = (new_idx % c1_nb1) / c1_nb2; + new_ik = (new_idx % c1_nb2); + new_idx = new_im * c1_fac0 + new_isni * c1_fac1 + new_ing * c1_fac2 + new_ik; + + // # 0 1 2 3 4 5 + // w = w.reshape(M // bm, bm // mgroup, simd_n_in, ngroups_per_elem, K // g // kfactor, kfactor).transpose(0, 4, 1, 5, 2, 3) + new_im = new_idx / c2_nb0; + size_t new_ibm = (new_idx % c2_nb0) / c2_nb1; + new_isni = (new_idx % c2_nb1) / c2_nb2; + new_ing = (new_idx % c2_nb2) / c2_nb3; + new_ik = (new_idx % c2_nb3) / c2_nb4; + size_t new_ikf = (new_idx % c2_nb4); + new_idx = new_im * c2_fac0 + + new_ik * c2_fac1 + + new_ibm * c2_fac2 + + new_ikf * c2_fac3 + + new_isni * ngroups_per_elem + + new_ing; + new_idx = new_idx / ngroups_per_elem; + size_t buf_idx = im * bits * K / g + ib * K / g + ik; + uint8_t buf_val = buf[buf_idx]; + + // w = sum([(w[:, :, :, :, :, ng] << (ng * g)) for ng in range(ngroups_per_elem)]) + PackedQuantBDataBegin[new_idx] = static_cast( + static_cast(PackedQuantBDataBegin[new_idx]) + + (buf_val << (new_ing * g)) + ); + } + } + } +} + +// Internal helper: calculates packed scales and zero points size in floats +static size_t +LutPackScalesAndZeroPointsSize( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint +) +{ + // TODO(vraspar): support one scale case + if (HasZeroPoint) { + return N * K / BlkLen * 2; + } else { + return N * K / BlkLen; + } +} + +// Internal helper: packs scales and zero points +static void +LutPackScalesAndZeroPoints( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + float* PackedQuantBZPBegin, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint +) +{ + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const size_t bits = tmac_params.bits; + const size_t simd_n_out = tmac_params.simd_n_out; + const size_t bm = tmac_params.bm; + const size_t num_elem_per_byte = 8 / bits; + + // ZP array is column-major packed, with per-column alignment to byte boundary + const size_t row_blks = K / BlkLen; // number of blocks per column + const size_t zp_bytes_per_col = (row_blks + num_elem_per_byte - 1) / num_elem_per_byte; + + for (size_t im = 0; im < N; im += 1) { + for (size_t ik = 0; ik < K; ik += BlkLen) { + size_t idx = (im * K + ik) / BlkLen; // linear block index for scale (scale is NOT packed) + float scale = QuantBScale[idx]; + float zp = 0.0f; + if (HasZeroPoint) { + size_t blk_in_col = ik / BlkLen; // block index within column + size_t zp_byte_idx = im * zp_bytes_per_col + blk_in_col / num_elem_per_byte; + size_t elem_idx = blk_in_col % num_elem_per_byte; + uint8_t v = (QuantBZeroPoint[zp_byte_idx] >> (elem_idx * bits)) & ((1 << bits) - 1); + + // The LUT kernel assumes weights are centered around the midpoint (2 for 2-bit). + // Thus, need to correct for the actual ZP relative to the midpoint. + + int midpoint = 1 << (bits - 1); // 2 for 2-bit + zp = static_cast(static_cast(v) - midpoint) * scale; + } + + // TODO(vraspar): fix when k < BlkLen and nb1 is 0 + size_t nb1 = K / BlkLen; + size_t nb0 = bm / bits * nb1; + + size_t new_im, new_ibm, new_ik; + if (nb1 == 0) { + new_im = 0; + new_ibm = 0; + new_ik = 0; + + } else { + new_im = idx / nb0; + new_ibm = (idx % nb0) / nb1; + new_ik = (idx % nb1); + } + + if (HasZeroPoint) { + size_t new_isimd = new_ibm % simd_n_out; + size_t new_idx_outer = new_im * bm / bits * K / BlkLen / simd_n_out + new_ik * bm / bits / simd_n_out + new_ibm / simd_n_out; + size_t new_idx_scale = new_idx_outer * (simd_n_out * 2) + new_isimd; + size_t new_idx_zero = new_idx_outer * (simd_n_out * 2) + simd_n_out + new_isimd; + + PackedQuantBZPBegin[new_idx_scale] = scale; + PackedQuantBZPBegin[new_idx_zero] = zp; + } else { + size_t new_idx = new_im * bm / bits * K / BlkLen + new_ik * bm / bits + new_ibm; + PackedQuantBZPBegin[new_idx] = scale; + } + } + } +} + +// Internal helper: calculates the offset to scales in the packed buffer +static size_t +LutGemmPackedScalesOffset( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +) +{ + constexpr size_t kAlignment = 64; // Cache line alignment + size_t packed_b_size = LutGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + return ((packed_b_size + kAlignment - 1) / kAlignment) * kAlignment; +} + +size_t MLASCALL +MlasLutGemmPackedSize( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint +) +{ + // Get packed B size (aligned) + size_t aligned_b_size = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + + // Get packed scales/zp size (in floats, convert to bytes) + size_t packed_scales_count = LutPackScalesAndZeroPointsSize(N, K, BlkLen, HasZeroPoint); + size_t packed_scales_bytes = packed_scales_count * sizeof(float); + + return aligned_b_size + packed_scales_bytes; +} + +void MLASCALL +MlasLutGemmPack( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + bool HasZeroPoint, + const std::byte* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + std::byte* PackedBuf, + MLAS_THREADPOOL* ThreadPool +) +{ + // Pack B data if provided + if (QuantBData != nullptr) { + LutGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, HasZeroPoint, QuantBData, PackedBuf, ThreadPool); + } + + // Pack scales/zero points if scales are provided + if (QuantBScale != nullptr) { + size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + float* scales_dest = reinterpret_cast(PackedBuf + scales_offset); + LutPackScalesAndZeroPoints(N, K, BlkBitWidth, BlkLen, HasZeroPoint, scales_dest, QuantBScale, QuantBZeroPoint); + } +} + +bool MLASCALL +MlasIsLutGemmAvailable( + size_t N, + size_t K, + size_t BlkBitWidth, + size_t BlkLen +) +{ + const auto* lut_kernel = GetMlasPlatform().LutGenKernel; + if (lut_kernel == nullptr || lut_kernel->GenerateLUT == nullptr || lut_kernel->ComputeGemm == nullptr) { + return false; + } + + // currently only 2-bit is supported + if (BlkBitWidth != 2 || BlkLen == 0 || (BlkLen % 32) != 0) { + return false; + } + + if (K % 32 != 0) { + return false; + } + + size_t n_div = 0; + switch (BlkBitWidth) { + case 1: + n_div = 256; + break; + case 2: + n_div = 128; + break; + case 3: + n_div = 64; + break; + case 4: + n_div = 32; + break; + default: + return false; + } + + if (N % n_div != 0) { + return false; + } + return true; +} + +size_t +CalculateLutBufferSize(size_t n, size_t k, size_t m, const MlasTMACKernelParams& tmac_params) +{ + MLAS_UNREFERENCED_PARAMETER(n); + const size_t lut_scales_size = k / tmac_params.act_group_size; + + // The AVX2 kernel (g=4) expects 16 entries (16 bytes) per group of 4 activations. + // This effectively requires 4 bytes per activation in the K dimension. + size_t lut_size_bytes = m * k * 4; + size_t scales_size_bytes = m * lut_scales_size * sizeof(float); + size_t biases_size_bytes = m * lut_scales_size * sizeof(float); + + return lut_size_bytes + scales_size_bytes + biases_size_bytes + 256; // + alignment/safety padding +} + +void MLASCALL +MlasLutGemm( + const void* A, + size_t BlkLen, + const void* PackedBuf, // Packed buffer containing weights followed by scales/zp + void* C, + size_t K, + size_t M, // batch size (number of rows in activation) + size_t N, + bool HasZeroPoint, + MLAS_THREADPOOL* threadpool +) +{ + // adapted from ggml_backend_tmac_mul_mat + const auto* Dispatch = GetMlasPlatform().LutGenKernel; + // This should be ensured by calling MlasIsLutGemmAvailable() before MlasLutGemm() + assert(Dispatch && Dispatch->GenerateLUT && "TMAC not supported in this configuration."); + + // Calculate scales offset from packed buffer + // TODO(vraspar): support other bitwidths + constexpr size_t BlkBitWidth = 2; + size_t scales_offset = LutGemmPackedScalesOffset(N, K, BlkBitWidth, BlkLen, HasZeroPoint); + const auto* QuantBData = PackedBuf; + const auto* QuantBScale = reinterpret_cast( + static_cast(PackedBuf) + scales_offset + ); + + /** TODO(vraspar): The biases_float and scales float values don't make sense + * FP 16 + * QLUT K(ne10) x M(ne11) x 4 bytes + * Scales: lut_scales_size * M * 2 bytes + * Biases: lut_scales_size * M * 2 bytes + * Needs FP 16 conversion Buffer: max(K, N) * M * 2 bytes + * + * FP 32 + * QLUT K x M x 4 bytes + * Scales: lut_scales_size * M * 4 bytes + * Biases: lut_scales_size * M * 4 bytes + * + * Currently, we only support FP32, add FP16 support later which requires conversion buffer + * + * LUT Buffer for FP32 : K * M * 4 * sizeof(uint8_t) bytes + lut_scale_size * m * 2 * sizeof(float) bytes + allignment + * + */ + + // n_tiles_num = m * bits / bm; + + // TODO(vraspar): support other bitwidths + // For T-MAC, kernel properties (bm, n_tiles_num) are primarily driven by the number of output features (N). + // Initialization during packing (LutGemmPackQuantBDataSize) uses N as the major dimension, + // so we must match that here to ensure consistent weight tiling. + MlasInitLutGemmKernelConfig(N, K, 2, BlkLen, HasZeroPoint); + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(N, K, 2, BlkLen, HasZeroPoint); + const size_t lut_scales_size = K / tmac_params.act_group_size; + const size_t lut_size_bytes = static_cast(M) * static_cast(K) * 4; + size_t lut_buffer_size = CalculateLutBufferSize(N, K, M, tmac_params); + + // make buffer of lut_buffer_size bytes + // TODO(vraspar): other way to do it + auto lut_buffer = std::make_unique(lut_buffer_size); + memset(lut_buffer.get(), 0, lut_buffer_size); + + int8_t* qlut = reinterpret_cast(lut_buffer.get()); + float* lut_scales = reinterpret_cast(qlut + lut_size_bytes); // after lut + float* lut_biases = reinterpret_cast(lut_scales + lut_scales_size * M); // after scales + + const auto* a_float = reinterpret_cast(A); // Activation data + + // const int num_groups = static_cast(K / BlkLen); + + // Iterate over M (batch dimension) + // Each iteration processes one row of the activation matrix. + // NOTE: This loop is intentionally serialized. Previous attempts to parallelize + // using MlasTrySimpleParallel caused flaky test failures (race conditions) + // when M > 1 (e.g., Batch32 case). Since GenerateLUT is lightweight, + // serial execution ensures correctness with negligible performance impact. + // TODO(vraspar): Ideally we have to do block parallelism here + + for (size_t ine11 = 0; ine11 < static_cast(M); ine11++) { + const size_t row_offset = ine11 * K; + // Call the LUT generation kernel for this activation row. + // We use a 4-byte stride (per activation) for the LUT entries to satisfy + // the memory layout requirements of the computation kernel. + const size_t lut_offset = ine11 * K * 4; + const size_t scale_bias_offset = ine11 * lut_scales_size; + + Dispatch->GenerateLUT( + const_cast(a_float + row_offset), // Input activation for this row + qlut + lut_offset, // Output LUT for this row + lut_scales + scale_bias_offset, // Scales for this row + lut_biases + scale_bias_offset, // Biases for this row + M, + K, + N, + tmac_params.act_group_size, + tmac_params.act_group_size * 4 + ); + } + + // all relevant LUT's have been generated + // equivalent of lut_mul_mat's ggml_backend_tmac_mul_mat function ggml_barrier line + + const size_t n_tiles_num = tmac_params.n_tiles_num; + assert(N % n_tiles_num == 0); + + const size_t bits = tmac_params.bits; + + // Pre-calculate sizes for offset calculations + const size_t w_size = N * K * bits / 8; + const size_t w_chunk_size = w_size / n_tiles_num; + + // TODO: fix the below 4 + // Matrix multiplication: Output[N×M] = QuantBData[N×K] × Weights[K×M] + const size_t OutputRows = N; // Number of output features + const size_t OutputCols = M; // Batch size + + const size_t ChunkSize0 = N / n_tiles_num; + const size_t ChunkSize1 = tmac_params.chunk_n; // process one batch item at a time + + // In llama.cpp terminology (note the swap!): + // ne0 = M (output features, called "n" in llama.cpp) + // ne1 = N (batch size, called "m" in llama.cpp) + + // Calculate number of chunks in each dimension + const size_t nchunk0 = (OutputRows + ChunkSize0 - 1) / ChunkSize0; // Should equal NumTiles + const size_t nchunk1 = (OutputCols + ChunkSize1 - 1) / ChunkSize1; + const size_t total_chunks = nchunk0 * nchunk1; + + // TODO(vraspar): support one_scale case + // Determine weight-scale layout. These should be provided by the caller or inferred from the packed weights. + // For now we default to per-group symmetric quantization (no zero-point, not one-scale). + + const size_t scales_size_total = LutPackScalesAndZeroPointsSize( + static_cast(N), + static_cast(K), + BlkLen, + tmac_params.has_zero_point + ); + + // Per-tile scales size = total scales size divided evenly across tiles. + // If one_scale is true we do not advance the scales pointer per tile, so set per tile size to 0 + size_t scales_size_per_tile = 0; + + if (scales_size_total % n_tiles_num != 0) { + // Sanity: scales should partition evenly across tiles. If they don't, choose floor division + // and document that callers must layout scales accordingly. + // Prefer to error loudly in debug builds. + fprintf(stderr, "Warning: scales_size_total=%zu is not divisible by n_tiles_num=%zu; using floor division.\n", scales_size_total, n_tiles_num); + } + scales_size_per_tile = scales_size_total / n_tiles_num; + + // Note: when one_scale == true, callers should pass a pointer to a single scale value (scales_offset=0 will be used) + + // Cast to appropriate types + const auto* packed_weights = reinterpret_cast(QuantBData); + float* act_output = reinterpret_cast(C); + + // Parallelize over the 2D chunk grid + MlasTrySimpleParallel( + threadpool, + total_chunks, + [&](ptrdiff_t current_chunk) { + // Decompose linear chunk index into 2D coordinates + const size_t ith0 = current_chunk % nchunk0; // Chunk in dimension 0 (output rows) + const size_t ith1 = current_chunk / nchunk0; // Chunk in dimension 1 (batch) + + // Calculate ranges for this chunk + const size_t ir0_start = ChunkSize0 * ith0; + const size_t ir0_end = std::min(ir0_start + ChunkSize0, OutputRows); + + const size_t ir1_start = ChunkSize1 * ith1; + const size_t ir1_end = std::min(ir1_start + ChunkSize1, OutputCols); + + // Process all tiles in dimension 0 for this chunk + for (size_t ichunk0 = ir0_start / ChunkSize0; ichunk0 < ir0_end / ChunkSize0; ichunk0++) { + // Calculate weight offsets + const size_t w_offset = ichunk0 * w_chunk_size; + const size_t scales_offset = ichunk0 * scales_size_per_tile; + + // Process all batch items in this chunk + for (size_t ine11 = ir1_start; ine11 < ir1_end; ine11++) { + // Calculate LUT offsets with 4-byte stride (per activation) for consistent access. + const size_t qlut_offset = K * ine11 * 4; + const size_t lut_scales_offset = lut_scales_size * ine11; + + // Calculate output offset + const size_t dst_offset = OutputRows * ine11 + ichunk0 * ChunkSize0; + + // Call the dispatch function to compute this tile. + // We pass one batch item at a time (M=1) and ChunkSize0 output features. + // TotalN is passed specifically to allow the kernel to find the correct + // parameters (bm, tiles) used during weight packing. + Dispatch->ComputeGemm( + packed_weights + w_offset, // Weight tile + QuantBScale + scales_offset, // Weight scales for this tile + qlut + qlut_offset, // LUT for this batch row + lut_scales + lut_scales_offset, // LUT scales + lut_biases + lut_scales_offset, // LUT biases + act_output + dst_offset, // Output location + static_cast(K), // K dimension + static_cast(1), // M dimension (batch size = 1) + static_cast(ir0_end - ir0_start), // N dimension (output features in chunk) + static_cast(N), // TotalN (total output features in weights) + BlkLen, // Weight quantization group size + HasZeroPoint // Whether zero points are used + ); + } + } + } + ); +} diff --git a/onnxruntime/core/mlas/lib/qlutgemm.h b/onnxruntime/core/mlas/lib/qlutgemm.h new file mode 100644 index 0000000000000..0a733199ea2e8 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qlutgemm.h @@ -0,0 +1,90 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qlutgemm.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing LUT-based GEMM. +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +/** + * @brief Parameters for TMAC kernel + */ +struct MlasTMACKernelParams { + size_t g; + size_t ngroups_per_elem; + size_t q_group_size; + size_t act_group_size; + + size_t kfactor; + size_t bits; + size_t actk; + size_t bm; + size_t simd_n_in; + size_t simd_n_out; + size_t chunk_n; + size_t n_tiles_num; + + bool has_scale; + bool has_zero_point; + bool one_scale; +}; + +/** + * Retrieves the T-MAC kernel configuration for a given GEMM problem. + * Returns the parameters by value to ensure thread-safety across concurrent calls. + */ +MlasTMACKernelParams +MlasGetLutGemmKernelParams(size_t M, size_t N, size_t nbits, size_t block_size, bool has_zero_point); + +typedef void(MLAS_QNBIT_GEMM_LUT_GEN)( + const float* b, + int8_t* qlut, + float* lut_scales, + float* lut_biases, + size_t M, + size_t K, + size_t N, + size_t act_group_size, + size_t lut_stride // Stride (in bytes) between consecutive LUT entries along the batch dimension. +); + +typedef void(MLAS_QNBIT_LUT_GEMM_COMPUTE)( + const uint8_t* A, + const float* Scales, + const int8_t* LUT, + const float* LUT_Scales, + const float* LUT_Biases, + float* C, + int K, + int M, // Batch size (current activation rows). + int N, // Number of output features to compute in this tile/chunk. + int TotalN, // Total number of output features in the weights (used for parameter mapping). + size_t BlkLen, + bool HasZeroPoint +); + +// +// Kernel dispatch structure. +// +// NOTE: This name must match the forward declaration in mlasi.h: +// struct MLAS_QNBIT_LUT_GEMM_DISPATCH; +// Keep it minimal for now; extend with function pointers as kernels are added. +struct MLAS_QNBIT_LUT_GEMM_DISPATCH { + // Intentionally empty placeholder; add members as needed. + MLAS_QNBIT_GEMM_LUT_GEN* GenerateLUT = nullptr; + + MLAS_QNBIT_LUT_GEMM_COMPUTE* ComputeGemm = nullptr; +}; diff --git a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp index 3a93723fc3b52..e611009733fbf 100644 --- a/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp +++ b/onnxruntime/core/mlas/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -150,8 +150,8 @@ RopeKernel_Fp16_Impl( if (i + 15 < dim) { float16x8_t x0 = MlasLoadFloat16x8(input + i); float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); - float16x8_t sin_val = MlasLoadFloat16x8(sin + i); - float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i / 2); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i / 2); for (; i + 31 < dim; i += 16) { float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -163,8 +163,8 @@ RopeKernel_Fp16_Impl( MlasStoreFloat16x8(output + i + 8, y1); x0 = MlasLoadFloat16x8(input + i + 16); x1 = MlasLoadFloat16x8(input + i + 24); - sin_val = MlasLoadFloat16x8(sin + i + 16); - cos_val = MlasLoadFloat16x8(cos + i + 16); + sin_val = MlasLoadFloat16x8(sin + (i + 16) / 2); + cos_val = MlasLoadFloat16x8(cos + (i + 16) / 2); } float16x8_t real = vuzp1q_f16(x0, x1); float16x8_t imag = vuzp2q_f16(x0, x1); @@ -181,8 +181,8 @@ RopeKernel_Fp16_Impl( float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); float16x4_t real = vuzp1_f16(x0, x1); float16x4_t imag = vuzp2_f16(x0, x1); - float16x4_t sin_val = MlasLoadFloat16x4(sin + i); - float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i / 2); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i / 2); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); float16x4_t y0 = vzip1_f16(real_out, imag_out); @@ -201,12 +201,12 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); - cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i / 2 + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i / 2 + 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -224,10 +224,10 @@ RopeKernel_Fp16_Impl( imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); - cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i / 2 + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i / 2 + 1, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); @@ -241,8 +241,8 @@ RopeKernel_Fp16_Impl( float16x4_t cos_val = MlasZeroFloat16x4(); real = MlasLoadLaneFloat16x4<0>(input + i, real); imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); - sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); - cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i / 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i / 2, cos_val); float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); MlasStoreLaneFloat16x4<0>(output + i, real_out); diff --git a/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp new file mode 100644 index 0000000000000..f41b380b2a071 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbconv_kernel_neon.cpp @@ -0,0 +1,110 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sbconv_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision convolution kernels for ARM NEON. + +--*/ + +#if defined(MLAS_USE_ARM_NEON_NCHWC) && defined(__linux__) + +#include "mlasi.h" +#include "sconv_nchwc_kernel_neon.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +// +// BF16 Pointwise (1x1) Convolution Kernel using SBGEMM. +// +void MLASCALL +MlasConvPointwiseBf16KernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags +) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + + // SBGEMM only adds bias when ZeroMode=true. When accumulating (ZeroMode=false), + // pre-add bias to existing output before the GEMM operations. + if (BiasAddition && AccumulateOutput) { + for (size_t f = 0; f < FilterCount; f++) { + float* output = Output + f * OutputStrideElements; + const float32x4_t b0 = MlasLoadFloat32x4(&Bias[f * BlockSize]); + const float32x4_t b1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]); + const float32x4_t b2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]); + const float32x4_t b3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]); + for (size_t i = 0; i < OutputCount; i++) { + MlasStoreFloat32x4(&output[i * BlockSize], MlasAddFloat32x4(b0, MlasLoadFloat32x4(&output[i * BlockSize]))); + MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasAddFloat32x4(b1, MlasLoadFloat32x4(&output[i * BlockSize + 4]))); + MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasAddFloat32x4(b2, MlasLoadFloat32x4(&output[i * BlockSize + 8]))); + MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasAddFloat32x4(b3, MlasLoadFloat32x4(&output[i * BlockSize + 12]))); + } + } + } + + // Build SBGEMM params for all (filter, input_channel) combinations. + // FilterCount <= 4, InputChannels <= 8, so max 32 elements. + // Bias is set on all elements but SBGEMM only uses it when ZeroMode=true. + MLAS_SBGEMM_DATA_PARAMS gemm_params[32]; + + size_t idx = 0; + for (size_t f = 0; f < FilterCount; f++) { + const float* filter = Filter + f * FilterStrideElements; + float* output = Output + f * OutputStrideElements; + for (size_t ic = 0; ic < InputChannels; ic++, idx++) { + gemm_params[idx].A = Input + ic * InputStrideElements; + gemm_params[idx].B = filter + ic * BlockSize * BlockSize; + gemm_params[idx].C = output; + gemm_params[idx].lda = StrideWidthElements; + gemm_params[idx].ldb = BlockSize; + gemm_params[idx].ldc = BlockSize; + gemm_params[idx].Bias = BiasAddition ? (Bias + f * BlockSize) : nullptr; + gemm_params[idx].AIsfp32 = true; + gemm_params[idx].BIsfp32 = true; + gemm_params[idx].ZeroMode = (ic == 0) && !AccumulateOutput; + gemm_params[idx].OutputProcessor = nullptr; + } + } + + MlasSBGemmBatch(OutputCount, BlockSize, BlockSize, idx, gemm_params, nullptr); + + if (ReluActivation) { + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + for (size_t f = 0; f < FilterCount; f++) { + float* output = Output + f * OutputStrideElements; + for (size_t i = 0; i < OutputCount; i++) { + MlasStoreFloat32x4(&output[i * BlockSize], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize]), ZeroVector)); + MlasStoreFloat32x4(&output[i * BlockSize + 4], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 4]), ZeroVector)); + MlasStoreFloat32x4(&output[i * BlockSize + 8], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 8]), ZeroVector)); + MlasStoreFloat32x4(&output[i * BlockSize + 12], MlasMaximumFloat32x4(MlasLoadFloat32x4(&output[i * BlockSize + 12]), ZeroVector)); + } + } + } +} + +#endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h index de7fd72fad45a..5415cb3dc4406 100644 --- a/onnxruntime/core/mlas/lib/sbgemm.h +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -112,7 +112,7 @@ MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, template MLAS_FORCEINLINE void -MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode) { constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; size_t PackedStrideN = Strides.N; @@ -131,7 +131,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size // size_t CountK; for (size_t k = 0; k < K; k += CountK) { - bool ZeroMode = (k == 0); + bool ZeroMode = (k == 0) && InitialZeroMode; CountK = std::min(K - k, PackedStrideK); const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; @@ -148,7 +148,7 @@ MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size template void -MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor, bool InitialZeroMode) { // // Compute the strides to step through slices of the input matrices. @@ -201,7 +201,7 @@ MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_ const float* pbias = ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart - bool ZeroMode = (k == 0); + bool ZeroMode = (k == 0) && InitialZeroMode; MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); } if (PostProcessor != nullptr) { @@ -249,16 +249,17 @@ MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const float* A = (const float*)DataParams->A + RangeStartM * lda; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* bias = DataParams->Bias; + const bool zeroMode = DataParams->ZeroMode; if (!DataParams->BIsfp32) { MlasSBGemmPackedOperation( RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, - lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode ); } else { const size_t ldb = DataParams->ldb; const float* B = (const float*)DataParams->B + RangeStartN; - MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor, zeroMode); } } diff --git a/onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp new file mode 100644 index 0000000000000..14b6b30c85bda --- /dev/null +++ b/onnxruntime/core/mlas/lib/sconv_nchw_kernel_neon.cpp @@ -0,0 +1,297 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv_nchw_kernel_neon.cpp + +Abstract: + + This module implements the single precision NCHW convolution kernels for ARM NEON. + +--*/ + + +#include "mlasi.h" +#include + +MLAS_FORCEINLINE float DepthwiseSampleValue( + const float* row, + ptrdiff_t col, + size_t width +) +{ + if (row == nullptr || col < 0 || col >= static_cast(width)) { + return 0.0f; + } + return row[col]; +} + +MLAS_FORCEINLINE float DepthwiseAccumulateRowScalar( + float acc, + const float* row, + size_t base, + float w0, + float w1, + float w2 +) +{ + if (row == nullptr) { + return acc; + } + + acc += row[base] * w0; + acc += row[base + 1] * w1; + acc += row[base + 2] * w2; + return acc; +} + +MLAS_FORCEINLINE void DepthwiseAccumulateRowVector( + float32x4_t& acc, + const float* row, + size_t base, + float w0, + float w1, + float w2 +) +{ + if (row == nullptr) { + return; + } + + const float* r = row + base; + const float32x4_t c0 = MlasLoadFloat32x4(r); + const float32x4_t c1 = MlasLoadFloat32x4(r + 1); + const float32x4_t c2 = MlasLoadFloat32x4(r + 2); + + acc = MlasMultiplyAddFloat32x4(c0, w0, acc); + acc = MlasMultiplyAddFloat32x4(c1, w1, acc); + acc = MlasMultiplyAddFloat32x4(c2, w2, acc); +} + +MLAS_FORCEINLINE float DepthwiseComputeEdge( + const float* row0, + const float* row1, + const float* row2, + ptrdiff_t iw, + size_t width, + const float w00, + const float w01, + const float w02, + const float w10, + const float w11, + const float w12, + const float w20, + const float w21, + const float w22 +) +{ + float acc = 0.0f; + const ptrdiff_t c0 = iw; + const ptrdiff_t c1 = iw + 1; + const ptrdiff_t c2 = iw + 2; + + acc += DepthwiseSampleValue(row0, c0, width) * w00; + acc += DepthwiseSampleValue(row0, c1, width) * w01; + acc += DepthwiseSampleValue(row0, c2, width) * w02; + acc += DepthwiseSampleValue(row1, c0, width) * w10; + acc += DepthwiseSampleValue(row1, c1, width) * w11; + acc += DepthwiseSampleValue(row1, c2, width) * w12; + acc += DepthwiseSampleValue(row2, c0, width) * w20; + acc += DepthwiseSampleValue(row2, c1, width) * w21; + acc += DepthwiseSampleValue(row2, c2, width) * w22; + + return acc; +} + +static void DepthwiseConv3x3Stride1PadLe1Neon( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + float* Output +) +{ + const size_t H = Parameters->InputShape[0]; + const size_t W = Parameters->InputShape[1]; + const size_t out_rows = Parameters->OutputShape[0]; + const size_t out_cols = Parameters->OutputShape[1]; + + const size_t pad_top = Parameters->Padding[0]; + const size_t pad_left = Parameters->Padding[1]; + const size_t pad_right = Parameters->Padding[3]; + + const float beta = Parameters->Beta; + const bool accumulate_output = beta != 0.0f; + + const float w00 = Filter[0]; + const float w01 = Filter[1]; + const float w02 = Filter[2]; + const float w10 = Filter[3]; + const float w11 = Filter[4]; + const float w12 = Filter[5]; + const float w20 = Filter[6]; + const float w21 = Filter[7]; + const float w22 = Filter[8]; + + for (size_t oh = 0; oh < out_rows; ++oh) { + const ptrdiff_t ih = static_cast(oh) - static_cast(pad_top); + + const ptrdiff_t row0_index = ih; + const ptrdiff_t row1_index = ih + 1; + const ptrdiff_t row2_index = ih + 2; + + const float* row0 = nullptr; + const float* row1 = nullptr; + const float* row2 = nullptr; + + if (row0_index >= 0 && row0_index < static_cast(H)) { + row0 = Input + static_cast(row0_index) * W; + } + if (row1_index >= 0 && row1_index < static_cast(H)) { + row1 = Input + static_cast(row1_index) * W; + } + if (row2_index >= 0 && row2_index < static_cast(H)) { + row2 = Input + static_cast(row2_index) * W; + } + + float* out_row = Output + oh * out_cols; + size_t ow = 0; + + if (pad_left && ow < out_cols) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + float acc = DepthwiseComputeEdge( + row0, row1, row2, iw, W, + w00, w01, w02, w10, w11, w12, w20, w21, w22 + ); + if (accumulate_output) { + acc += beta * out_row[ow]; + } + out_row[ow++] = acc; + } + + size_t interior_cols = 0; + if (out_cols > pad_left + pad_right) { + interior_cols = out_cols - pad_left - pad_right; + } + + size_t processed = 0; + while (processed + 4 <= interior_cols) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + if ((iw + 5) >= static_cast(W)) { + break; + } + + const size_t base = static_cast(iw); + float32x4_t acc = MlasZeroFloat32x4(); + + DepthwiseAccumulateRowVector(acc, row0, base, w00, w01, w02); + DepthwiseAccumulateRowVector(acc, row1, base, w10, w11, w12); + DepthwiseAccumulateRowVector(acc, row2, base, w20, w21, w22); + + if (accumulate_output) { + const float32x4_t prev = MlasLoadFloat32x4(out_row + ow); + acc = MlasMultiplyAddFloat32x4(prev, beta, acc); + } + + MlasStoreFloat32x4(out_row + ow, acc); + ow += 4; + processed += 4; + } + + for (; processed < interior_cols; ++processed) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + const size_t base = static_cast(iw); + + float acc = 0.0f; + acc = DepthwiseAccumulateRowScalar(acc, row0, base, w00, w01, w02); + acc = DepthwiseAccumulateRowScalar(acc, row1, base, w10, w11, w12); + acc = DepthwiseAccumulateRowScalar(acc, row2, base, w20, w21, w22); + + if (accumulate_output) { + acc += beta * out_row[ow]; + } + out_row[ow++] = acc; + } + + if (pad_right && ow < out_cols) { + const ptrdiff_t iw = static_cast(ow) - static_cast(pad_left); + float acc = DepthwiseComputeEdge( + row0, row1, row2, iw, W, + w00, w01, w02, w10, w11, w12, w20, w21, w22 + ); + if (accumulate_output) { + acc += beta * out_row[ow]; + } + out_row[ow++] = acc; + } + } +} + +static +void +MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + float* Output + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute convolution on one channel input with one filter channel. + +Arguments: + + Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc. + + Input - input channel data start. Input is NCHW, so this pointer points to single H x W image data. + + Filter - Whole filters are of F x CpG x FH x FW, this filter points to single FH x FW filter data. + + Output - whole output are of N x F x OH x OW. This pointer points to single OH x OW output image data. + +--*/ +{ + DepthwiseConv3x3Stride1PadLe1Neon(Parameters, Input, Filter, Output); +} + +void MlasConvDepthwiseFloat_CHW( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + float* Output, + const float* Zeros + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute depthwise convolution for one filter channel on one input channel. + +Arguments: + + Parameters - conv parameters calculated based on conv parameters like padding, strides, dilations, etc. + + Input - input channel data start. Input is NCHW, so this pointer point to single H x W image data. + + Filter - Whole filters are of F x CpG x FH x FW, this filter points to single FH x FW filter data. + + Output - whole output are of N x F x OH x OW. This pointer point to single OH x OW output image data. + + Zeros - Point to working buffer where all 0.0f are filled. + +Note: + No checking here as it is inner loop. Logic in generating Parameters controls the check. + + Currently only support 2d kernel 3x3 with strides=1, dilations=1, pads<=1. + Will add general case and more special case if needed later. + +--*/ +{ + MLAS_UNREFERENCED_PARAMETER(Zeros); + MlasConv2dSingleChannel_CHW_Kernel3x3_Pad01_Dilation1(Parameters, Input, Filter, Output); +} diff --git a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp similarity index 99% rename from onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp rename to onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp index 0b6538eb06379..745258080810a 100644 --- a/onnxruntime/core/mlas/lib/sconv_kernel_neon.cpp +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.cpp @@ -6,18 +6,18 @@ Licensed under the MIT License. Module Name: - sconv_kernel_neon.cpp + sconv_nchwc_kernel_neon.cpp Abstract: - This module implements the single precision convolution kernels for ARM NEON. + This module implements the single precision NCHWC convolution kernels for ARM NEON. --*/ #if defined(MLAS_USE_ARM_NEON_NCHWC) #include "mlasi.h" -#include "sconv.h" +#include "sconv_nchwc_kernel_neon.h" constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; diff --git a/onnxruntime/core/mlas/lib/sconv.h b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h similarity index 96% rename from onnxruntime/core/mlas/lib/sconv.h rename to onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h index 99b2ad3130adf..10bee4b19766b 100644 --- a/onnxruntime/core/mlas/lib/sconv.h +++ b/onnxruntime/core/mlas/lib/sconv_nchwc_kernel_neon.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sconv.h + sconv_nchwc_kernel_neon.h Abstract: diff --git a/onnxruntime/core/mlas/lib/sgemm.cpp b/onnxruntime/core/mlas/lib/sgemm.cpp index 02e38b6ef432e..7117f20b82ce5 100644 --- a/onnxruntime/core/mlas/lib/sgemm.cpp +++ b/onnxruntime/core/mlas/lib/sgemm.cpp @@ -1573,10 +1573,10 @@ MlasGemmBatch( ) { // Override - if(GetMlasPlatform().MlasGemmBatchOverride != nullptr && + if(GetMlasPlatform().MlasSGemmBatchOverride != nullptr && // TODO: Remove once KAI supports transposing for A TransA != CBLAS_TRANSPOSE::CblasTrans && - GetMlasPlatform().MlasGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchSize, ThreadPool)){ + GetMlasPlatform().MlasSGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchSize, ThreadPool)){ return; } // @@ -1671,12 +1671,12 @@ Return Value: // // KleidiAI or other override #if defined(USE_KLEIDIAI) - if (GetMlasPlatform().MlasGemmPackBSizeOverride != nullptr && + if (GetMlasPlatform().MlasSGemmPackBSizeOverride != nullptr && // TODO: Remove once KAI supports transposing for A TransA != CBLAS_TRANSPOSE::CblasTrans) { size_t bytes_required; //TODO pass status by reference to indicate success/fail - bytes_required = GetMlasPlatform().MlasGemmPackBSizeOverride(TransA, TransB, N, K); + bytes_required = GetMlasPlatform().MlasSGemmPackBSizeOverride(TransA, TransB, N, K); if (bytes_required != 0){// If ArmKleidiAI::MlasGemmPackBSize ran to completion return bytes_required; } @@ -1738,10 +1738,10 @@ Return Value: --*/ { #if defined(USE_KLEIDIAI) - if (GetMlasPlatform().MlasGemmPackBOverride != nullptr && + if (GetMlasPlatform().MlasSGemmPackBOverride != nullptr && // TODO: Remove once KAI supports transposing for A TransA != CBLAS_TRANSPOSE::CblasTrans && - GetMlasPlatform().MlasGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + GetMlasPlatform().MlasSGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ return; } #endif diff --git a/onnxruntime/core/mlas/lib/snchwc.cpp b/onnxruntime/core/mlas/lib/snchwc.cpp index 6f3423a792509..505246841087c 100644 --- a/onnxruntime/core/mlas/lib/snchwc.cpp +++ b/onnxruntime/core/mlas/lib/snchwc.cpp @@ -53,6 +53,7 @@ struct MLAS_NCHWC_CONV_WORK_BLOCK : MLAS_NCHWC_WORK_BLOCK float* Output; size_t GroupCount; bool ZeroMode; + bool UseBf16; }; // @@ -881,6 +882,11 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM #if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; +#if defined(__aarch64__) && defined(__linux__) + if (WorkBlock->UseBf16) { + Kernel = GetMlasPlatform().ConvPointwiseBf16Kernel; + } +#endif #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; #endif @@ -1224,7 +1230,8 @@ MlasNchwcConv( float* Output, const MLAS_ACTIVATION* Activation, bool ZeroMode, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + const bool UseBf16 ) /*++ @@ -1269,6 +1276,8 @@ Routine Description: ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. + UseBf16 - Supplies true to use BF16 for convolutions on supported platforms. + Return Value: None. @@ -1288,6 +1297,7 @@ Return Value: WorkBlock.Bias = Bias; WorkBlock.Activation = Activation; WorkBlock.ZeroMode = ZeroMode; + WorkBlock.UseBf16 = UseBf16; // // Capture the generic shape parameters to the work block. diff --git a/onnxruntime/core/mlas/lib/spool_kernel_neon.cpp b/onnxruntime/core/mlas/lib/spool_nchwc_kernel_neon.cpp similarity index 100% rename from onnxruntime/core/mlas/lib/spool_kernel_neon.cpp rename to onnxruntime/core/mlas/lib/spool_nchwc_kernel_neon.cpp diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp new file mode 100644 index 0000000000000..7e4df13423be2 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.cpp @@ -0,0 +1,737 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_lut_kernel_avx2.cpp + +Abstract: + + This module implements x64 AVX2 kernel functions for LUT-based quantized + n-bit integer matrix multiplication. + + It provides optimized AVX2 implementations for lookup table generation, + GEMM computation, and related operations on quantized weight and activation + matrices. + +--*/ + +#include +#include +#include +// AVX2 intrinsics +#include + +#include "qlutgemm.h" +#include "qnbitgemm.h" +#include "sqnbitgemm_q8_block.h" + +static inline float +_mm256_addv_ps(const __m256 v) +{ + __m128 res = _mm256_extractf128_ps(v, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(v)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// Conditional pragma unroll for compiler compatibility +#if defined(__INTEL_COMPILER) || defined(__clang__) +#define PRAGMA_UNROLL _Pragma("unroll") +#else +#define PRAGMA_UNROLL +#endif + +// Helper macros for extracting and widening vectors +#define extract_low_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_castsi256_si128(v)) +#define extract_high_epi8_epi16(v) _mm256_cvtepi8_epi16(_mm256_extracti128_si256(v, 1)) +#define extract_low_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_castsi256_si128(v)) +#define extract_high_epi16_epi32(v) _mm256_cvtepi16_epi32(_mm256_extracti128_si256(v, 1)) + +// Template classes for accumulation +template +struct SignedHalvingAdder { + SignedHalvingAdder adder; + __m256i lhs = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) + { + if (k < N / 2) { + adder.push(v, k); + if (k == N / 2 - 1) { + lhs = adder.get(); + } + } else { + adder.push(v, k - N / 2); + if (k == N - 1) { + lhs = _mm256_avg_epu8(lhs, adder.get()); + } + } + } + + inline __m256i get() + { + return lhs; + } + + inline __m256i get_low() + { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() + { + return extract_high_epi8_epi16(lhs); + } +}; + +template <> +struct SignedHalvingAdder<2> { + __m256i lhs = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) + { + if (k == 0) { + lhs = v; + } else { + lhs = _mm256_avg_epu8(lhs, v); + } + } + + inline __m256i get() + { + return lhs; + } + + inline __m256i get_low() + { + return extract_low_epi8_epi16(lhs); + } + + inline __m256i get_high() + { + return extract_high_epi8_epi16(lhs); + } +}; + +template +struct SignedWideningAdder { + __m256i lhs_low = _mm256_setzero_si256(); + __m256i lhs_high = _mm256_setzero_si256(); + + inline void push(__m256i v, int k) + { + if (k == 0) { + lhs_low = extract_low_epi8_epi16(v); + lhs_high = extract_high_epi8_epi16(v); + } else { + lhs_low = _mm256_add_epi16(lhs_low, extract_low_epi8_epi16(v)); + lhs_high = _mm256_add_epi16(lhs_high, extract_high_epi8_epi16(v)); + } + } + + inline __m256i get_low() + { + return lhs_low; + } + + inline __m256i get_high() + { + return lhs_high; + } +}; + +template +using SignedAdder = typename std::conditional, SignedWideningAdder>::type; + +// Template for computing log2 at compile time +template +struct mylog2 { + enum { + value = 1 + mylog2::value + }; +}; + +template <> +struct mylog2<0> { + enum { + value = -1 + }; +}; + +// Template for computing bias scale at compile time +template +constexpr int +get_bias_scale() +{ + // The bias scale will be added to the first bit + // 15 = (1/2 + 1 + 2 + 4) / (1/2) + // 7 = (1/2 + 1 + 2) / (1/2) + // 3 = (1/2 + 1) / (1/2) + // 1 = (1/2) / (1/2) + // if constexpr (bits == 4) { + // return 15; + // } else if constexpr (bits == 3) { + // return 7; + // } else if constexpr (bits == 2) { + // return 3; + // } else if constexpr (bits == 1) { + // return 1; + // } else { + // return 0; + // } + return 3; +} + +static inline void +MlasAvx2LoaduDeinterleave32Ps(const float* src, __m256& v0, __m256& v1, __m256& v2, __m256& v3) +{ + // Process 32 activations contiguously using loadu + shuffle. + // This allows us to mix neighbors (src[4i], src[4i+1], src[4i+2], src[4i+3]) across lanes, + // which matches the T-MAC weight packing. + // We use loadu + shuffle instead of gather to avoid potential issues with gather + // on some hardware and ensure deterministic behavior. + __m256 vec_b0 = _mm256_loadu_ps(src + 0); + __m256 vec_b1 = _mm256_loadu_ps(src + 8); + __m256 vec_b2 = _mm256_loadu_ps(src + 16); + __m256 vec_b3 = _mm256_loadu_ps(src + 24); + + __m256 t0 = _mm256_unpacklo_ps(vec_b0, vec_b1); + __m256 t1 = _mm256_unpackhi_ps(vec_b0, vec_b1); + __m256 t2 = _mm256_unpacklo_ps(vec_b2, vec_b3); + __m256 t3 = _mm256_unpackhi_ps(vec_b2, vec_b3); + + __m256 u0 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2))); + __m256 u1 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t0), _mm256_castps_pd(t2))); + __m256 u2 = _mm256_castpd_ps(_mm256_unpacklo_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3))); + __m256 u3 = _mm256_castpd_ps(_mm256_unpackhi_pd(_mm256_castps_pd(t1), _mm256_castps_pd(t3))); + + const __m256i perm_idx = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7); + v0 = _mm256_permutevar8x32_ps(u0, perm_idx); + v1 = _mm256_permutevar8x32_ps(u1, perm_idx); + v2 = _mm256_permutevar8x32_ps(u2, perm_idx); + v3 = _mm256_permutevar8x32_ps(u3, perm_idx); +} + +void +partial_max_g4_int8_k8(float* lut_scales, const float* b) +{ + __m256 vec_b0, vec_b1, vec_b2, vec_b3; + MlasAvx2LoaduDeinterleave32Ps(b, vec_b0, vec_b1, vec_b2, vec_b3); + + const __m256 vec_sign = _mm256_set1_ps(-0.0f); + __m256 vec_babs0 = _mm256_andnot_ps(vec_sign, vec_b0); + __m256 vec_babs1 = _mm256_andnot_ps(vec_sign, vec_b1); + __m256 vec_babs2 = _mm256_andnot_ps(vec_sign, vec_b2); + __m256 vec_babs3 = _mm256_andnot_ps(vec_sign, vec_b3); + + // The upper bound for the LUT values (mixtures of 4 activations) is the sum + // of their absolute values. + __m256 abssum = _mm256_add_ps(_mm256_add_ps(vec_babs0, vec_babs1), _mm256_add_ps(vec_babs2, vec_babs3)); + + // Reduce max across lanes to find the global maximum sum in this chunk. + __m128 max4 = _mm_max_ps(_mm256_extractf128_ps(abssum, 1), _mm256_castps256_ps128(abssum)); + max4 = _mm_max_ps(max4, _mm_movehl_ps(max4, max4)); + max4 = _mm_max_ss(max4, _mm_movehdup_ps(max4)); + float scales = _mm_cvtss_f32(max4) / 127; + *lut_scales = std::max(*lut_scales, scales); +} + +// Current implementation requires (K * 4) == act_group_size and K >= 8 +// s0 = -1, s1 = 1 +// TODO: loop K +inline void +lut_ctor_g4_int8_impl( + int32_t act_k, + int8_t* qlut, + const float* b, + float* lut_scales, + float* lut_biases +) +{ + __m256 vec_lut[16]; + float biases = 0.0f; + float scales = *lut_scales; + float t_scales = scales ? 1.0f / scales : 0.0f; + + for (int k = 0; k < act_k / 32; ++k) { + const float* b_chunk = b + k * 32; + __m256 vec_b0, vec_b1, vec_b2, vec_b3; + MlasAvx2LoaduDeinterleave32Ps(b_chunk, vec_b0, vec_b1, vec_b2, vec_b3); + + PRAGMA_UNROLL + for (int g = 1; g < 16; g += 2) { + vec_lut[g] = vec_b0; + if (g & 0b0010) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b1); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b1); + } + if (g & 0b0100) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b2); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b2); + } + if (g & 0b1000) { + vec_lut[g] = _mm256_add_ps(vec_lut[g], vec_b3); + } else { + vec_lut[g] = _mm256_sub_ps(vec_lut[g], vec_b3); + } + } + PRAGMA_UNROLL + for (int g = 0; g < 16; g += 2) { + // vec_lut[g] = -vec_lut[15 - g]; + const __m256 neg_mask = _mm256_set1_ps(-0.0f); // all lanes have sign bit set + vec_lut[g] = _mm256_xor_ps(vec_lut[15 - g], neg_mask); + } + + biases += _mm256_addv_ps(vec_lut[0]); + + PRAGMA_UNROLL + for (int g = 0; g < 16; ++g) { + vec_lut[g] = _mm256_mul_ps(vec_lut[g], _mm256_set1_ps(t_scales)); + } + + __m256i vec_qlut[4]; + const __m256i shuf = _mm256_setr_epi8(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15); + PRAGMA_UNROLL + for (int g = 0; g < 4; g += 1) { + __m256i i0 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 0], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i1 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 1], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i2 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 2], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m256i i3 = _mm256_cvtps_epi32(_mm256_round_ps(vec_lut[g * 4 + 3], _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + + i0 = _mm256_packs_epi32(i0, i1); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32(i2, i3); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16(i0, i2); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + vec_qlut[g] = _mm256_shuffle_epi8(i0, shuf); // 0, 8, 16, 24, 1, 9, 17, 25, 2, 10, 18, 26, 3, 11, 19, 27, 4, 12, 20, 28, 5, 13, 21, 29, 6, 14, 22, 30, 7, 15, 23, 31 + } + + int32_t* qlut_i32 = reinterpret_cast(qlut); + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 0 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 0); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 1 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 1); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 2 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 2); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 3 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 3); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 4 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 4); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 5 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 5); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 6 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 6); + } + PRAGMA_UNROLL + for (int g = 0; g < 4; ++g) { + qlut_i32[k * 32 + 7 * 4 + g] = _mm256_extract_epi32(vec_qlut[g], 7); + } + } + + *lut_scales = scales; + *lut_biases = biases; +} + +// based on lut_ctor_g4_int8_impl +void +GenerateLUT_avx2( + const float* b, + int8_t* qlut, + float* lut_scales, + float* lut_biases, + size_t M, + size_t K, + size_t N, + size_t act_group_size, + size_t lut_stride +) +{ + (void)M; // silence unused parameter warning + (void)N; // silence unused parameter warning + // TODO: handle bitnet here + const int32_t kk_outer_max = static_cast(K / act_group_size); + const int32_t ags_div32 = static_cast(act_group_size / 32); + + for (int32_t kk_outer = 0; kk_outer < kk_outer_max; ++kk_outer) { + // compute partial max - directly reset scale to 0.0 + lut_scales[kk_outer] = 0.0f; // partial max reset + for (int32_t k_outer = 0; k_outer < ags_div32; ++k_outer) { + partial_max_g4_int8_k8(&lut_scales[kk_outer], &b[(kk_outer * act_group_size) + (k_outer * 32)]); + } + } + + for (int32_t k_outer_1 = 0; k_outer_1 < kk_outer_max; ++k_outer_1) { + // Use the explicit lut_stride provided by the dispatch/caller to ensure + // consistent memory layout between construction and compute paths. + lut_ctor_g4_int8_impl(static_cast(act_group_size), (&(qlut[(k_outer_1 * lut_stride)])), (&(b[(k_outer_1 * act_group_size)])), (&(lut_scales[k_outer_1])), (&(lut_biases[k_outer_1]))); + } +} + +inline void +tbl_g4_int8_float_gather_bit2_impl(int32_t m, float* C_global, float* CBits, float* C) +{ + constexpr int32_t bits = 2; + + int32_t m_c_outer_max = m / 32; + for (int32_t m_c_outer = 0; m_c_outer < m_c_outer_max; ++m_c_outer) { + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + PRAGMA_UNROLL + for (int32_t m_c_inner = 0; m_c_inner < 32; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (float)5.000000e-01f) + (CBits[cse_var_2 + bit_offset_1]); + } + } + + // Handle tail cases where m is not a multiple of 32. + // This ensures C_global is fully initialized for all m elements. + int32_t m_tail = m % 32; + if (m_tail > 0) { + int32_t m_c_outer = m_c_outer_max; + int32_t cse_var_2 = (m_c_outer * 32 * bits); + int32_t cse_var_1 = (m_c_outer * 32); + for (int32_t m_c_inner = 0; m_c_inner < m_tail; ++m_c_inner) { + int32_t bit_offset_0 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8); + int32_t bit_offset_1 = (m_c_inner / 8) * 8 * bits + (m_c_inner % 8) + 8; + C_global[cse_var_1 + m_c_inner] = (CBits[cse_var_2 + bit_offset_0] * (float)5.000000e-01f) + (CBits[cse_var_2 + bit_offset_1]); + } + } + + for (int32_t m_inner_outer = 0; m_inner_outer < m_c_outer_max; ++m_inner_outer) { + PRAGMA_UNROLL + for (int32_t m_inner = 0; m_inner < 32; ++m_inner) { + int offset = m_inner_outer * 32 + m_inner; + C[offset] = C_global[offset]; + } + } + + // Transfer the remaining tail results from C_global to the final output matrix C. + // This is necessary when m is not a multiple of 32, ensuring all output features + // are correctly written to the destination buffer. + if (m_tail > 0) { + int offset_base = m_c_outer_max * 32; + for (int32_t m_inner = 0; m_inner < m_tail; ++m_inner) { + int offset = offset_base + m_inner; + C[offset] = C_global[offset]; + } + } +} + +// When FastAggregation is enabled, FastAggregationK = ActK +// zero_points is merged into scales to maintain API +template +inline int32_t +tbl_g4_int8_float_update_impl(int32_t m, float* c, const int8_t* lut, const uint8_t* a, const float* scales, const float* lut_scales, const float* lut_biases) +{ + const __m128i vec_mask = _mm_set1_epi8(0x0f); + __m128i vec_lut[K]; + + PRAGMA_UNROLL + for (int k = 0; k < K; k++) { + vec_lut[k] = _mm_loadu_si128(reinterpret_cast(lut + k * 16)); + } + + SignedAdder adder; + for (int i = 0; i < m / 2; i += 16) { + __m256 vec_c0 = _mm256_setzero_ps(); + __m256 vec_c1 = _mm256_setzero_ps(); + __m256 vec_c2 = _mm256_setzero_ps(); + __m256 vec_c3 = _mm256_setzero_ps(); + + float partial_sum = -0.0f; + PRAGMA_UNROLL + for (int kk = 0; kk < K; kk += ActK) { + PRAGMA_UNROLL + for (int k = 0; k < ActK; k++) { + // (M // bm, KK / K / 4, bm / 16 / 2, K * 16) + __m128i vec_as = _mm_loadu_si128(reinterpret_cast(a + i * K + (kk + k) * 16)); + __m128i vec_a_bot = _mm_and_si128(vec_as, vec_mask); + __m128i vec_a_top = _mm_and_si128(_mm_srli_epi16(vec_as, 4), vec_mask); + + __m256i vec_lut_ = _mm256_set_m128i(vec_lut[kk + k], vec_lut[kk + k]); + __m256i vec_a = _mm256_set_m128i(vec_a_top, vec_a_bot); + __m256i vec_v = _mm256_shuffle_epi8(vec_lut_, vec_a); + adder.push(vec_v, k); + } + + __m256 vec_v_low_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_low())); + __m256 vec_v_low_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_low())); + __m256 vec_v_high_low = _mm256_cvtepi32_ps(extract_low_epi16_epi32(adder.get_high())); + __m256 vec_v_high_high = _mm256_cvtepi32_ps(extract_high_epi16_epi32(adder.get_high())); + + float lut_s = lut_scales[kk / (ActK * 4)]; + float lut_b = lut_biases[kk / (ActK * 4)]; + + partial_sum += lut_b; + + if (FastAggregation) { + lut_s = lut_s * ActK; + lut_b -= lut_s * (mylog2::value / 4 * get_bias_scale()); + } + +#define lut_fma(vs, ib) \ + ((ib) % Bits) ? (_mm256_mul_ps((vs), _mm256_set1_ps(lut_s))) \ + : (_mm256_fmadd_ps((vs), _mm256_set1_ps(lut_s), _mm256_set1_ps(lut_b))) + if (kk == 0) { + vec_c0 = lut_fma(vec_v_low_low, (i / 4)); + vec_c1 = lut_fma(vec_v_low_high, (i / 4 + 1)); + vec_c2 = lut_fma(vec_v_high_low, (i / 4 + 2)); + vec_c3 = lut_fma(vec_v_high_high, (i / 4 + 3)); + } else { + vec_c0 = _mm256_add_ps(vec_c0, lut_fma(vec_v_low_low, (i / 4))); + vec_c1 = _mm256_add_ps(vec_c1, lut_fma(vec_v_low_high, (i / 4 + 1))); + vec_c2 = _mm256_add_ps(vec_c2, lut_fma(vec_v_high_low, (i / 4 + 2))); + vec_c3 = _mm256_add_ps(vec_c3, lut_fma(vec_v_high_high, (i / 4 + 3))); + } +#undef lut_fma + } + + if (ZeroPoint) { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4) / Bits) * 16); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16); + vec_c0 = _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2)); + vec_c1 = _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8)); + vec_c2 = _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16)); + vec_c3 = _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24)); + __m256 vec_z0 = _mm256_loadu_ps(scales + ((i / 4) / Bits) * 16 + 8); + __m256 vec_z1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 16 + 8); + __m256 vec_z2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 16 + 8); + __m256 vec_z3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 16 + 8); + partial_sum *= 2; +#define add_zero(cs, zs, ib) \ + ((ib) % Bits) ? ((cs)) \ + : (_mm256_fmadd_ps((zs), _mm256_set1_ps(partial_sum), (cs))) + _mm256_storeu_ps(c + i * 2, add_zero(vec_c0, vec_z0, (i / 4))); + _mm256_storeu_ps(c + i * 2 + 8, add_zero(vec_c1, vec_z1, (i / 4 + 1))); + _mm256_storeu_ps(c + i * 2 + 16, add_zero(vec_c2, vec_z2, (i / 4 + 2))); + _mm256_storeu_ps(c + i * 2 + 24, add_zero(vec_c3, vec_z3, (i / 4 + 3))); +#undef add_zero + } else if (OneScale) { + float single_scale = scales[0]; + __m256 vec_s = _mm256_set1_ps(single_scale); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s, _mm256_loadu_ps(c + i * 2 + 24))); + } else { + __m256 vec_s0 = _mm256_loadu_ps(scales + ((i / 4) / Bits) * 8); + __m256 vec_s1 = _mm256_loadu_ps(scales + ((i / 4 + 1) / Bits) * 8); + __m256 vec_s2 = _mm256_loadu_ps(scales + ((i / 4 + 2) / Bits) * 8); + __m256 vec_s3 = _mm256_loadu_ps(scales + ((i / 4 + 3) / Bits) * 8); + _mm256_storeu_ps(c + i * 2, _mm256_fmadd_ps(vec_c0, vec_s0, _mm256_loadu_ps(c + i * 2))); + _mm256_storeu_ps(c + i * 2 + 8, _mm256_fmadd_ps(vec_c1, vec_s1, _mm256_loadu_ps(c + i * 2 + 8))); + _mm256_storeu_ps(c + i * 2 + 16, _mm256_fmadd_ps(vec_c2, vec_s2, _mm256_loadu_ps(c + i * 2 + 16))); + _mm256_storeu_ps(c + i * 2 + 24, _mm256_fmadd_ps(vec_c3, vec_s3, _mm256_loadu_ps(c + i * 2 + 24))); + } + } + + return 0; +} + +int32_t +tbl_int32_reset(int32_t m, int32_t* c) +{ + memset(c, 0, m * sizeof(int32_t)); + return 0; +} + +// based on qgemm_lut_int8_g4 +// Simplified version with hardcoded configuration for 2-bit quantization +void +TMACComputeGemm_avx2( + const uint8_t* A, // Quantized packed weights + const float* Scales, // Weight scales (and optionally zero-points) + const int8_t* LUT, // Pre-computed quantized lookup table + const float* LUT_Scales, // LUT scales from activation quantization + const float* LUT_Biases, // LUT biases from activation quantization + float* C, // Output buffer + int K, + int M, + int N, + int TotalN, + size_t BlkLen, // Weight quantization group size (q_group_size) + bool HasZeroPoint +) +{ + // Validate batch size (M) + // For now, TMAC AVX2 kernel processes one batch row at a time. + if (M != 1) { + MLAS_THROW_EX(std::runtime_error, "M > 1 is not supported yet in TMAC AVX2 kernel"); + } + + // get kernel config using the total output features (TotalN) + // This matches the parameters used during weight packing. + const MlasTMACKernelParams& tmac_params = MlasGetLutGemmKernelParams(TotalN, K, 2, BlkLen, HasZeroPoint); + + // ==================== CONFIGURATION ==================== + // Fixed parameters for this kernel implementation + bool has_zero_point = tmac_params.has_zero_point; // Whether weights have zero-points (interleaved with scales) + bool one_scale = tmac_params.one_scale; // Whether using single global scale for all weights + + const int32_t bits = static_cast(tmac_params.bits); // 2-bit quantization + const int32_t g = static_cast(tmac_params.g); // Packing group size + const int32_t ngroups_per_elem = static_cast(tmac_params.ngroups_per_elem); // 8 / g = 2 + const int32_t kfactor = static_cast(tmac_params.kfactor); // K-dimension blocking factor + + const bool has_scale = tmac_params.has_scale; // Always use weight scales + + // Parameters derived from inputs + const int32_t q_group_size = static_cast(tmac_params.q_group_size); // Weight quant group size + const int32_t act_group_size = static_cast(tmac_params.act_group_size); // Activation group size (same as weight) + const int32_t actk = static_cast(tmac_params.actk); // CRITICAL: = 16 for BlkLen=64, NOT BlkLen! + + const int32_t bm = static_cast(tmac_params.bm); + // m is the number of output features this kernel tile produces. + // We clamp m by N (the number of features in the current chunk) to ensure + // we don't read or write past the tile boundary during the gather phase. + int32_t m_full = bm / bits; + int32_t m = std::min(m_full, N); + + // Validate configuration + assert(bm % bits == 0); + assert(K % (kfactor * g) == 0); + assert(BlkLen % g == 0); + + // Validate configuration + assert(bm % bits == 0); + assert(K % (kfactor * g) == 0); + assert(BlkLen % g == 0); + + // ==================== ALLOCATE BUFFERS ==================== + // Use float for now (can be changed to _Float16 if needed) + + float* CBits = new float[bm]; + float* C_global = new float[m]; + + // Explicitly zero-initialize accumulation buffers to ensure determinism. + memset(CBits, 0, bm * sizeof(float)); + memset(C_global, 0, m * sizeof(float)); + + // ==================== CALCULATE LOOP PARAMETERS ==================== + const int32_t k_outer_max = K / (kfactor * g); + const int32_t scale_gs = q_group_size / (kfactor * g); + + // Calculate bit shift for scale indexing + int32_t scale_idx_shfr = 0; + if (scale_gs == 1) { + scale_idx_shfr = 0; + } else if (scale_gs == 2) { + scale_idx_shfr = 1; + } else if (scale_gs == 4) { + scale_idx_shfr = 2; + } else if (scale_gs == 8) { + scale_idx_shfr = 3; + } else { + MLAS_THROW_EX(std::runtime_error, + ("Unsupported scale_gs=" + std::to_string(scale_gs) + + " (q_group_size=" + std::to_string(q_group_size) + + ", kfactor=" + std::to_string(kfactor) + + ", g=" + std::to_string(g) + "). Expected {1,2,4,8}.").c_str()); + } + + // ==================== MAIN COMPUTATION LOOP ==================== + for (int32_t k_outer = 0; k_outer < k_outer_max; k_outer++) { + // Calculate pointers for this K-outer iteration + const uint8_t* a = A + k_outer * bm * kfactor / ngroups_per_elem; + + // Calculate scales pointer based on configuration + const float* scales = one_scale ? reinterpret_cast(Scales) : // Single global scale + (has_zero_point ? reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m * 2 : // Scale + zero_point pairs + reinterpret_cast(Scales) + (k_outer >> scale_idx_shfr) * m); // Scales only + + // Calculate LUT pointers + const int8_t* lut = reinterpret_cast(LUT) + k_outer * kfactor * (1 << g); // 2^g = 16 for g=4 + const float* lut_scales = reinterpret_cast(LUT_Scales) + + (k_outer * kfactor * g / act_group_size); + const float* lut_biases = reinterpret_cast(LUT_Biases) + + (k_outer * kfactor * g / act_group_size); + + // Select appropriate kernel template based on configuration + // For standard 2-bit, kfactor=16, BlkLen=64: actk = 64/4 = 16 + if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 16 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } + // actk == 8 variants (for BlkLen=32) + else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 16 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } + // kfactor == 8 variants + else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && !one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else if (has_scale && kfactor == 8 && bits == 2 && actk == 8 && !has_zero_point && one_scale) { + tbl_g4_int8_float_update_impl( + static_cast(bm), CBits, lut, a, scales, lut_scales, lut_biases + ); + } else { + // No matching kernel template found + MLAS_THROW_EX(std::runtime_error, "No matching kernel found for T-MAC GEMM"); + } + } + + // ==================== GATHER RESULTS ==================== + // Gather bit-plane results into final output + // Only support 2-bit in this implementation + // TODO(vraspar): extend to other bit-widths + tbl_g4_int8_float_gather_bit2_impl(m, C_global, CBits, C); + + // ==================== CLEANUP ==================== + delete[] C_global; + delete[] CBits; +} + +// Kernel dispatch structure definition. + +const MLAS_QNBIT_LUT_GEMM_DISPATCH MlasLutGenKernelAvx2 = []() { + MLAS_QNBIT_LUT_GEMM_DISPATCH d; + d.GenerateLUT = GenerateLUT_avx2; + d.ComputeGemm = TMACComputeGemm_avx2; + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h new file mode 100644 index 0000000000000..e66eec6fd67ea --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_lut_kernel_avx2.h @@ -0,0 +1,43 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_lut_kernel_avx2.h + +Abstract: + + This module implements x64 AVX2 kernel functions for LUT-based n-bit + quantized integer matrix multiplication. +--*/ + +#pragma once +#include "qnbitgemm.h" + +void +GenerateLUT_avx2( + int32_t group_size, + int8_t lut, + const float* b, + float* scales, + float* biases, + int K +); + +void +TMACComputeGemm_avx2( + const void* A, + const void* a_scales, + const void* LUT, + const void* LUT_Scales, + const void* LUT_Biases, + void* C, + int bm, + int K, + int M, + int N, + size_t BlkLen +); diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index a1859b9d7071b..faeee1abd07fc 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -81,7 +81,7 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& return true; } -bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& /*logger*/) const { +bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const { if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "Clip", {1, 6, 11, 12, 13}) || !graph_utils::IsSupportedProvider(node, {kCpuExecutionProvider}) || !optimizer_utils::CheckOutputEdges(graph, node, 1)) { @@ -95,6 +95,10 @@ bool ClipQuantFusion::SatisfyCondition(const Graph& graph, const Node& node, con return false; } + if (!graph_utils::CanRemoveNode(graph, node, logger)) { + return false; + } + return true; } diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 078fbe8ed0478..05b337d9933fb 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -758,7 +758,14 @@ bool BatchNormalizationNodeGroupSelector::Check(const GraphViewer& graph_viewer, const Node* redundant_clip_node, const std::vector& dq_nodes, const std::vector& q_nodes) const { - if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, 3)) { + // BatchNormalization has 5 inputs: x, scale, bias, mean, var. + // Require DQ on x and scale (indices 0,1). mean, var may optionally have DQ. + const int num_dq_nodes = gsl::narrow_cast(dq_nodes.size()); + if (num_dq_nodes < 3 || num_dq_nodes > 5) { + return false; + } + + if (!CheckQDQNodes(graph_viewer, node, redundant_clip_node, dq_nodes, q_nodes, num_dq_nodes)) { return false; } diff --git a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc index 655364357999a..3727ac0918115 100644 --- a/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/skip_layer_norm_fusion.cc @@ -254,6 +254,21 @@ Status SkipLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_le continue; } + // SkipLayerNormalization kernel requires gamma and beta to be 1D. + // Skip fusion if gamma or beta have more than 1 dimension. + const NodeArg* gamma_arg = ln_node.MutableInputDefs()[1]; + const TensorShapeProto* gamma_shape = gamma_arg->Shape(); + if (gamma_shape != nullptr && gamma_shape->dim_size() != 1) { + continue; + } + if (ln_node.MutableInputDefs().size() > 2) { + const NodeArg* beta_arg = ln_node.MutableInputDefs()[2]; + const TensorShapeProto* beta_shape = beta_arg->Shape(); + if (beta_shape != nullptr && beta_shape->dim_size() != 1) { + continue; + } + } + NodeArg beta_place_holder("", nullptr); // Get the inputs for the new SkipLayerNormalization node. diff --git a/onnxruntime/core/platform/linux/device_discovery.cc b/onnxruntime/core/platform/linux/device_discovery.cc index e9c45a6966ef8..db6ac73996863 100644 --- a/onnxruntime/core/platform/linux/device_discovery.cc +++ b/onnxruntime/core/platform/linux/device_discovery.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include #include "core/common/common.h" @@ -114,6 +115,28 @@ std::optional IsGpuDiscrete(uint16_t vendor_id, uint16_t device_id) { return std::nullopt; } +Status GetPciBusId(const std::filesystem::path& sysfs_path, std::optional& pci_bus_id) { + constexpr const char* regex_pattern{R"([0-9a-f]+:[0-9a-f]+:[0-9a-f]+[.][0-9a-f]+)"}; + static const std::regex pci_bus_id_regex(regex_pattern); + + std::error_code error_code; + auto pci_bus_id_path = std::filesystem::canonical(sysfs_path / "device", error_code); // resolves symlink to PCI bus id, e.g. 0000:65:00.0 + ORT_RETURN_IF_ERROR(ErrorCodeToStatus(error_code)); + + auto pci_bus_id_filename = pci_bus_id_path.filename(); + if (std::regex_match(pci_bus_id_filename.string(), pci_bus_id_regex)) { + pci_bus_id = pci_bus_id_filename.string(); + } else { + pci_bus_id = {}; + LOGS_DEFAULT(WARNING) << MakeString("Skipping pci_bus_id for PCI path at \"", + pci_bus_id_path.string(), + "\" because filename \"", pci_bus_id_filename, "\" dit not match expected pattern of ", + regex_pattern); + }; + + return Status::OK(); +} + Status GetGpuDeviceFromSysfs(const GpuSysfsPathInfo& path_info, OrtHardwareDevice& gpu_device_out) { OrtHardwareDevice gpu_device{}; const auto& sysfs_path = path_info.path; @@ -140,6 +163,12 @@ Status GetGpuDeviceFromSysfs(const GpuSysfsPathInfo& path_info, OrtHardwareDevic gpu_device.metadata.Add("Discrete", (*is_gpu_discrete ? "1" : "0")); } + std::optional pci_bus_id; + ORT_RETURN_IF_ERROR(GetPciBusId(sysfs_path, pci_bus_id)); + if (pci_bus_id) { + gpu_device.metadata.Add("pci_bus_id", std::move(*pci_bus_id)); + } + gpu_device.type = OrtHardwareDeviceType_GPU; gpu_device_out = std::move(gpu_device); diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 9b71f4ba2ebec..6d5a400be703b 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -3,6 +3,10 @@ #include "core/platform/windows/telemetry.h" #include +#include +#include +#include +#include #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -51,6 +55,80 @@ TRACELOGGING_DEFINE_PROVIDER(telemetry_provider_handle, "Microsoft.ML.ONNXRuntim // {3a26b1ff-7484-7484-7484-15261f42614d} (0x3a26b1ff, 0x7484, 0x7484, 0x74, 0x84, 0x15, 0x26, 0x1f, 0x42, 0x61, 0x4d), TraceLoggingOptionMicrosoftTelemetry()); + +std::string ConvertWideStringToUtf8(const std::wstring& wide) { + if (wide.empty()) + return {}; + + const UINT code_page = CP_UTF8; + const DWORD flags = 0; + LPCWCH const src = wide.data(); + const int src_len = static_cast(wide.size()); + int utf8_length = ::WideCharToMultiByte(code_page, flags, src, src_len, nullptr, 0, nullptr, nullptr); + if (utf8_length == 0) + return {}; + + std::string utf8(utf8_length, '\0'); + if (::WideCharToMultiByte(code_page, flags, src, src_len, utf8.data(), utf8_length, nullptr, nullptr) == 0) + return {}; + + return utf8; +} + +std::string GetServiceNamesForCurrentProcess() { + static std::once_flag once_flag; + static std::string service_names; + + std::call_once(once_flag, [] { + SC_HANDLE service_manager = ::OpenSCManagerW(nullptr, nullptr, SC_MANAGER_ENUMERATE_SERVICE); + if (service_manager == nullptr) + return; + + DWORD bytes_needed = 0; + DWORD services_returned = 0; + DWORD resume_handle = 0; + if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, nullptr, 0, &bytes_needed, + &services_returned, &resume_handle, nullptr) && + ::GetLastError() != ERROR_MORE_DATA) { + ::CloseServiceHandle(service_manager); + return; + } + + if (bytes_needed == 0) { + ::CloseServiceHandle(service_manager); + return; + } + + std::vector buffer(bytes_needed); + auto* services = reinterpret_cast(buffer.data()); + services_returned = 0; + resume_handle = 0; + if (!::EnumServicesStatusExW(service_manager, SC_ENUM_PROCESS_INFO, SERVICE_WIN32, SERVICE_ACTIVE, reinterpret_cast(services), + bytes_needed, &bytes_needed, &services_returned, &resume_handle, nullptr)) { + ::CloseServiceHandle(service_manager); + return; + } + + DWORD current_pid = ::GetCurrentProcessId(); + std::wstring aggregated; + bool first = true; + for (DWORD i = 0; i < services_returned; ++i) { + if (services[i].ServiceStatusProcess.dwProcessId == current_pid) { + if (!first) { + aggregated.push_back(L','); + } + aggregated.append(services[i].lpServiceName); + first = false; + } + } + + ::CloseServiceHandle(service_manager); + + service_names = ConvertWideStringToUtf8(aggregated); + }); + + return service_names; +} } // namespace #ifdef _MSC_VER @@ -178,6 +256,7 @@ void WindowsTelemetry::LogProcessInfo() const { #if BUILD_INBOX isRedist = false; #endif + const std::string service_names = GetServiceNamesForCurrentProcess(); TraceLoggingWrite(telemetry_provider_handle, "ProcessInfo", TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), @@ -189,7 +268,8 @@ void WindowsTelemetry::LogProcessInfo() const { TraceLoggingString(ORT_VERSION, "runtimeVersion"), TraceLoggingBool(IsDebuggerPresent(), "isDebuggerAttached"), TraceLoggingBool(isRedist, "isRedist"), - TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName"), + TraceLoggingString(service_names.c_str(), "serviceNames")); process_info_logged = true; } @@ -204,7 +284,8 @@ void WindowsTelemetry::LogSessionCreationStart(uint32_t session_id) const { TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), TraceLoggingUInt32(session_id, "sessionId"), - TraceLoggingLevel(WINEVENT_LEVEL_INFO)); + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogEvaluationStop(uint32_t session_id) const { @@ -278,6 +359,7 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio execution_provider_string += i; } + const std::string service_names = GetServiceNamesForCurrentProcess(); // Difference is MeasureEvent & isCaptureState, but keep in sync otherwise if (!captureState) { TraceLoggingWrite(telemetry_provider_handle, @@ -304,7 +386,9 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(service_names.c_str(), "serviceNames"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } else { TraceLoggingWrite(telemetry_provider_handle, "SessionCreation_CaptureState", @@ -330,7 +414,9 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio TraceLoggingString(model_weight_hash.c_str(), "modelWeightHash"), TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), TraceLoggingString(loaded_from.c_str(), "loadedFrom"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"), + TraceLoggingString(service_names.c_str(), "serviceNames"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } } @@ -419,7 +505,8 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"), TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), - TraceLoggingInt32(line, "line")); + TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #else TraceLoggingWrite(telemetry_provider_handle, "RuntimeError", @@ -435,7 +522,8 @@ void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status TraceLoggingString(status.ErrorMessage().c_str(), "errorMessage"), TraceLoggingString(file, "file"), TraceLoggingString(function, "function"), - TraceLoggingInt32(line, "line")); + TraceLoggingInt32(line, "line"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); #endif } @@ -465,7 +553,8 @@ void WindowsTelemetry::LogRuntimePerf(uint32_t session_id, uint32_t total_runs_s TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingUInt32(total_runs_since_last, "totalRuns"), TraceLoggingInt64(total_run_duration_since_last, "totalRunDuration"), - TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize")); + TraceLoggingString(total_duration_per_batch_size.c_str(), "totalRunDurationPerBatchSize"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogExecutionProviderEvent(LUID* adapterLuid) const { @@ -541,7 +630,8 @@ void WindowsTelemetry::LogAutoEpSelection(uint32_t session_id, const std::string TraceLoggingUInt32(session_id, "sessionId"), TraceLoggingString(selection_policy.c_str(), "selectionPolicy"), TraceLoggingString(requested_execution_provider_string.c_str(), "requestedExecutionProviderIds"), - TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds")); + TraceLoggingString(available_execution_provider_string.c_str(), "availableExecutionProviderIds"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const std::string& provider_options_string, bool captureState) const { @@ -560,7 +650,8 @@ void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const // Telemetry info TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingString(provider_id.c_str(), "providerId"), - TraceLoggingString(provider_options_string.c_str(), "providerOptions")); + TraceLoggingString(provider_options_string.c_str(), "providerOptions"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } else { TraceLoggingWrite(telemetry_provider_handle, "ProviderOptions_CaptureState", @@ -572,7 +663,8 @@ void WindowsTelemetry::LogProviderOptions(const std::string& provider_id, const // Telemetry info TraceLoggingUInt8(0, "schemaVersion"), TraceLoggingString(provider_id.c_str(), "providerId"), - TraceLoggingString(provider_options_string.c_str(), "providerOptions")); + TraceLoggingString(provider_options_string.c_str(), "providerOptions"), + TraceLoggingString(ORT_CALLER_FRAMEWORK, "frameworkName")); } } diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index db96089f7d053..0ad8d1d4fef4d 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -1439,6 +1439,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, int16_t, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Int4x2, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, UInt4x2, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Int2x4, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, UInt2x4, DequantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Float8E4M3FN, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Float8E4M3FNUZ, DequantizeLinear); @@ -1451,6 +1453,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, int16_t, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Int4x2, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, UInt4x2, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Int2x4, QuantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, UInt2x4, QuantizeLinear); #if !defined(DISABLE_FLOAT8_TYPES) class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Float8E4M3FN, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 25, Float8E4M3FNUZ, QuantizeLinear); @@ -3557,6 +3561,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { DequantizeLinear)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, @@ -3579,6 +3587,10 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { QuantizeLinear)>, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, #if !defined(DISABLE_FLOAT8_TYPES) BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc b/onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc index af67419f4fb91..60ebf862e1601 100644 --- a/onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc +++ b/onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc @@ -73,10 +73,10 @@ common::Status ArrayFeatureExtractorOp::Compute(OpKernelContext* context) con } for (int64_t i = 0; i < num_indices; ++i) { - if (y_data[i] >= stride) { + if (y_data[i] < 0 || y_data[i] >= stride) { return ORT_MAKE_STATUS( ONNXRUNTIME, INVALID_ARGUMENT, - "Invalid Y argument: index is out of range: Y[", i, "] (", y_data[i], ") >=", stride); + "Invalid Y argument: index is out of range: Y[", i, "] (", y_data[i], ") must be in [0, ", stride, ")"); } } diff --git a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h index e26eae19b8fd4..c1a51c802f9d8 100644 --- a/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h +++ b/onnxruntime/core/providers/cpu/quantization/matmul_integer_base.h @@ -1,6 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include +#include +#include +#include + +#include "core/common/cpuid_info.h" #include "core/framework/op_kernel.h" #include "core/mlas/inc/mlas.h" #include "core/providers/common.h" @@ -20,6 +27,11 @@ class MatMulIntegerBase : public OpKernel { // only pack Matrix B if (input_idx == GetBIdx()) { +#if defined(USE_KLEIDIAI) + if (TryKleidiaiDynamicPrePack(tensor, input_idx, alloc, is_packed, prepacked_weights)) { + return Status::OK(); + } +#endif // Only handle the common case of a 2D weight matrix. Additional matrices // could be handled by stacking the packed buffers. b_shape_ = tensor.Shape(); @@ -89,6 +101,248 @@ class MatMulIntegerBase : public OpKernel { return false; } + virtual int GetBScaleIdx() const { + return -1; + } + + virtual int GetBZeroPointIdx() const { + return -1; + } + + virtual int GetBiasIdx() const { + return -1; + } + + virtual bool SupportsKleidiaiDynamicQuant() const { + return false; + } + + bool can_use_dynamic_quant_mlas_{false}; + +#if defined(USE_KLEIDIAI) + struct KleidiaiDynamicPackContext { + const Tensor* scale{nullptr}; + const Tensor* bias{nullptr}; + const uint8_t* b_data{nullptr}; + size_t K{0}; + size_t N{0}; + std::optional transposed_buffer; + }; + /* + Helper method to pre-pack Matrix B using Arm® KleidiAI™ packing if eligible. + + Returns false if KleidiAI dynamic quantization is not supported or the index of the input tensor is not input B's index. + If these checks pass, prepares a dynamic quantization pack context and calls PrepareKleidiaiDynamicPack for further policies. + If those policies also satisfy, it calls the helper to execute the pre-packing in KleidiAI context. + Returns true if pre-packing was performed and false otherwise. + */ + bool TryKleidiaiDynamicPrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, + PrePackedWeights* prepacked_weights) { + if (!SupportsKleidiaiDynamicQuant() || input_idx != GetBIdx()) { + return false; + } + + KleidiaiDynamicPackContext ctx; + if (!PrepareKleidiaiDynamicPack(tensor, alloc, ctx)) { + return false; + } + + return ExecuteKleidiaiDynamicPack(ctx, alloc, is_packed, prepacked_weights); + } + /* + Helper method to determine if Arm® KleidiAI™ dynamic quantization pre-packing policies are satisfied. + + Checks for the presence of the constant input tensor B, symmetry on the zero point and validity of the scales. + Also checks if the shape of the tensor B is supported by KleidiAI and if bias tensor is also a constant input. + Makes B transposition if necessary. + Sets can_use_dynamic_quant_mlas_ flag accordingly and returns true if all policies are satisfied. + */ + bool PrepareKleidiaiDynamicPack(const Tensor& tensor, + AllocatorPtr alloc, + KleidiaiDynamicPackContext& ctx) { + can_use_dynamic_quant_mlas_ = false; + dynamic_quant_mlas_bias_data_was_packed_ = false; + + ctx.scale = GetConstantInputTensor(GetBScaleIdx()); + if (ctx.scale == nullptr) { + return false; + } + + if (!IsZeroPointSymmetric()) { + return false; + } + + if (!AreScalesValid(*ctx.scale)) { + return false; + } + + if (!IsBShapeSupportedForDynamicQuant(tensor.Shape())) { + return false; + } + + ctx.bias = GetConstantInputTensor(GetBiasIdx()); + + ctx.K = static_cast(b_shape_[0]); + ctx.N = static_cast(b_shape_[1]); + ctx.b_data = static_cast(tensor.DataRaw()); + + if (IsBTransposed()) { + std::swap(ctx.K, ctx.N); + ctx.b_data = quantization::TransPoseInputData(ctx.b_data, ctx.transposed_buffer, alloc, ctx.N, ctx.K); + } + + // KleidiAI dynamic-qgemm packing is not expected to handle degenerate shapes. + // If K==0 there is nothing to reduce over, and the RHS packer may dereference invalid memory. + if (ctx.K == 0 || ctx.N == 0) { + return false; + } + + if (ctx.bias != nullptr) { + dynamic_quant_mlas_bias_data_was_packed_ = true; + } + + can_use_dynamic_quant_mlas_ = true; + return true; + } + /* + Helper method to execute Arm® KleidiAI™ dynamic quantization pre-packing. + + If can_use_dynamic_quant_mlas_ flag was true from previous policy controls then it checks the packed + RHS matrix size in bytes and allocates the packed buffer. If the size is 0 returns false. + It then assigns the scale and bias data accordingly and calls the packing function. + It caches this pre-packed buffer as Mlas does. + */ + bool ExecuteKleidiaiDynamicPack(const KleidiaiDynamicPackContext& ctx, + AllocatorPtr alloc, + bool& is_packed, + PrePackedWeights* prepacked_weights) { + if (!can_use_dynamic_quant_mlas_) { + return false; + } + + is_packed = false; + + const size_t packed_b_size = MlasDynamicQgemmPackBSize(ctx.N, ctx.K); + if (packed_b_size == 0) { + can_use_dynamic_quant_mlas_ = false; + return false; + } + + packed_b_ = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + memset(packed_b_.get(), 0, packed_b_size); + + const auto scales = static_cast(ctx.scale->Shape().Size()) == ctx.N + ? std::vector(&ctx.scale->Data()[0], + &ctx.scale->Data()[ctx.N]) + : std::vector(ctx.N, ctx.scale->Data()[0]); + + const auto biases = ctx.bias != nullptr + ? std::vector(&ctx.bias->Data()[0], + &ctx.bias->Data()[ctx.N]) + : std::vector(ctx.N, 0.f); + + MlasDynamicQgemmPackB(ctx.N, ctx.K, reinterpret_cast(ctx.b_data), + scales.data(), biases.data(), packed_b_.get()); + + if (prepacked_weights != nullptr) { + prepacked_weights->buffers_.push_back(std::move(packed_b_)); + prepacked_weights->buffer_sizes_.push_back(packed_b_size); + } + + is_packed = true; + return true; + } + /* + Helper for checking the zero points tensor of the input. Arm® KleidiAI™ supports symmetric zero points. + + This helper method checks if the zero point tensor, if it's present in the inputs with its index, it checks the data type either uint8_t or int8_t. + It also checks if all the zero point values are zeros. If not, sets the can_use_dynamic_quant_mlas_ flag to false. + If zero point tensor is not present, it sets the flag true as symmetric zero point is assumed. + Returns the flag. + */ + bool IsZeroPointSymmetric() { + const Tensor* b_zp_constant_tensor = GetConstantInputTensor(GetBZeroPointIdx()); + if (b_zp_constant_tensor != nullptr) { + assert(b_zp_constant_tensor->IsDataType() || b_zp_constant_tensor->IsDataType()); + const auto* zp_bytes = static_cast(b_zp_constant_tensor->DataRaw()); + const size_t zp_size_in_bytes = b_zp_constant_tensor->SizeInBytes(); + return std::none_of(zp_bytes, zp_bytes + zp_size_in_bytes, + [](std::byte v) { return v != std::byte{0}; }); + } + + const auto input_defs = Info().node().InputDefs(); + const int b_zp_idx = GetBZeroPointIdx(); + const bool b_zp_input_exists = b_zp_idx >= 0 && + static_cast(b_zp_idx) < input_defs.size() && + input_defs[b_zp_idx]->Exists(); + return !b_zp_input_exists; + } + /* + Helper method to check the validity of the scales tensor for Arm® KleidiAI™ dynamic quantization. + Scales are invalid and can_use_dynamic_quant_mlas_ flag is false returns if the float scales are non-finite or non-positive. + Otherwise can_use_dynamic_quant_mlas_ flag returned true. + */ + bool AreScalesValid(const Tensor& b_scale_tensor) { + const auto bs = b_scale_tensor.DataAsSpan(); + const bool has_invalid = + std::any_of(bs.begin(), bs.end(), + [](float s) { return !std::isfinite(s) || s <= 0.0f; }); + + return !has_invalid; + } + /* + Helper to promote a 1D tensor to 2D, for Arm® KleidiAI™ dynamic quantization, if necessary. Returns false if the tensor rank is 0. + */ + bool PromoteBShapeIfNeeded() { + if (b_shape_.NumDimensions() == 0) { + return false; // rank-0 tensor is not supported + } + + if (b_shape_.NumDimensions() == 1) { + TensorShapeVector expanded{1, b_shape_[0]}; + b_shape_ = TensorShape(expanded); + } + + return true; + } + /* + Helper method to check the shape policy of the tensor B is passes for Arm® KleidiAI™ dynamic quantization. + The shape should be at least 2D and all the dimensions except the last two should be 1. 1D tensor is promoted to 2D. + */ + bool IsBShapeSupportedForDynamicQuant(const TensorShape& tensor_shape) { + b_shape_ = tensor_shape; + if (!PromoteBShapeIfNeeded()) { + return false; + } + + for (size_t i = 0; i < (b_shape_.NumDimensions() - 2); ++i) { + if (b_shape_[i] != 1) { + return false; + } + } + b_shape_ = tensor_shape; + return true; + } + /* + Checks against the constant initialized tensor index and returns the constant tensor if present. + Returns nullptr if index is invalid or the tensor is not held by the kernel instance. + */ + const Tensor* GetConstantInputTensor(int input_idx) const { + if (input_idx < 0) { + return nullptr; + } + const OrtValue* ort_value = nullptr; + if (!Info().TryGetConstantInput(input_idx, &ort_value)) { + return nullptr; + } + + return &ort_value->Get(); + } + + bool dynamic_quant_mlas_bias_data_was_packed_{false}; +#endif + // Check if quantization parameter of B is supported. // It should be in one of the formats below: // 1. Scalar diff --git a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc index 0f93dd4f476ba..60a64af2e50fc 100644 --- a/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc +++ b/onnxruntime/core/providers/cpu/quantization/quantize_linear.cc @@ -168,6 +168,8 @@ REGISTER_DEQUANTIZELINEAR(uint16_t) REGISTER_DEQUANTIZELINEAR(int32_t) REGISTER_DEQUANTIZELINEAR(Int4x2) REGISTER_DEQUANTIZELINEAR(UInt4x2) +REGISTER_DEQUANTIZELINEAR(Int2x4) +REGISTER_DEQUANTIZELINEAR(UInt2x4) #if !defined(DISABLE_FLOAT8_TYPES) REGISTER_DEQUANTIZELINEAR(Float8E4M3FN) REGISTER_DEQUANTIZELINEAR(Float8E4M3FNUZ) @@ -309,7 +311,7 @@ ONNX_CPU_OPERATOR_TYPED_MS_KERNEL( } // namespace contrib #endif // !defined(DISABLE_CONTRIB_OPS) -template +template struct DequantizeLinearApply; // The dimensions before quantize axis and after quantize axis can be flattened. @@ -317,8 +319,8 @@ struct DequantizeLinearApply; // If the quantization happens on the first or last axis, the flattened tensor is // effectively rank-2. // For per tensor quantization, the tensor is effectively rank-1. -template -struct DequantizeLinearApply { +template +struct DequantizeLinearApply { /** * @brief Calculate per-tensor/layer or per-axis quantization of DequantizeLinear on the * flattened tensors. @@ -413,24 +415,26 @@ struct DequantizeLinearApply { } }; -template -struct DequantizeLinearApply { - // per-tensor/layer or per-axis quantization +template +struct DequantizeLinearApply { + // per-tensor/layer or per-axis quantization for sub-byte types void op(size_t M, size_t K, size_t N, const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; + constexpr size_t shift_bits = (elements_per_byte == 2) ? 1 : 2; // log2(elements_per_byte) + constexpr size_t mask = elements_per_byte - 1; // For modulo operation for (size_t m = 0; m < M; m++) { for (size_t bd = 0; bd < K; bd++) { - size_t bd_i = bd >> 1; /*bd / 2*/ - size_t bd_j = bd & 0x1; /*bd % 2*/ + size_t bd_i = bd >> shift_bits; // bd / elements_per_byte + size_t bd_j = bd & mask; // bd % elements_per_byte auto zp = zero_point ? static_cast(zero_point[bd_i].GetElem(bd_j)) : 0; auto sc = static_cast(scale[bd]); for (size_t bs = 0; bs < N; bs++) { - size_t input_i = input_index >> 1; - size_t input_j = input_index & 0x1; + size_t input_i = input_index >> shift_bits; + size_t input_j = input_index & mask; int32_t val = static_cast(input[input_i].GetElem(input_j)); *output++ = static_cast(static_cast(val - zp) * sc); input_index += 1; @@ -447,6 +451,8 @@ struct DequantizeLinearApply { const T* input, const OutT* scale, OutT* output, const T* zero_point, concurrency::ThreadPool* thread_pool) { ORT_UNUSED_PARAMETER(thread_pool); size_t input_index = 0; + constexpr size_t shift_bits = (elements_per_byte == 2) ? 1 : 2; // log2(elements_per_byte) + constexpr size_t mask = elements_per_byte - 1; // For modulo operation if (zero_point) { size_t zp_index = 0; @@ -456,10 +462,10 @@ struct DequantizeLinearApply { for (size_t qb = 0, qb_end = std::min(quant_block_size, K - bd); qb < qb_end; ++qb) { auto q_zp_index = zp_index; for (size_t bs = 0; bs < N; ++bs, ++input_index, ++q_zp_index) { - auto zp = static_cast(zero_point[q_zp_index >> 1].GetElem(q_zp_index & 0x1)); + auto zp = static_cast(zero_point[q_zp_index >> shift_bits].GetElem(q_zp_index & mask)); auto sc = static_cast(scale[bs]); - int32_t val = static_cast(input[input_index >> 1].GetElem(input_index & 0x1)); + int32_t val = static_cast(input[input_index >> shift_bits].GetElem(input_index & mask)); *output++ = static_cast(static_cast(val - zp) * sc); } } @@ -475,7 +481,7 @@ struct DequantizeLinearApply { for (size_t bs = 0; bs < N; ++bs, ++input_index) { auto sc = static_cast(scale[bs]); - int32_t val = static_cast(input[input_index >> 1].GetElem(input_index & 0x1)); + int32_t val = static_cast(input[input_index >> shift_bits].GetElem(input_index & mask)); *output++ = static_cast(static_cast(val) * sc); } } @@ -492,8 +498,8 @@ struct DequantizeLinearApply { #if !defined(DISABLE_FLOAT8_TYPES) #define DEQUANTIZE_LINEAR_APPLY_FLOAT8(T) \ - template \ - struct DequantizeLinearApply { \ + template \ + struct DequantizeLinearApply { \ /* Per-tensor/layer or per-axis quantization */ \ void op(size_t M, size_t K, size_t N, \ const T* input, const OutT* scale, OutT* output, const T*, concurrency::ThreadPool*) { \ @@ -561,38 +567,42 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { const auto to = x_scale.GetElementType(); const T* input = x.Data(); - constexpr bool is_4bit = boost::mp11::mp_contains, T>::value; + constexpr bool is_sub_byte = boost::mp11::mp_contains, T>::value; + // Determine elements_per_byte: Int4x2/UInt4x2 = 2, Int2x4/UInt2x4 = 4 + constexpr int elements_per_byte = + boost::mp11::mp_contains, T>::value ? 2 : boost::mp11::mp_contains, T>::value ? 4 + : 0; concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); if (to == ONNX_NAMESPACE::TensorProto::FLOAT) { const float* scale = x_scale.Data(); float* output = y.MutableData(); if (block_size_) { - DequantizeLinearApply().op(static_cast(process_block_count), - static_cast(broadcast_dim), - static_cast(process_block_size), - static_cast(block_size_), - input, scale, output, zero_point, thread_pool); + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + static_cast(block_size_), + input, scale, output, zero_point, thread_pool); } else { - DequantizeLinearApply().op(static_cast(process_block_count), - static_cast(broadcast_dim), - static_cast(process_block_size), - input, scale, output, zero_point, thread_pool); + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::FLOAT16) { const MLFloat16* scale = x_scale.Data(); MLFloat16* output = y.MutableData(); if (block_size_) { - DequantizeLinearApply().op(static_cast(process_block_count), - static_cast(broadcast_dim), - static_cast(process_block_size), - static_cast(block_size_), - input, scale, output, zero_point, thread_pool); + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + static_cast(block_size_), + input, scale, output, zero_point, thread_pool); } else { - DequantizeLinearApply().op(static_cast(process_block_count), - static_cast(broadcast_dim), - static_cast(process_block_size), - input, scale, output, zero_point, thread_pool); + DequantizeLinearApply().op(static_cast(process_block_count), + static_cast(broadcast_dim), + static_cast(process_block_size), + input, scale, output, zero_point, thread_pool); } } else if (to == ONNX_NAMESPACE::TensorProto::BFLOAT16) { ORT_THROW("DequantizeLinear into BFLOAT16 is not implemented yet."); @@ -611,7 +621,9 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { KernelDefBuilder() \ .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}) \ + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), \ QuantizeLinear); #define REGISTER_QUANTIZELINEAR_VERSIONED(T, start_version, end_version) \ @@ -623,7 +635,21 @@ Status DequantizeLinear::Compute(OpKernelContext* ctx) const { KernelDefBuilder() \ .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), \ DataTypeImpl::GetTensorType()}) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + .TypeConstraint("T2", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}) \ + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), \ + QuantizeLinear); + +#define REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(T, start_version, end_version) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + QuantizeLinear, \ + start_version, \ + end_version, \ + T, \ + KernelDefBuilder() \ + .TypeConstraint("T1", {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ QuantizeLinear); #define REGISTER_QUANTIZELINEAR_VERSIONED_PRE_19(T) \ @@ -654,6 +680,8 @@ REGISTER_QUANTIZELINEAR(int16_t) REGISTER_QUANTIZELINEAR(uint16_t) REGISTER_QUANTIZELINEAR(Int4x2) REGISTER_QUANTIZELINEAR(UInt4x2) +REGISTER_QUANTIZELINEAR(Int2x4) +REGISTER_QUANTIZELINEAR(UInt2x4) #if !defined(DISABLE_FLOAT8_TYPES) REGISTER_QUANTIZELINEAR(Float8E4M3FN) REGISTER_QUANTIZELINEAR(Float8E4M3FNUZ) @@ -690,28 +718,27 @@ REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ, 23, 23) #endif // Opset 21 added 16-bit and 4-bit int support to Q ops. -// TODO(adrianlizarraga): Support int4 and block quantization. -REGISTER_QUANTIZELINEAR_VERSIONED(int8_t, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(int16_t, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(uint16_t, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(Int4x2, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(UInt4x2, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(int8_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(uint8_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(int16_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(uint16_t, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Int4x2, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(UInt4x2, 21, 22) #if !defined(DISABLE_FLOAT8_TYPES) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FN, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2, 21, 22) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E4M3FN, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E4M3FNUZ, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E5M2, 21, 22) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E5M2FNUZ, 21, 22) #endif // Opset 19 added 8-bit floats to Q ops. -REGISTER_QUANTIZELINEAR_VERSIONED(int8_t, 19, 20) -REGISTER_QUANTIZELINEAR_VERSIONED(uint8_t, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(int8_t, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(uint8_t, 19, 20) #if !defined(DISABLE_FLOAT8_TYPES) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FN, 19, 20) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E4M3FNUZ, 19, 20) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2, 19, 20) -REGISTER_QUANTIZELINEAR_VERSIONED(Float8E5M2FNUZ, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E4M3FN, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E4M3FNUZ, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E5M2, 19, 20) +REGISTER_QUANTIZELINEAR_VERSIONED_PRE_23(Float8E5M2FNUZ, 19, 20) #endif // Before opset 19, Q only supported int8 and uint8. @@ -819,71 +846,85 @@ void ComputeLoop(OpKernelContext* ctx, const InT* input, const InT* scale, const } } -// Quantizes float32 to INT4 (in-place) using MLAS kernel. -#define DEFINE_COMPUTE_LOOP_FP32_TO_INT4(INT4_TYPE, QUANT_FUNC) \ - template <> \ - void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const INT4_TYPE* zero_point, \ - INT4_TYPE* output, int64_t M, int64_t K, int64_t N, bool saturate) { \ - ORT_UNUSED_PARAMETER(saturate); \ - size_t output_index = 0; \ - for (size_t m = 0; m < static_cast(M); m++) { \ - for (size_t bd = 0; bd < static_cast(K); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ - QUANT_FUNC(input, output, output_index, output_index + static_cast(N), \ - scale[bd], INT4_TYPE(zp, 0), ctx->GetOperatorThreadPool()); \ - input += N; \ - output_index += static_cast(N); \ - } \ - } \ - assert(output_index == static_cast(M * K * N)); \ +// Helper macros to create zero point with correct number of constructor arguments +#define CREATE_SUB_BYTE_ZP_2(TYPE, zp) TYPE(zp, 0) +#define CREATE_SUB_BYTE_ZP_4(TYPE, zp) TYPE(zp, 0, 0, 0) +#define CREATE_SUB_BYTE_ZP(TYPE, zp, ELEMENTS_PER_BYTE) CREATE_SUB_BYTE_ZP_##ELEMENTS_PER_BYTE(TYPE, zp) + +// Quantizes float32 to sub-byte types using MLAS kernel (4-bit) or generic quantization (2-bit). +#define DEFINE_COMPUTE_LOOP_FP32_TO_SUB_BYTE(SUB_BYTE_TYPE, QUANT_FUNC, ELEMENTS_PER_BYTE) \ + template <> \ + void ComputeLoop(OpKernelContext* ctx, const float* input, const float* scale, const SUB_BYTE_TYPE* zero_point, \ + SUB_BYTE_TYPE* output, int64_t M, int64_t K, int64_t N, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + size_t output_index = 0; \ + constexpr size_t shift_bits = (ELEMENTS_PER_BYTE == 2) ? 1 : 2; /* log2(ELEMENTS_PER_BYTE) */ \ + constexpr size_t mask = ELEMENTS_PER_BYTE - 1; /* For modulo operation */ \ + for (size_t m = 0; m < static_cast(M); m++) { \ + for (size_t bd = 0; bd < static_cast(K); bd++) { \ + size_t bd_i = bd >> shift_bits; /* bd / ELEMENTS_PER_BYTE */ \ + size_t bd_j = bd & mask; /* bd % ELEMENTS_PER_BYTE */ \ + SUB_BYTE_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + QUANT_FUNC(input, output, output_index, output_index + static_cast(N), \ + scale[bd], CREATE_SUB_BYTE_ZP(SUB_BYTE_TYPE, zp, ELEMENTS_PER_BYTE), \ + ctx->GetOperatorThreadPool()); \ + input += N; \ + output_index += static_cast(N); \ + } \ + } \ + assert(output_index == static_cast(M * K * N)); \ } -DEFINE_COMPUTE_LOOP_FP32_TO_INT4(Int4x2, ParQuantizeLinearStdS4) -DEFINE_COMPUTE_LOOP_FP32_TO_INT4(UInt4x2, ParQuantizeLinearStdU4) - -// Defines functions to quantize MLFloat16 to INT4. -// This is not an efficient implementation: we allocate a buffer, quantize to INT8, and then copy/clamp/pack -// into output INT4 buffer. -#define DEFINE_COMPUTE_LOOP_FP16_TO_INT4(INT4_TYPE) \ - template <> \ - void ComputeLoop(OpKernelContext * ctx, const MLFloat16* input, const MLFloat16* scale, \ - const INT4_TYPE* zero_point, INT4_TYPE* output, int64_t M, \ - int64_t K, int64_t N, bool saturate) { \ - ORT_UNUSED_PARAMETER(saturate); \ - \ - size_t total_size = static_cast(M * K * N); \ - auto tmp_buf = std::make_unique(total_size); \ - size_t tmp_buf_index = 0; \ - \ - for (size_t m = 0; m < static_cast(M); m++) { \ - for (size_t bd = 0; bd < static_cast(K); bd++) { \ - size_t bd_i = bd >> 1; /*bd / 2*/ \ - size_t bd_j = bd & 0x1; /*bd % 2*/ \ - INT4_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ - ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ - static_cast(N), scale[bd], \ - zp, ctx->GetOperatorThreadPool()); \ - input += N; \ - tmp_buf_index += static_cast(N); \ - } \ - } \ - \ - for (size_t i = 0; i < total_size; i++) { \ - tmp_buf[i] = std::min(INT4_TYPE::max_val, \ - std::max(INT4_TYPE::min_val, \ - tmp_buf[i])); \ - } \ - \ - size_t num_int4_pairs = (total_size + 1) / 2; \ - auto dst = gsl::make_span(output, num_int4_pairs); \ - auto src = gsl::make_span(tmp_buf.get(), total_size); \ - INT4_TYPE::Pack(dst, src); \ +DEFINE_COMPUTE_LOOP_FP32_TO_SUB_BYTE(Int4x2, ParQuantizeLinearStdS4, 2) +DEFINE_COMPUTE_LOOP_FP32_TO_SUB_BYTE(UInt4x2, ParQuantizeLinearStdU4, 2) +DEFINE_COMPUTE_LOOP_FP32_TO_SUB_BYTE(Int2x4, ParQuantizeLinearStdS2, 4) +DEFINE_COMPUTE_LOOP_FP32_TO_SUB_BYTE(UInt2x4, ParQuantizeLinearStdU2, 4) + +// Defines functions to quantize MLFloat16 to sub-byte types. +// This is not an efficient implementation: we allocate a buffer, quantize to the unpacked type, and then clamp/pack +// into output sub-byte buffer. +#define DEFINE_COMPUTE_LOOP_FP16_TO_SUB_BYTE(SUB_BYTE_TYPE, ELEMENTS_PER_BYTE) \ + template <> \ + void ComputeLoop(OpKernelContext * ctx, const MLFloat16* input, const MLFloat16* scale, \ + const SUB_BYTE_TYPE* zero_point, SUB_BYTE_TYPE* output, int64_t M, \ + int64_t K, int64_t N, bool saturate) { \ + ORT_UNUSED_PARAMETER(saturate); \ + \ + size_t total_size = static_cast(M * K * N); \ + auto tmp_buf = std::make_unique(total_size); \ + size_t tmp_buf_index = 0; \ + constexpr size_t shift_bits = (ELEMENTS_PER_BYTE == 2) ? 1 : 2; /* log2(ELEMENTS_PER_BYTE) */ \ + constexpr size_t mask = ELEMENTS_PER_BYTE - 1; /* For modulo operation */ \ + \ + for (size_t m = 0; m < static_cast(M); m++) { \ + for (size_t bd = 0; bd < static_cast(K); bd++) { \ + size_t bd_i = bd >> shift_bits; /* bd / ELEMENTS_PER_BYTE */ \ + size_t bd_j = bd & mask; /* bd % ELEMENTS_PER_BYTE */ \ + SUB_BYTE_TYPE::UnpackedType zp = zero_point ? zero_point[bd_i].GetElem(bd_j) : 0; \ + ParQuantizeLinearStd(input, tmp_buf.get() + tmp_buf_index, \ + static_cast(N), scale[bd], \ + zp, ctx->GetOperatorThreadPool()); \ + input += N; \ + tmp_buf_index += static_cast(N); \ + } \ + } \ + \ + for (size_t i = 0; i < total_size; i++) { \ + tmp_buf[i] = std::min(SUB_BYTE_TYPE::max_val, \ + std::max(SUB_BYTE_TYPE::min_val, \ + tmp_buf[i])); \ + } \ + \ + size_t num_packed = (total_size + ELEMENTS_PER_BYTE - 1) / ELEMENTS_PER_BYTE; \ + auto dst = gsl::make_span(output, num_packed); \ + auto src = gsl::make_span(tmp_buf.get(), total_size); \ + SUB_BYTE_TYPE::Pack(dst, src); \ } -DEFINE_COMPUTE_LOOP_FP16_TO_INT4(Int4x2) -DEFINE_COMPUTE_LOOP_FP16_TO_INT4(UInt4x2) +DEFINE_COMPUTE_LOOP_FP16_TO_SUB_BYTE(Int4x2, 2) +DEFINE_COMPUTE_LOOP_FP16_TO_SUB_BYTE(UInt4x2, 2) +DEFINE_COMPUTE_LOOP_FP16_TO_SUB_BYTE(Int2x4, 4) +DEFINE_COMPUTE_LOOP_FP16_TO_SUB_BYTE(UInt2x4, 4) // formula is Y = X / Scale + ZeroPoint template @@ -904,7 +945,8 @@ Status QuantizeLinear::Compute(OpKernelContext* ctx) const { T* output = y.MutableData(); constexpr int output_type_group_ = - boost::mp11::mp_contains, T>::value ? 2 + boost::mp11::mp_contains, T>::value ? 2 + : boost::mp11::mp_contains, T>::value ? 3 #if !defined(DISABLE_FLOAT8_TYPES) : boost::mp11::mp_contains::value ? 1 #endif diff --git a/onnxruntime/core/providers/cpu/tensor/cast_op.cc b/onnxruntime/core/providers/cpu/tensor/cast_op.cc index 767137a185cc2..156d581ff030b 100644 --- a/onnxruntime/core/providers/cpu/tensor/cast_op.cc +++ b/onnxruntime/core/providers/cpu/tensor/cast_op.cc @@ -14,6 +14,7 @@ #include "core/framework/data_types_internal.h" #include "core/framework/data_types.h" #include "core/framework/element_type_lists.h" +#include "core/framework/int2.h" #include "core/framework/op_kernel.h" #include "core/providers/cpu/tensor/utils.h" #include "core/providers/op_kernel_type_control.h" @@ -27,11 +28,22 @@ namespace onnxruntime { +namespace { +// Define a type list that extends AllIRv10 with INT2 types, but without Float4 +// Float4E2M1x2 doesn't support all the casting operations that other types do, +// so we don't include it here for the Cast operator. +using AllIRv10WithInt2 = + boost::mp11::mp_push_back< + element_type_lists::AllIRv10, + UInt2x4, + Int2x4>; +} // namespace + namespace op_kernel_type_control { // we're using one set of types for all opsets of Cast ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, - element_type_lists::AllIRv10); + AllIRv10WithInt2); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Input, 0, @@ -39,7 +51,7 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0, - element_type_lists::AllIRv10); + AllIRv10WithInt2); ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, Cast, Output, 0, @@ -95,6 +107,26 @@ struct IsOrtInt4ConversionType { static constexpr bool value = IsOrtInt4NumericConversionType::value || std::is_same_v; }; +// INT2 type support helpers +template +using IsOrtInt2Type = boost::mp11::mp_contains, T>; + +// Types that Int2x4 and UInt2x4 convert to and from, apart from string. +template +struct IsOrtInt2NumericConversionType { + static constexpr bool value = + std::is_same_v || + IsStandardIntegerType::value || + std::is_floating_point_v || + IsOrtFloat16Type::value || + IsOrtFloat8Type::value; +}; + +template +struct IsOrtInt2ConversionType { + static constexpr bool value = IsOrtInt2NumericConversionType::value || std::is_same_v; +}; + // string cast helpers // Note: when C++17 is available, use functions @@ -333,6 +365,121 @@ struct ToInt4Converter::value && IsOrtInt2ConversionType::value>> +struct FromInt2Converter { + // The input 'val' can be either an int8_t value coming from Int2x4.GetElem(pos), + // or an uint8_t value coming from UInt2x4.GetElem(pos), where pos can be 0, 1, 2, or 3. + static DstType Convert(typename SrcType::UnpackedType val) { + if constexpr (IsOrtFloat16Type::value) { + return DstType(static_cast(val)); + } else if constexpr (IsOrtFloat8Type::value) { + return DstType(static_cast(val), true); + } else if constexpr (std::is_same_v) { + return val != 0; + } else if constexpr (std::is_same_v) { + // val has type (u)int8_t, so static_cast is required in order for std::to_string + // to interpret val as a number (1 -> "1"), instead of a char. + return std::to_string(static_cast(val)); + } else { + return static_cast(val); + } + } +}; + +// Helper for converting any source type to (U)Int2x4::UnpackedType values (int8_t and uint8_t). +template ::value && IsOrtInt2Type::value>> +struct ToInt2Converter { + static typename DstType::UnpackedType Convert(const SrcType& val); +}; + +// Integer types -> Int2x4 +// INT2 values range from -2 to 1 (2-bit signed two's complement) +template +struct ToInt2Converter::value>> { + static int8_t Convert(const SrcType& val) { + // Truncate to 2 least significant bits + uint8_t truncated = static_cast(val) & 0x03; + + // Sign-extend: if bit 1 is set, it's negative in 2-bit two's complement, + // so set the 6 most significant bits to 1. + return static_cast((truncated & 0x2) ? (truncated | 0xFC) : truncated); + } +}; + +// Integer types -> UInt2x4 +// UINT2 values range from 0 to 3 (2-bit unsigned) +template +struct ToInt2Converter::value>> { + static uint8_t Convert(const SrcType& val) { + // Truncate to 2 least significant bits + return static_cast(val) & 0x03; + } +}; + +// bool -> (U)Int2x4 +template +struct ToInt2Converter::value>> { + static typename DstType::UnpackedType Convert(bool val) { + return static_cast(val ? 1 : 0); + } +}; + +// float -> (U)Int2x4 +template +struct ToInt2Converter::value>> { + static typename DstType::UnpackedType Convert(const float& val) { + int result = static_cast(std::roundf(val)); + return ToInt2Converter::Convert(result); + } +}; + +// double -> (U)Int2x4 +template +struct ToInt2Converter::value>> { + static typename DstType::UnpackedType Convert(const double& val) { + int result = static_cast(std::round(val)); + return ToInt2Converter::Convert(result); + } +}; + +// float 8 -> (U)Int2x4 +template +struct ToInt2Converter::value && IsOrtInt2Type::value>> { + static typename DstType::UnpackedType Convert(const SrcType& val) { + float result = val.ToFloat(); + return ToInt2Converter::Convert(result); + } +}; + +// float 16 -> (U)Int2x4 +template +struct ToInt2Converter::value && IsOrtInt2Type::value>> { + static typename DstType::UnpackedType Convert(const SrcType& val) { + float f_val = static_cast(val); + return ToInt2Converter::Convert(f_val); + } +}; + +// string -> (U)Int2x4 +template +struct ToInt2Converter::value>> { + static typename DstType::UnpackedType Convert(const std::string& val) { + double result = std::stod(val); + return ToInt2Converter::Convert(result); + } +}; + // generic tensor X -> Y template struct TensorCaster { @@ -349,10 +496,10 @@ struct TensorCaster { } }; -// tensor X -> string, if X != (U)Int4x2 +// tensor X -> string, if X != (U)Int4x2 and X != (U)Int2x4 template struct TensorCaster::value>> { + std::enable_if_t::value && !IsOrtInt2Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -363,10 +510,10 @@ struct TensorCaster X, if X != (U)Int4x2 +// tensor string -> X, if X != (U)Int4x2 and X != (U)Int2x4 template struct TensorCaster::value>> { + std::enable_if_t::value && !IsOrtInt2Type::value>> { void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { const std::ptrdiff_t shape_size = narrow(shape.Size()); const auto* in_data = in.Data(); @@ -479,6 +626,228 @@ struct TensorCaster { } }; +// (U)Int2x4 -> string or numeric types +template +struct TensorCaster::value && IsOrtInt2ConversionType::value>> { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + // 4 elements per byte, accessed by position 0-3 + auto val = in_data[i >> 2].GetElem(i & 0x3); + out_data[i] = FromInt2Converter::Convert(val); + } + } +}; + +// string or numeric types -> (U)Int2x4 +template +struct TensorCaster::value && IsOrtInt2Type::value>> { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + ptrdiff_t i = 0; + // Process 4 elements at a time + for (; i + 3 < shape_size; i += 4) { + auto val0 = ToInt2Converter::Convert(in_data[i]); + auto val1 = ToInt2Converter::Convert(in_data[i + 1]); + auto val2 = ToInt2Converter::Convert(in_data[i + 2]); + auto val3 = ToInt2Converter::Convert(in_data[i + 3]); + out_data[i >> 2] = DstType(val0, val1, val2, val3); + } + + // Handle remaining elements + if (i < shape_size) { + auto val0 = ToInt2Converter::Convert(in_data[i]); + auto val1 = (i + 1 < shape_size) ? ToInt2Converter::Convert(in_data[i + 1]) : 0; + auto val2 = (i + 2 < shape_size) ? ToInt2Converter::Convert(in_data[i + 2]) : 0; + auto val3 = (i + 3 < shape_size) ? ToInt2Converter::Convert(in_data[i + 3]) : 0; + out_data[i >> 2] = DstType(val0, val1, val2, val3); + } + } +}; + +// Int2x4 -> UInt2x4 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t num_quads = narrow((shape.Size() + 3) >> 2); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < num_quads; ++i) { + auto e0 = in_data[i].GetElem(0); + auto e1 = in_data[i].GetElem(1); + auto e2 = in_data[i].GetElem(2); + auto e3 = in_data[i].GetElem(3); + + // Reinterpret: just mask to 2 bits + uint8_t u0 = static_cast(e0) & 0x03; + uint8_t u1 = static_cast(e1) & 0x03; + uint8_t u2 = static_cast(e2) & 0x03; + uint8_t u3 = static_cast(e3) & 0x03; + + out_data[i] = UInt2x4(u0, u1, u2, u3); + } + } +}; + +// UInt2x4 -> Int2x4 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t num_quads = narrow((shape.Size() + 3) >> 2); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < num_quads; ++i) { + auto e0 = in_data[i].GetElem(0); + auto e1 = in_data[i].GetElem(1); + auto e2 = in_data[i].GetElem(2); + auto e3 = in_data[i].GetElem(3); + + // Sign-extend: if bit 1 is set, the value is negative in 2-bit two's complement + int8_t s0 = static_cast((e0 & 0x2) ? (e0 | 0xFC) : e0); + int8_t s1 = static_cast((e1 & 0x2) ? (e1 | 0xFC) : e1); + int8_t s2 = static_cast((e2 & 0x2) ? (e2 | 0xFC) : e2); + int8_t s3 = static_cast((e3 & 0x2) ? (e3 | 0xFC) : e3); + + out_data[i] = Int2x4(s0, s1, s2, s3); + } + } +}; + +// Int4x2 -> Int2x4 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 1].GetElem(i & 0x1); + int8_t truncated = static_cast((val & 0x03) << 6) >> 6; + out_data[i >> 2].SetElem(i & 0x3, truncated); + } + } +}; + +// Int4x2 -> UInt2x4 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 1].GetElem(i & 0x1); + uint8_t truncated = static_cast(val) & 0x03; + out_data[i >> 2].SetElem(i & 0x3, truncated); + } + } +}; + +// UInt4x2 -> Int2x4 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 1].GetElem(i & 0x1); + int8_t truncated = static_cast((val & 0x03) << 6) >> 6; + out_data[i >> 2].SetElem(i & 0x3, truncated); + } + } +}; + +// UInt4x2 -> UInt2x4 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 1].GetElem(i & 0x1); + uint8_t truncated = val & 0x03; + out_data[i >> 2].SetElem(i & 0x3, truncated); + } + } +}; + +// Int2x4 -> Int4x2 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 2].GetElem(i & 0x3); + out_data[i >> 1].SetElem(i & 0x1, val); + } + } +}; + +// Int2x4 -> UInt4x2 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 2].GetElem(i & 0x3); + uint8_t masked = static_cast(val) & 0x0F; + out_data[i >> 1].SetElem(i & 0x1, masked); + } + } +}; + +// UInt2x4 -> Int4x2 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 2].GetElem(i & 0x3); + out_data[i >> 1].SetElem(i & 0x1, static_cast(val)); + } + } +}; + +// UInt2x4 -> UInt4x2 +template <> +struct TensorCaster { + void Cast(const OpKernelContext&, const TensorShape& shape, const Tensor& in, Tensor& out) const { + const ptrdiff_t shape_size = narrow(shape.Size()); + const auto* in_data = in.Data(); + auto* out_data = out.MutableData(); + + for (ptrdiff_t i = 0; i < shape_size; ++i) { + auto val = in_data[i >> 2].GetElem(i & 0x3); + out_data[i >> 1].SetElem(i & 0x1, val); + } + } +}; + #if defined(_M_AMD64) && !defined(_M_ARM64EC) // specializations to use optimized and Windows x64-specific @@ -502,7 +871,7 @@ void CastMLFloat16ThroughFloatTensor( // tensor MLFloat16 -> X template struct TensorCaster::value>> { + std::enable_if_t::value && !IsOrtInt2Type::value>> { void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& in, Tensor& out) const { CastMLFloat16ThroughFloatTensor(context, shape, in, out); } @@ -547,6 +916,16 @@ struct TensorCasterNoSat float 8 +template +struct TensorCasterNoSat::value && IsOrtFloat8Type::value>> { + void Cast(const OpKernelContext& context, const TensorShape& shape, const Tensor& src, Tensor& dst) const { + // Int2x4/UInt2x4 always fit inside any Float8 type, so we can reuse the saturate = true implementation. + TensorCaster{}.Cast(context, shape, src, dst); + } +}; + // tensor string -> float 8 template struct TensorCasterNoSat(index / N); + const int64_t i = static_cast(index % N); const int64_t src_offset_batch = batch * data_batch_bytes; const int64_t dst_offset_batch = batch * gathered_batch_bytes; @@ -120,12 +120,14 @@ Status GatherCopyData(const Tensor* indices_tensor, const uint8_t* src_base, uin memcpy(dst_base + dst_offset, src_base + src_offset, narrow(block_size)); } }; - concurrency::ThreadPool::TryParallelFor(tp, SafeInt(M) * N, static_cast(block_size), - [&lambda](ptrdiff_t first, ptrdiff_t last) { - for (int index = static_cast(first), end = static_cast(last); index < end; ++index) { - lambda(index); - } - }); + + concurrency::ThreadPool::TryParallelFor( + tp, SafeInt(M) * N, static_cast(block_size), + [&lambda](ptrdiff_t first, ptrdiff_t last) { + for (ptrdiff_t index = first; index < last; ++index) { + lambda(index); + } + }); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc index ad3faa70ed6af..a0a848eef0dff 100644 --- a/onnxruntime/core/providers/cpu/tensor/gather_nd.cc +++ b/onnxruntime/core/providers/cpu/tensor/gather_nd.cc @@ -66,6 +66,18 @@ Status GatherNDBase::PrepareForCompute(const TensorShape& input_shape, const Ten const auto num_slices = indices_shape.SizeToDimension(indices_shape.NumDimensions() - 1); const auto slice_size = input_shape.SizeFromDimension(SafeInt(batch_dims_) + num_slice_dims); const auto num_batches = input_shape.SizeToDimension(SafeInt(batch_dims_)); + + // Validate batch dimensions to prevent division by zero + if (num_batches == 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherND: input tensor batch dimensions cannot be zero"); + } + if (num_slices % num_batches != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "GatherND: indices batch size (", num_slices, + ") is not divisible by input batch size (", num_batches, ")"); + } + const auto input_batch_stride = input_shape.SizeFromDimension(SafeInt(batch_dims_)); const auto num_slices_per_batch = num_slices / num_batches; std::vector sizes_from_slice_dims(onnxruntime::narrow(num_slice_dims)); diff --git a/onnxruntime/core/providers/cpu/tensor/transpose.cc b/onnxruntime/core/providers/cpu/tensor/transpose.cc index 1f29ccbd4d8f8..3e815807e307d 100644 --- a/onnxruntime/core/providers/cpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/cpu/tensor/transpose.cc @@ -18,6 +18,15 @@ namespace { using DefaultDataTypes = element_type_lists::All; } // namespace +namespace { +// Define a type list that extends AllIRv10 with INT2 types +using AllIRv10WithInt2 = + boost::mp11::mp_push_back< + element_type_lists::AllIRv10, + UInt2x4, + Int2x4>; +} // namespace + namespace op_kernel_type_control { // we're using one set of types for all opsets ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( @@ -39,6 +48,14 @@ ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPE_LIST( kCpuExecutionProvider, kOnnxDomain, Transpose, 21, Input, 0, element_type_lists::AllIRv10); +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, Transpose, 25, Input, 0, + AllIRv10WithInt2); + +ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, Transpose, 25, Input, 0, + AllIRv10WithInt2); + } // namespace op_kernel_type_control namespace { @@ -47,6 +64,8 @@ using EnabledDataTypesAllOpsets = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS Transpose, Input, 0); using EnabledDataTypesOpset21 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, Transpose, 21, Input, 0); +using EnabledDataTypesOpset25 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + Transpose, 25, Input, 0); } // namespace /* A permutation [a,b,c,...] indicates that @@ -371,38 +390,38 @@ static Status TransposeImpl(const gsl::span& permutations, const T return DoUntypedTranspose(permutations, input, output, input_shape_override); } -template -static Status UnpackInt4Tensor(const Tensor& src, Tensor& dst, AllocatorPtr cpu_allocator) { - using UnpackedType = typename Int4Type::UnpackedType; +template +static Status UnpackSubByteTensor(const Tensor& src, Tensor& dst, AllocatorPtr cpu_allocator) { + using UnpackedType = typename SubByteType::UnpackedType; MLDataType int8_elem_type = DataTypeImpl::GetType(); const TensorShape& shape = src.Shape(); Tensor int8_tensor(int8_elem_type, shape, cpu_allocator); - ORT_RETURN_IF_NOT(Int4Type::Unpack(int8_tensor.MutableDataAsSpan(), src.DataAsSpan()), - "Failed to unpack Int4x2 Tensor to an int8_t Tensor"); + ORT_RETURN_IF_NOT(SubByteType::Unpack(int8_tensor.MutableDataAsSpan(), src.DataAsSpan()), + "Failed to unpack sub-byte Tensor to an int8_t/uint8_t Tensor"); dst = std::move(int8_tensor); return Status::OK(); } -template -static Status DoTransposeInt4(const gsl::span& permutations, const Tensor& input, Tensor& output, - const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { - using Int8Type = typename Int4Type::UnpackedType; +template +static Status DoTransposeSubByte(const gsl::span& permutations, const Tensor& input, Tensor& output, + const TensorShape* input_shape_override, concurrency::ThreadPool* tp) { + using UnpackedType = typename SubByteType::UnpackedType; - ORT_RETURN_IF_NOT(input.IsDataType() && output.IsDataType(), - "Expected to transpose int4 tensor"); + ORT_RETURN_IF_NOT(input.IsDataType() && output.IsDataType(), + "Expected to transpose sub-byte tensor"); - // Convert to Tensor, transpose, and then repack back to Tensor. + // Convert to Tensor, transpose, and then repack back to Tensor. AllocatorPtr cpu_allocator = CPUAllocator::DefaultInstance(); Tensor input_unpacked; - Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); + Tensor output_unpacked(DataTypeImpl::GetType(), output.Shape(), cpu_allocator); - ORT_RETURN_IF_ERROR((UnpackInt4Tensor(input, input_unpacked, cpu_allocator))); + ORT_RETURN_IF_ERROR((UnpackSubByteTensor(input, input_unpacked, cpu_allocator))); ORT_RETURN_IF_ERROR(TransposeImpl(permutations, input_unpacked, output_unpacked, input_shape_override, tp)); - ORT_RETURN_IF_NOT(Int4Type::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), - "Failed to pack 8-bit Tensor into 4-bit Tensor"); + ORT_RETURN_IF_NOT(SubByteType::Pack(output.MutableDataAsSpan(), output_unpacked.DataAsSpan()), + "Failed to pack 8-bit Tensor into sub-byte Tensor"); return Status::OK(); } @@ -417,12 +436,23 @@ Status TransposeBase::DoTranspose(const gsl::span& permutations, c return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Mismatched data types between input and output Tensors. ", input_type, " != ", output_type); } + + // Handle int4 types if (input.IsDataType()) { - return DoTransposeInt4(permutations, input, output, input_shape_override, tp); + return DoTransposeSubByte(permutations, input, output, input_shape_override, tp); } if (input.IsDataType()) { - return DoTransposeInt4(permutations, input, output, input_shape_override, tp); + return DoTransposeSubByte(permutations, input, output, input_shape_override, tp); + } + + // Handle int2 types + if (input.IsDataType()) { + return DoTransposeSubByte(permutations, input, output, input_shape_override, tp); + } + + if (input.IsDataType()) { + return DoTransposeSubByte(permutations, input, output, input_shape_override, tp); } return TransposeImpl(permutations, input, output, input_shape_override, tp); @@ -496,7 +526,7 @@ ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ONNX_CPU_OPERATOR_KERNEL( Transpose, 25, - KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), + KernelDefBuilder().TypeConstraint("T", BuildKernelDefConstraintsFromTypeList()), Transpose); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index 32f5c98da1585..d50a4deca3298 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -15,12 +15,17 @@ #pragma warning(push) // 'fp4_interpretation' : unreferenced parameter #pragma warning(disable : 4100) +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" #endif #include #if defined(_MSC_VER) #pragma warning(pop) +#elif defined(__GNUC__) +#pragma GCC diagnostic pop #endif #endif diff --git a/onnxruntime/core/providers/cuda/cuda_type_conversion.h b/onnxruntime/core/providers/cuda/cuda_type_conversion.h index 38cdce1380fad..04e47a9930710 100644 --- a/onnxruntime/core/providers/cuda/cuda_type_conversion.h +++ b/onnxruntime/core/providers/cuda/cuda_type_conversion.h @@ -14,12 +14,17 @@ #pragma warning(push) // 'fp4_interpretation' : unreferenced parameter #pragma warning(disable : 4100) +#elif defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-parameter" #endif #include #if defined(_MSC_VER) #pragma warning(pop) +#elif defined(__GNUC__) +#pragma GCC diagnostic pop #endif #endif diff --git a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu index 51c80d272bb96..62801c8da1e5f 100644 --- a/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu +++ b/onnxruntime/core/providers/cuda/reduction/reduction_functions.cu @@ -209,7 +209,7 @@ __device__ void reduce_all( // the size of shared_memory equals to the number of warps. #pragma unroll for (int stride = MAX_NUM_WARPS_PER_BLOCK / 2; stride > 0; stride /= 2) { - if (tid_in_block + stride < num_warps_in_block) { + if (tid_in_block < stride && tid_in_block + stride < num_warps_in_block) { shared_memory[tid_in_block] += shared_memory[tid_in_block + stride]; } __syncthreads(); diff --git a/onnxruntime/core/providers/cuda/tensor/pad.cc b/onnxruntime/core/providers/cuda/tensor/pad.cc index 656890e796a1c..d75c6e947e09c 100644 --- a/onnxruntime/core/providers/cuda/tensor/pad.cc +++ b/onnxruntime/core/providers/cuda/tensor/pad.cc @@ -259,7 +259,7 @@ Status Pad::ComputeInternal(OpKernelContext* ctx) const { TArray fdm_output_strides(dimension_count); TensorPitches output_strides(output_dims); - for (auto i = 0; i < dimension_count; i++) { + for (size_t i = 0; i < dimension_count; i++) { fdm_output_strides[i] = fast_divmod(static_cast(output_strides[i])); } diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index a96d4c82a7fdc..963fa020d033a 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -585,6 +585,13 @@ size_t CalcResizeBufferSize(const onnxruntime::UpsampleMode upsample_mode, static_cast(std::accumulate(output_dims.begin(), output_dims.end(), (int64_t)0)); case UpsampleMode::LINEAR: + // For LINEAR mode: + // - bilinear (2-D/4-D) uses mapping for [H, W] + // - trilinear (3-D/5-D) uses mapping for [D, H, W] + if (output_dims.size() == 3 || output_dims.size() == 5) { + return sizeof(LinearMappingInfo) * + static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 3, (int64_t)0)); + } return sizeof(LinearMappingInfo) * static_cast(std::accumulate(output_dims.rbegin(), output_dims.rbegin() + 2, (int64_t)0)); case UpsampleMode::CUBIC: diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc index def6d7e9ea916..ee4f45f5057e0 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.cc @@ -2578,6 +2578,84 @@ const InlinedVector NvExecutionProvider::GetEpContextNodes() const return ep_context_nodes; } +std::string NvExecutionProvider::GetCompiledModelCompatibilityInfo( + const onnxruntime::GraphViewer& graph_viewer) const { + ORT_UNUSED_PARAMETER(graph_viewer); + + // Protect read access to engine_headers_ for thread safety + auto lock = GetApiLock(); + + // Compatibility info is only supported when there is exactly one engine. + // If multiple EPContext nodes/engines exist, return empty so validation is not applicable. + if (engine_headers_.size() > 1) { + return std::string(); + } + + // If we have stored engine headers, return the first one found + // (typically there's only one per EP context) + if (!engine_headers_.empty()) { + return engine_headers_.begin()->second; + } + + // No headers available - validation not supported for this model + return std::string(); +} + +common::Status NvExecutionProvider::ValidateCompiledModelCompatibilityInfo( + const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const { + // If no compatibility info provided, validation not applicable + if (compatibility_info.empty()) { + model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return Status::OK(); + } + + // Decode hex string to binary + std::vector engine_header; + try { + engine_header = HexStringToBinary(compatibility_info); + } catch (const std::exception& ex) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what(); + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return Status::OK(); + } + + // Use TensorRT RTX's getEngineValidity to check compatibility + uint64_t diagnostics = 0; + nvinfer1::EngineValidity validity = runtime_->getEngineValidity( + engine_header.data(), + engine_header.size(), + &diagnostics); + + // Map TensorRT RTX validity to ORT compatibility status + switch (validity) { + case nvinfer1::EngineValidity::kVALID: + LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] Engine is fully compatible with this system"; + model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + break; + + case nvinfer1::EngineValidity::kSUBOPTIMAL: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is compatible but recompilation recommended " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + break; + + case nvinfer1::EngineValidity::kINVALID: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine is incompatible with this system " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + + default: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown TensorRT validity status: " + << static_cast(validity); + model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + } + + return Status::OK(); +} + Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer, const Node& fused_node, std::unordered_map& input_map, @@ -2854,6 +2932,18 @@ Status NvExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& gr return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "NvTensorRTRTX EP failed to create engine from network for fused node: " + fused_node.Name()); } + + // Capture engine header (first 64 bytes) for compatibility validation + if (serialized_engine->size() >= kTensorRTEngineHeaderSize) { + std::string engine_header_hex = BinaryToHexString( + serialized_engine->data(), + kTensorRTEngineHeaderSize); + engine_headers_[fused_node.Name()] = engine_header_hex; + } else { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine too small to capture header for validation: " + << serialized_engine->size() << " bytes"; + } + trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size())); if (trt_engine == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h index 5c6ca20d75ec6..e415143a6ddd1 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider.h @@ -345,6 +345,13 @@ class NvExecutionProvider : public IExecutionProvider { const InlinedVector GetEpContextNodes() const override; + // Engine compatibility validation methods + std::string GetCompiledModelCompatibilityInfo(const onnxruntime::GraphViewer& graph_viewer) const override; + + common::Status ValidateCompiledModelCompatibilityInfo( + const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override; + private: mutable NvExecutionProviderInfo info_; bool external_stream_ = false; @@ -424,6 +431,10 @@ class NvExecutionProvider : public IExecutionProvider { std::unordered_map> profiles_; std::unordered_map dds_output_allocator_maps_; + // Storage for engine headers (64 bytes) for compatibility validation + // Maps fused_node_name -> hex-encoded engine header + mutable std::unordered_map engine_headers_; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc index 5fe37a6c30e33..90e488a1eda18 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_execution_provider_custom_ops.cc @@ -7,18 +7,7 @@ #include "core/framework/provider_options.h" #include "nv_execution_provider_custom_ops.h" #include "nv_execution_provider.h" - -// The filename extension for a shared library is different per platform -#ifdef _WIN32 -#define LIBRARY_PREFIX -#define LIBRARY_EXTENSION ORT_TSTR(".dll") -#elif defined(__APPLE__) -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".dylib" -#else -#define LIBRARY_PREFIX "lib" -#define LIBRARY_EXTENSION ".so" -#endif +#include "nv_platform_utils.h" namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose); @@ -76,14 +65,14 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // This library contains GroupQueryAttention and RotaryEmbedding plugins for transformer models try { const auto& env = onnxruntime::GetDefaultEnv(); - auto external_plugin_path = env.GetRuntimePath() + + auto external_plugin_path = GetEPLibraryDirectory() + PathString(LIBRARY_PREFIX ORT_TSTR("tensorrt_plugins") LIBRARY_EXTENSION); void* external_plugin_handle = nullptr; auto status = env.LoadDynamicLibrary(external_plugin_path, false, &external_plugin_handle); if (status.IsOK()) { LOGS_DEFAULT(INFO) << "[NvTensorRTRTX EP] External plugins loaded: tensorrt_plugins (GQA + RotaryEmbedding)"; } else { - LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] tensorrt_plugins library not found in runtime path (optional)"; + LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] tensorrt_plugins library not found in EP library path (optional)"; } } catch (const std::exception& e) { LOGS_DEFAULT(VERBOSE) << "[NvTensorRTRTX EP] tensorrt_plugins library not available: " << e.what(); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_platform_utils.h b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_platform_utils.h new file mode 100644 index 0000000000000..f3298a8449157 --- /dev/null +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_platform_utils.h @@ -0,0 +1,67 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include "core/common/path_string.h" + +#ifdef _WIN32 +#include +#else +#include +#endif + +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + +namespace onnxruntime { +inline PathString GetEPLibraryDirectory() { +#ifdef _WIN32 + HMODULE hModule = NULL; + // Get handle to the DLL executing this code + if (!GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(GetEPLibraryDirectory), + &hModule)) { + return PathString(); + } + + wchar_t buffer[MAX_PATH]; + DWORD len = GetModuleFileNameW(hModule, buffer, MAX_PATH); + if (len == 0 || len >= MAX_PATH) { + return PathString(); + } + + std::wstring path(buffer); + size_t lastSlash = path.find_last_of(L"\\/"); + if (lastSlash != std::wstring::npos) { + return PathString(path.substr(0, lastSlash + 1)); + } + return PathString(); +#else + // Linux and other Unix-like platforms + Dl_info dl_info; + + if (dladdr((void*)&GetEPLibraryDirectory, &dl_info) == 0 || dl_info.dli_fname == nullptr) { + return PathString(); + } + + std::string so_path(dl_info.dli_fname); + size_t last_slash = so_path.find_last_of('/'); + if (last_slash != std::string::npos) { + return PathString(so_path.substr(0, last_slash + 1)); + } + return PathString(); +#endif +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc index e5015e705958d..d1e449eb58870 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/nv_provider_factory.cc @@ -13,6 +13,7 @@ #include "core/providers/cuda/shared_inc/cuda_call.h" #include "core/providers/cuda/cuda_stream_handle.h" +#include "onnx_ctx_model_helper.h" #include "nv_provider_factory.h" #include "nv_execution_provider.h" #include "nv_provider_factory_creator.h" @@ -21,6 +22,11 @@ using namespace onnxruntime; +// External declarations +namespace onnxruntime { +extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); +} + namespace onnxruntime { void InitializeRegistry(); @@ -541,7 +547,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { IsStreamAware = IsStreamAwareImpl; CreateSyncStreamForDevice = CreateSyncStreamForDeviceImpl; - + ValidateCompiledModelCompatibilityInfo = ValidateCompiledModelCompatibilityInfoImpl; ort_version_supported = ORT_API_VERSION; // Set to the ORT version we were compiled with. } @@ -584,6 +590,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { * @return True if the device is a supported NVIDIA GPU, false otherwise. */ bool IsOrtHardwareDeviceSupported(const OrtHardwareDevice& device) { +#if _WIN32 const auto& metadata_entries = device.metadata.Entries(); const auto it = metadata_entries.find("LUID"); if (it == metadata_entries.end()) { @@ -625,6 +632,25 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { } return false; +#else + const auto& metadata_entries = device.metadata.Entries(); + const auto it = metadata_entries.find("pci_bus_id"); + if (it == metadata_entries.end()) { + return false; + } + auto& target_id = it->second; + int cuda_device_idx = 0; + if (cudaDeviceGetByPCIBusId(&cuda_device_idx, target_id.c_str()) != cudaSuccess) { + return false; + } + + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, cuda_device_idx) != cudaSuccess) { + return false; + } + // Ampere architecture or newer is required. + return prop.major >= 8; +#endif } // Creates and returns OrtEpDevice instances for all OrtHardwareDevices that this factory supports. @@ -661,6 +687,7 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { RETURN_IF_ERROR(factory->ort_api.GetEpApi()->CreateEpDevice(factory, &device, ep_metadata, ep_options, &ep_devices[num_ep_devices])); + factory->ort_api.ReleaseKeyValuePairs(ep_options); factory->ort_api.ReleaseKeyValuePairs(ep_metadata); @@ -735,6 +762,120 @@ struct NvTensorRtRtxEpFactory : OrtEpFactory { return nullptr; } + /** + * This function is called by the public C API GetModelCompatibilityForEpDevices. + * It uses TensorRT RTX runtime directly to call runtime->getEngineValidity() to check the 64-byte engine header. + * + * @param this_ptr Factory instance pointer + * @param devices Hardware devices (not used, validation is done against current system) + * @param num_devices Number of devices + * @param compatibility_info Hex-encoded 64-byte TensorRT RTX engine header (128 hex characters) + * @param model_compatibility Output parameter for compatibility status + * @return OrtStatus* nullptr on success, error status on failure + */ + static OrtStatus* ORT_API_CALL ValidateCompiledModelCompatibilityInfoImpl( + OrtEpFactory* this_ptr, + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept { + auto& factory = *static_cast(this_ptr); + + // Validate input parameters + if (compatibility_info == nullptr || model_compatibility == nullptr) { + return factory.ort_api.CreateStatus(ORT_INVALID_ARGUMENT, + "[NvTensorRTRTX EP] Invalid arguments: compatibility_info or model_compatibility is null"); + } + + // Device parameters not used for header validation + ORT_UNUSED_PARAMETER(devices); + ORT_UNUSED_PARAMETER(num_devices); + + try { + // If no compatibility info provided, validation not applicable + if (compatibility_info[0] == '\0') { + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + + // Decode hex string to binary + std::vector engine_header; + try { + engine_header = HexStringToBinary(std::string(compatibility_info)); + } catch (const std::exception& ex) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Failed to decode engine header: " << ex.what(); + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Validate header size (keep in sync with TensorRT engine header size) + if (engine_header.size() != kTensorRTEngineHeaderSize) { + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Invalid header size: " << engine_header.size() + << " bytes (expected " << kTensorRTEngineHeaderSize << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + return nullptr; + } + + // Create TensorRT runtime for validation + static std::mutex runtime_creation_mutex; + std::unique_ptr runtime; + { + std::lock_guard lock(runtime_creation_mutex); + TensorrtLogger& trt_logger = GetTensorrtLogger(false); + runtime.reset(nvinfer1::createInferRuntime(trt_logger)); + } + + if (!runtime) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Failed to create TensorRT runtime"; + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Failed to create TensorRT runtime"); + } + + // Use TensorRT's getEngineValidity to check compatibility + uint64_t diagnostics = 0; + nvinfer1::EngineValidity validity = runtime->getEngineValidity( + engine_header.data(), + engine_header.size(), + &diagnostics); + + // Map TensorRT validity to ORT compatibility status + switch (validity) { + case nvinfer1::EngineValidity::kVALID: + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + break; + + case nvinfer1::EngineValidity::kSUBOPTIMAL: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine compatible but recompilation recommended " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_SUPPORTED_PREFER_RECOMPILATION; + break; + + case nvinfer1::EngineValidity::kINVALID: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Engine incompatible with this system " + << "(diagnostics: 0x" << std::hex << diagnostics << std::dec << ")"; + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + + default: + LOGS_DEFAULT(WARNING) << "[NvTensorRTRTX EP] Unknown validity status: " + << static_cast(validity); + *model_compatibility = OrtCompiledModelCompatibility_EP_UNSUPPORTED; + break; + } + + return nullptr; + + } catch (const std::exception& ex) { + std::string error_msg = std::string("[NvTensorRTRTX EP] Exception during validation: ") + ex.what(); + LOGS_DEFAULT(ERROR) << error_msg; + return factory.ort_api.CreateStatus(ORT_FAIL, error_msg.c_str()); + } catch (...) { + LOGS_DEFAULT(ERROR) << "[NvTensorRTRTX EP] Unknown exception during validation"; + return factory.ort_api.CreateStatus(ORT_FAIL, + "[NvTensorRTRTX EP] Unknown exception during validation"); + } + } + OrtStatus* CreateMemoryInfoForDevices(int num_devices) { gpu_memory_infos.reserve(num_devices); host_accessible_memory_infos.reserve(num_devices); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc index c1626fa4f36ad..b6a4069c59700 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.cc @@ -14,6 +14,53 @@ namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(bool verbose_log); +/* + * Convert binary data to hex string + */ +std::string BinaryToHexString(const void* data, size_t size) { + static const char hex_chars[] = "0123456789abcdef"; + const uint8_t* bytes = static_cast(data); + std::string result; + result.reserve(size * 2); + + for (size_t i = 0; i < size; ++i) { + result.push_back(hex_chars[(bytes[i] >> 4) & 0xF]); + result.push_back(hex_chars[bytes[i] & 0xF]); + } + return result; +} + +/* + * Convert hex string back to binary + */ +std::vector HexStringToBinary(const std::string& hex) { + if (hex.size() % 2 != 0) { + ORT_THROW("Hex string must have even length"); + } + + std::vector result; + result.reserve(hex.size() / 2); + + for (size_t i = 0; i < hex.size(); i += 2) { + uint8_t byte = 0; + + // High nibble + char c = hex[i]; + byte |= (c >= '0' && c <= '9') ? static_cast((c - '0') << 4) : (c >= 'a' && c <= 'f') ? static_cast((c - 'a' + 10) << 4) + : (c >= 'A' && c <= 'F') ? static_cast((c - 'A' + 10) << 4) + : 0; + + // Low nibble + c = hex[i + 1]; + byte |= (c >= '0' && c <= '9') ? static_cast(c - '0') : (c >= 'a' && c <= 'f') ? static_cast(c - 'a' + 10) + : (c >= 'A' && c <= 'F') ? static_cast(c - 'A' + 10) + : 0; + + result.push_back(byte); + } + return result; +} + /* * Check whether the graph has the EP context contrib op. * The op can contain the precompiled engine info for TRT EP to directly load the engine. diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h index 7c52f26cc9177..80263b1ba80d5 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/onnx_ctx_model_helper.h @@ -24,6 +24,12 @@ static const std::string PARTITION_NAME = "partition_name"; static const std::string SDK_VERSION = "ep_sdk_version"; static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft"; +// TensorRT does not currently expose a header size define; keep in sync with TRT engine serialization header size. +constexpr size_t kTensorRTEngineHeaderSize = 64; +// Helper functions for engine header validation +std::string BinaryToHexString(const void* data, size_t size); +std::vector HexStringToBinary(const std::string& hex); + bool GraphHasCtxNode(const GraphViewer& graph_viewer, size_t& node_idx); const std::filesystem::path& GetModelPath(const GraphViewer& graph_viewer); std::filesystem::path GetPathOrParentPathOfCtxModel(const std::string& ep_context_file_path); diff --git a/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds b/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds index 094abb3329781..251e39e089275 100644 --- a/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds +++ b/onnxruntime/core/providers/nv_tensorrt_rtx/version_script.lds @@ -2,6 +2,8 @@ VERS_1.0 { global: GetProvider; + CreateEpFactories; + ReleaseEpFactory; # Hide everything else. local: diff --git a/onnxruntime/core/providers/openvino/backend_manager.cc b/onnxruntime/core/providers/openvino/backend_manager.cc index 3426a2781bbc6..892bdec7abe83 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.cc +++ b/onnxruntime/core/providers/openvino/backend_manager.cc @@ -96,15 +96,15 @@ BackendManager::BackendManager(SessionContext& session_context, ptr_stream_t model_stream; std::unique_ptr model_proto; if (subgraph_context_.is_ep_ctx_graph) { - if (!session_context_.reshape.empty()) { + if (!session_context_.reshape.empty() && !subgraph_context_.is_ep_ctx_ovir_encapsulated) { std::string exception_str = "[OpenVINO-EP] Bounded dynamic model execution using provider option reshape_input is not supported for OVEP EPContext model"; ORT_THROW(exception_str); } if (subgraph_context_.is_ep_ctx_ovir_encapsulated) { - model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.onnx_model_path_name.replace_extension("xml").string(), subgraph); + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.onnx_model_path_name.replace_extension("xml").string(), subgraph, session_context_.device_type); } else { - model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph); + model_stream = ep_ctx_handle_.GetModelBlobStream(session_context_.so_context_file_path, subgraph, session_context_.device_type); } } else { @@ -231,21 +231,8 @@ bool BackendManager::ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& mod bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& subgraph) const { const auto& graph_inputs = subgraph.GetInputs(); - // First validate shapes if provided by user - bool shapes_valid = true; - if (!session_context_.reshape.empty()) { - try { - ValidateInputShapes(session_context_.reshape, graph_inputs); - } catch (const std::exception& e) { - LOGS_DEFAULT(ERROR) << "[OpenVINO-EP] Shape validation failed: " << e.what(); - session_context_.reshape.clear(); // Clear the shape map as it's invalid - shapes_valid = false; - } - } - // Count dynamic inputs and check if reshape covers all of them size_t dynamic_input_count = 0; - bool all_dynamic_inputs_covered = true; for (const auto* input : graph_inputs) { // Skip dangling inputs (no consumers) @@ -273,14 +260,6 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& s // If dynamic, count it and check if reshape covers it if (has_dynamic_dim) { dynamic_input_count++; - - // Check if this dynamic input is covered by reshape input - if (!session_context_.reshape.empty() && - session_context_.reshape.find(input->Name()) == session_context_.reshape.end()) { - all_dynamic_inputs_covered = false; - LOGS_DEFAULT(WARNING) << "[OpenVINO-EP] reshape_input is provided but doesn't cover dynamic input: " - << input->Name(); - } } } @@ -289,23 +268,8 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::GraphViewer& s // Early return if no reshape input provided if (session_context_.reshape.empty()) { return has_symbolic_dims; // Return based on whether model has symbolic dims - } - - // For dynamic models with incomplete reshape coverage, clear shapes - if (has_symbolic_dims && !all_dynamic_inputs_covered) { - session_context_.reshape.clear(); - LOGS_DEFAULT(WARNING) << "reshape_input does not cover all dynamic dimensions, " - << "ignoring all provided shapes"; - return true; // Model is dynamic - } - - // If shapes are valid with complete coverage for dynamic model, treat as concrete - if (has_symbolic_dims && shapes_valid && all_dynamic_inputs_covered) { - LOGS_DEFAULT(INFO) << "All dynamic dimensions successfully covered by reshape_input"; - return false; // Model is now effectively static with concrete shapes - } - - return has_symbolic_dims; // Return dynamic status based on symbolic dimensions + } else + return false; } // Check to see if the graph is QDQ @@ -383,12 +347,12 @@ static void DumpOpenVINOEPModel([[maybe_unused]] const std::filesystem::path& on // this is a helper function to set the data fields, it duplicates ExternalDataInfo::SetExternalLocationToProto // but we cannot use that function as it is not part of public provider api. -static void SetExternalDataFields(ONNX_NAMESPACE::TensorProto* proto_init, const void* data_ptr, int64_t data_size) { +static void SetExternalDataFields(ONNX_NAMESPACE::TensorProto& proto_init, const void* data_ptr, int64_t data_size) { static constexpr const char* ORT_INTERNAL_MEM_INITIALIZER = "*/_ORT_MEM_ADDR_/*"; - auto* external_data = proto_init->mutable_external_data(); + auto* external_data = proto_init.mutable_external_data(); bool found_location = false, found_offset = false, found_length = false; const int ext_data_size = external_data->size(); - proto_init->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); + proto_init.set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL); for (int j = 0; j < ext_data_size; ++j) { auto& ext_entry = external_data->at(j); @@ -576,11 +540,15 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, if (it == proto_initializer_map.end()) continue; - auto* proto_init = it->second; + if (!it->second) { + ORT_THROW(name + " proto initializer is null!"); + } + + auto& proto_init = *it->second; // If the proto initializer is missing data, fill it in - if (!proto_init->has_raw_data() && src_init->has_raw_data()) { - *proto_init->mutable_raw_data() = src_init->raw_data(); + if (!proto_init.has_raw_data() && src_init->has_raw_data()) { + *(proto_init.mutable_raw_data()) = src_init->raw_data(); } // Only set in-memory external_data fields if the data is in memory @@ -589,10 +557,11 @@ BackendManager::GetModelProtoFromFusedNode(const onnxruntime::Node& fused_node, << src_init->name() << ", data_type: " << src_init->data_type() << ", raw_data size: " << src_init->raw_data().size(); - if (src_init->raw_data().size() > 0) + if (src_init->raw_data().size() > 0) { SetExternalDataFields(proto_init, src_init->raw_data().data(), src_init->raw_data().size()); - else + } else { LOGS(logger, VERBOSE) << "Initializer has empty raw_data: skipping initializer '" << src_init->name() << "'..."; + } } else if (onnxruntime::utils::HasExternalDataInMemory(*src_init)) { auto it_ext = external_initializers_offset_and_length.find(name); if (it_ext == external_initializers_offset_and_length.end()) { @@ -690,40 +659,6 @@ BackendManager::ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_p return model_copy; } -void BackendManager::ValidateInputShapes(const reshape_t& shapes, - const std::vector& graph_inputs) const { - for (const auto& [tensor_name, requested_shape] : shapes) { - // Find matching input in graph - const NodeArg* graph_input = nullptr; - for (const auto* input : graph_inputs) { - if (input->Name() == tensor_name) { - graph_input = input; - break; - } - } - - if (!graph_input) { - ORT_THROW("Input '" + tensor_name + "' specified in reshape_input does not exist in the graph"); - } - - const ONNX_NAMESPACE::TensorShapeProto* graph_shape = graph_input->Shape(); - if (!graph_shape) { - ORT_THROW("Graph input '" + tensor_name + "' has no shape information"); - } - - // Check dimensions count matches - size_t graph_dim_count = graph_shape->dim_size(); - size_t requested_dim_count = requested_shape.get_max_shape().size(); - - if (graph_dim_count != requested_dim_count) { - ORT_THROW("Dimensions mismatch for input '" + tensor_name + - "': graph expects " + std::to_string(graph_dim_count) + - " dimensions but reshape_input specifies " + - std::to_string(requested_dim_count) + " dimensions"); - } - } -} - void BackendManager::Compute(OrtKernelContext* context) { Ort::KernelContext ctx(context); std::chrono::high_resolution_clock::time_point start_compute, end_compute; @@ -844,5 +779,11 @@ void BackendManager::RewindKVCache(size_t index) { } } +void BackendManager::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + if (concrete_backend_) { + concrete_backend_->ReorderKVCache(src_indices, dst_indices); + } +} + } // namespace openvino_ep } // namespace onnxruntime diff --git a/onnxruntime/core/providers/openvino/backend_manager.h b/onnxruntime/core/providers/openvino/backend_manager.h index 716fe3ef4cc90..f8a74b9cbcfa4 100644 --- a/onnxruntime/core/providers/openvino/backend_manager.h +++ b/onnxruntime/core/providers/openvino/backend_manager.h @@ -31,6 +31,7 @@ class BackendManager { void TryExportCompiledBlobAsEPCtxNode(const onnxruntime::GraphViewer& subgraph, bool include_embed_data); ov::CompiledModel GetOVCompiledModel(); void RewindKVCache(size_t index); + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices); private: std::unique_ptr GetModelProtoFromFusedNode( @@ -42,8 +43,6 @@ class BackendManager { std::unordered_set IdentifyDynamicInputs(const onnxruntime::GraphViewer& subgraph, const std::vector& graph_inputs) const; bool ModelHasBatchedInputs(const ONNX_NAMESPACE::ModelProto& model_proto) const; - void ValidateInputShapes(const reshape_t& shapes, - const std::vector& graph_inputs) const; std::shared_ptr ReWriteBatchDimWithOne(const ONNX_NAMESPACE::ModelProto& model_proto); diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.cc b/onnxruntime/core/providers/openvino/backends/basic_backend.cc index d7fc0553fb1d4..7f4d1f74cfb7b 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.cc +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.cc @@ -63,7 +63,8 @@ BasicBackend::BasicBackend(std::unique_ptr& model_pr hw_target, device_config, enable_causallm, - model_file_path()); + model_file_path(), + session_context_); } else { // If the blob is held in an EPContext node, then skip FE+Compile // and directly move on to creating a backend with the executable blob @@ -308,32 +309,18 @@ void BasicBackend::SetOVDeviceConfiguration(ov::AnyMap& device_config) { } } -void BasicBackend::ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, - const ov::PartialShape& partial_shape) const { - // Check if the number of dimensions matches - if (static_cast(ort_dims.size()) != partial_shape.rank().get_length()) { - ORT_THROW("Mismatch in number of dimensions between ORT tensor and OpenVINO PartialShape."); - } - // Validate each dimension - for (size_t i = 0; i < ort_dims.size(); ++i) { - const auto& ov_dim = partial_shape[i]; // OpenVINO dimension at index i - int64_t ort_dim = ort_dims[i]; // ORT dimension at index i - - // Check if the ORT dimension is within the specified range - int64_t min_dim = ov_dim.get_min_length(); - int64_t max_dim = ov_dim.get_max_length(); - if (ort_dim < min_dim || ort_dim > max_dim) { - ORT_THROW(" ORT Dimension is out of range"); - } - } -} - void BasicBackend::RewindKVCache(size_t index) { infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) { infer_request->RewindKVCache(index); }); } +void BasicBackend::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + infer_req_pool_->forEachIdleRequest([&](OVInferRequestPtr& infer_request) { + infer_request->ReorderKVCache(src_indices, dst_indices); + }); +} + void BasicBackend::Infer(OrtKernelContext* ctx) const { Ort::KernelContext context(ctx); @@ -374,9 +361,6 @@ void BasicBackend::Infer(OrtKernelContext* ctx) const { // Set the input shape based on the input tensor from ort auto tensor = context.GetInput(input_info.onnx_index); auto ort_shape = tensor.GetTensorTypeAndShapeInfo().GetShape(); - if (input_info.IsBoundedDynamic()) { - ValidateOrtDimsAgainstPartialShape(ort_shape, input_info.shape); - } auto input_shape = ParameterShape(ort_shape); infer_request->SetTensor(input_info.name, diff --git a/onnxruntime/core/providers/openvino/backends/basic_backend.h b/onnxruntime/core/providers/openvino/backends/basic_backend.h index 2cf3d3faa8b47..d8af2ce7fd595 100644 --- a/onnxruntime/core/providers/openvino/backends/basic_backend.h +++ b/onnxruntime/core/providers/openvino/backends/basic_backend.h @@ -36,7 +36,6 @@ struct ParameterInfo { // Query methods bool IsStatic() const { return dynamic_flags == 0; } bool IsFullyDynamic() const { return dynamic_flags & 1; } - bool IsBoundedDynamic() const { return dynamic_flags & 2; } bool IsMixed() const { return (dynamic_flags & 3) == 3; } // Setter methods @@ -61,6 +60,7 @@ struct OnnxToOvNetworkBindings { OnnxToOvNetworkBindings(OVExeNetwork& exec_network, SubGraphContext& subgraph_context, SessionContext& session_context) { auto populate = [&](auto& input_output_map, const SubGraphContext::string_index_map_t& onnx_input_map, const auto& ov_parameters) { + auto input_parameter_aligned = (onnx_input_map.size() == ov_parameters.size()); for (const auto& [onnx_name, onnx_param_index] : onnx_input_map) { auto it = std::find_if(ov_parameters.begin(), ov_parameters.end(), [&onnx_name](const auto& ov_parameter_info) { return ov_parameter_info.get_names().contains(onnx_name); }); @@ -82,6 +82,13 @@ struct OnnxToOvNetworkBindings { } } + if (!input_parameter_aligned && !matched_names) { + LOGS_DEFAULT(WARNING) << log_tag << "The input '" << onnx_name + << "' is not used due to OpenVINO optimization. " + "This may cause issues if the input is required."; + continue; + } + ORT_ENFORCE(matched_names, log_tag, "Input names mismatch between OpenVINO and ONNX. ", onnx_name, " doesn't exist in the list of OpenVINO input tensor names"); @@ -111,6 +118,12 @@ struct OnnxToOvNetworkBindings { info.SetFullyDynamic(has_fully_dynamic); info.SetBoundedDynamic(has_bounded_dynamic); + } else { + // OV needs allocate the output buffer before inference, but the 0 size output graph doesn't need to do a real inference in ONNX + auto shape_size = ov::shape_size(shape.get_shape()); + if (0 == shape_size) { + has_dynamic_io_ = true; + } } input_output_map.push_back(std::move(info)); @@ -138,6 +151,7 @@ class BasicBackend : public IBackend { return exe_network_.Get(); } void RewindKVCache(size_t index) override; + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) override; private: bool ValidateSubgraph(std::map>& const_outputs_map); @@ -147,8 +161,6 @@ class BasicBackend : public IBackend { void EnableStreams(ov::AnyMap& device_config); void SetNumThreads(ov::AnyMap& device_config); void SetOVDeviceConfiguration(ov::AnyMap& device_config); - void ValidateOrtDimsAgainstPartialShape(const std::vector& ort_dims, - const ov::PartialShape& partial_shape) const; SessionContext& session_context_; SubGraphContext subgraph_context_; diff --git a/onnxruntime/core/providers/openvino/ibackend.h b/onnxruntime/core/providers/openvino/ibackend.h index 365a4625815d6..4444f37ac7433 100644 --- a/onnxruntime/core/providers/openvino/ibackend.h +++ b/onnxruntime/core/providers/openvino/ibackend.h @@ -18,6 +18,7 @@ class IBackend { virtual ov::CompiledModel GetOVCompiledModel() = 0; virtual ~IBackend() = default; virtual void RewindKVCache(size_t index) {} + virtual void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) {} }; using ptr_stream_t = std::unique_ptr; class BackendFactory { diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc index 60a461f7159f3..80ce4de4a6a19 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.cc @@ -93,7 +93,7 @@ Status EPCtxHandler::AddOVEPCtxNodeToGraph(const GraphViewer& graph_viewer, return Status::OK(); } -std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer) const { +std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& graph_viewer, const std::string& device_type) const { auto first_index = *graph_viewer.GetNodesInTopologicalOrder().begin(); auto node = graph_viewer.GetNode(first_index); ORT_ENFORCE(node != nullptr); @@ -128,10 +128,12 @@ std::unique_ptr EPCtxHandler::GetModelBlobStream(const std::fi // If the model stream is not an XML (i.e. precompiled blob), the OpenVINO SDK version that it was // exported with must match the version that is currently running. native_blob_path = std::move(blob_filepath); - ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), - "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + - ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); - + // Skip SDK version check for NPU devices as they may use different SDK versions. + if (device_type.find("NPU") == std::string::npos) { + ORT_ENFORCE((attrs.count(EP_SDK_VER) == 1) && (attrs.at(EP_SDK_VER).s() == openvino_sdk_version_), + "EPCtx blob was exported / is compatible with OpenVINO SDK version " + attrs.at(EP_SDK_VER).s() + + ", but OpenVINO SDK version currently in use is " + openvino_sdk_version_); + } result.reset(); // Release the stream as we will get the native blob from SharedContext auto shared_context = shared_context_manager_->GetOrCreateSharedContext(native_blob_path); return std::make_unique(shared_context->GetNativeBlobAsStream(partition_name), shared_context->GetNativeBlob(partition_name)); diff --git a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h index fce88005a0605..97e7369fcb0f5 100644 --- a/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/openvino/onnx_ctx_model_helper.h @@ -41,7 +41,7 @@ class EPCtxHandler { const std::string& graph_name, const bool embed_mode, std::string&& model_blob_str) const; - std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& subgraph_view) const; + std::unique_ptr GetModelBlobStream(const std::filesystem::path& so_context_file_path, const GraphViewer& subgraph_view, const std::string& device_type) const; InlinedVector GetEPCtxNodes() const; bool CheckEPCacheContextAttribute(const GraphViewer& subgraph_view, const std::string& target_attr_extn) const; std::shared_ptr Initialize(const std::vector& fused_nodes, const SessionContext& session_context); diff --git a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc index a099f85b2a4b9..e2d14b9e761b6 100644 --- a/onnxruntime/core/providers/openvino/openvino_execution_provider.cc +++ b/onnxruntime/core/providers/openvino/openvino_execution_provider.cc @@ -5,6 +5,7 @@ #include #include #include +#include #include "core/providers/shared_library/provider_api.h" #include "core/providers/openvino/openvino_execution_provider.h" #include "core/providers/openvino/contexts.h" @@ -286,6 +287,64 @@ common::Status OpenVINOExecutionProvider::SetEpDynamicOptions(gsl::span std::variant> { + std::vector indices; + while (!input.empty()) { + const auto delimiter_pos = input.find(','); + const auto part = input.substr(0, delimiter_pos); + errno = 0; + char* parse_end = nullptr; + // strtol/stol already skips whitespaces + const auto index = std::strtol(part.data(), &parse_end, 10); + if (parse_end == part.data()) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "Failed to parse kvcache_reorder " + index_type + ": " + std::string(part)); + } + if (index < 0) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "kvcache_reorder " + index_type + " cannot be negative: " + std::string(part)); + } + if (index > std::numeric_limits::max() || errno == ERANGE) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, + "kvcache_reorder " + index_type + " exceed INT32_MAX: " + std::string(part)); + } + indices.push_back(static_cast(index)); + if (delimiter_pos != std::string_view::npos) { + // ignore any trailing chars after the number, can do further checking if needed + input.remove_prefix(part.size() + 1); + } else { + break; + } + } + return indices; + }; + + const auto src_indices = parse_indices(src_string, "src_index"); + if (std::holds_alternative(src_indices)) { + return std::get(src_indices); + } + + const auto dst_indices = parse_indices(dst_string, "dst_index"); + if (std::holds_alternative(dst_indices)) { + return std::get(dst_indices); + } + + // Trigger KVCache Reorder for target Backend with vector arguments + for (auto& backend : backend_managers_) { + backend.ReorderKVCache(std::get>(src_indices), std::get>(dst_indices)); + } } else { // Handle unknown options LOGS_DEFAULT(WARNING) << "Unknown key/value pair - ignoring " << key << "/" << value; diff --git a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc index 7eb5b062fe7c8..3443983843881 100644 --- a/onnxruntime/core/providers/openvino/openvino_provider_factory.cc +++ b/onnxruntime/core/providers/openvino/openvino_provider_factory.cc @@ -568,8 +568,12 @@ struct OpenVINO_Provider : Provider { // Parse provider info with the device type ProviderInfo pi; const auto& config_options = session_options.GetConfigOptions(); - ParseProviderInfo(provider_options, &config_options, pi); - ParseConfigOptions(pi); + try { + ParseProviderInfo(provider_options, &config_options, pi); + ParseConfigOptions(pi); + } catch (std::exception& e) { + return Status(common::ONNXRUNTIME, ORT_INVALID_ARGUMENT, e.what()); + } // Create and return the execution provider auto factory = std::make_unique(pi, OVCore::Get()); diff --git a/onnxruntime/core/providers/openvino/ov_bin_manager.cc b/onnxruntime/core/providers/openvino/ov_bin_manager.cc index 88a50377281bc..65d17374adf44 100644 --- a/onnxruntime/core/providers/openvino/ov_bin_manager.cc +++ b/onnxruntime/core/providers/openvino/ov_bin_manager.cc @@ -18,8 +18,11 @@ static inline uint64_t AlignUp(uint64_t value, uint64_t alignment) { // Only supports input operations. class TensorStreamBuf : public std::streambuf { public: - explicit TensorStreamBuf(ov::Tensor& tensor) { - char* data = const_cast(tensor.data()); + explicit TensorStreamBuf(const ov::Tensor& tensor) { + // Suppress warning for tensor.data() returning const in 2026.0. Should be removable after 2026.0 is min supported ov version. + OPENVINO_SUPPRESS_DEPRECATED_START + char* data = const_cast(tensor.data()); // setg requires non-const char* but we won't modify data + OPENVINO_SUPPRESS_DEPRECATED_END size_t size = tensor.get_byte_size(); setg(data, data, data + size); } @@ -66,8 +69,8 @@ class TensorStream : public std::istream { buf_(tensor_) {} private: - ov::Tensor tensor_; // Keep tensor alive - TensorStreamBuf buf_; // Buffer wrapping tensor data + const ov::Tensor tensor_; // Keep tensor alive + TensorStreamBuf buf_; // Buffer wrapping tensor data }; /* @@ -169,11 +172,12 @@ ov::Tensor BinManager::GetNativeBlob(const std::string& blob_name) { } if (mapped_bin_) { - // Create a tensor from memory-mapped external file - blob_container.tensor = ov::Tensor( - ov::element::u8, - ov::Shape{blob_container.serialized_info.size}, - mapped_bin_.data() + blob_container.serialized_info.file_offset); + // Create tensor view from mapped_bin_ (which holds the underlying buffer) + auto blob_offset = blob_container.serialized_info.file_offset; + auto blob_size = blob_container.serialized_info.size; + ov::Coordinate begin{blob_offset}; + ov::Coordinate end{blob_offset + blob_size}; + blob_container.tensor = ov::Tensor(mapped_bin_, begin, end); } else { // Create a tensor from embedded data vector blob_container.tensor = ov::Tensor( diff --git a/onnxruntime/core/providers/openvino/ov_interface.cc b/onnxruntime/core/providers/openvino/ov_interface.cc index 898288554968e..224685735a134 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.cc +++ b/onnxruntime/core/providers/openvino/ov_interface.cc @@ -109,9 +109,13 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, bool model_status = IsStateful(model); LOGS_DEFAULT(INFO) << log_tag << "Model IsStateful() Status:\t" << (model_status ? "True" : "False"); + // Flag to add Gather+ScatterElementsUpdate subgraph to reorder KV cache for LLM speculative decoding + bool should_add_kvcache_reorder = false; if (!model_status) { LOGS_DEFAULT(INFO) << log_tag << "Converting from Stateless OV Model to Stateful OV Model" << std::endl; - PatchStatefulDecoder(model); + // TO-DO: extend to NPU device when OpenVINO NPU has related optimization + should_add_kvcache_reorder = hw_target.find("GPU") != std::string::npos; + PatchStatefulDecoder(model, should_add_kvcache_reorder); } if (onnxruntime::openvino_ep::backend_utils::IsDebugEnabled()) { @@ -152,7 +156,7 @@ OVExeNetwork OVCore::StatefulCompileModel(std::shared_ptr& model, LOGS_DEFAULT(INFO) << log_tag << "Compiling OV Model using Stateful Transformation flow"; compiled_model = OVCore::Get()->core.compile_model(model, hw_target, config); - OVExeNetwork exe(compiled_model, hw_target, true); + OVExeNetwork exe(compiled_model, hw_target, true, should_add_kvcache_reorder); return exe; } @@ -204,10 +208,10 @@ OVExeNetwork OVCore::ImportModel(ModelBlobWrapper& model_blob, return OvExceptionBoundary([&]() { ov::CompiledModel obj; #if (OPENVINO_VERSION_MAJOR > 2025 || (OPENVINO_VERSION_MAJOR == 2025 && OPENVINO_VERSION_MINOR >= 3)) - if (model_blob.stream_) { - obj = core.import_model(*model_blob.stream_, hw_target, device_config); - } else { + if (model_blob.tensor_) { obj = core.import_model(model_blob.tensor_, hw_target, device_config); + } else { + obj = core.import_model(*model_blob.stream_, hw_target, device_config); } #else obj = core.import_model(*model_blob.stream_, hw_target, device_config); @@ -226,7 +230,8 @@ OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, std::string& hw_target, const ov::AnyMap& device_config, bool enable_causallm, - std::filesystem::path model_file_path) { + std::filesystem::path model_file_path, + const SessionContext& session_context) { return OvExceptionBoundary([&]() { OVExeNetwork exe; @@ -259,6 +264,11 @@ OVExeNetwork OVCore::ImportEPCtxOVIREncapsulation(std::istream& model_stream, // Load the model explicitly with XML contents std::shared_ptr model = core.read_model(xml_file_path.string()); + if (!session_context.reshape.empty()) { + LOGS_DEFAULT(INFO) << log_tag << "Reshaping OV-IR model to specified shape"; + model->reshape(session_context.reshape); + } + if (enable_causallm) { exe = OVCore::Get()->StatefulCompileModel(model, hw_target, device_config); } else { @@ -326,7 +336,7 @@ std::shared_ptr OVExeNetwork::CreateInferRequest() { auto infReq = compiled_model_obj.create_infer_request(); std::shared_ptr ovInfReq; if (is_stateful_causallm) { - ovInfReq = std::make_shared(std::move(infReq), target_device); + ovInfReq = std::make_shared(std::move(infReq), target_device, is_kvcache_reorder_added); } else { ovInfReq = std::make_shared(std::move(infReq)); } @@ -371,8 +381,8 @@ void OVInferRequest::Infer() { "In Error Couldn't start Inference"); } -StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device) - : OVInferRequest(std::move(infer_request)), target_device(device) { +StatefulOVInferRequest::StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool kvcache_reorder_added) + : OVInferRequest(std::move(infer_request)), target_device(device), is_kvcache_reorder_added(kvcache_reorder_added) { bool gpu_or_npu = ((device.find("NPU") != std::string::npos) || (device.find("GPU") != std::string::npos)); _npu_logits_slice_required = IsNPULogitsSliceRequired(); @@ -463,6 +473,32 @@ void StatefulOVInferRequest::PreProcessInferRequest() { // TODO(ankit): Address this issue and implement the fix at the appropriate layer. FillTensor("beam_idx", ov::element::i32, {1}, 0); + if (is_kvcache_reorder_added) { + ov::Shape dst_idx_shape = ovInfReq.get_tensor("dst_idx").get_shape(); + const auto kv_num_heads = dst_idx_shape[1]; + const auto kv_head_size = dst_idx_shape[3]; + if (kv_src_indices.size() > 0) { + ov::Tensor src_idx_tensor = ov::Tensor(ov::element::i32, {kv_src_indices.size()}); + const auto src_idx_ptr = src_idx_tensor.data(); + for (size_t i = 0; i < kv_src_indices.size(); ++i) { + src_idx_ptr[i] = static_cast(kv_src_indices[i]); + } + ovInfReq.set_tensor("src_idx", src_idx_tensor); + + ov::Tensor dst_idx_tensor = ov::Tensor(ov::element::i32, {1, kv_num_heads, kv_dst_indices.size(), kv_head_size}); + const auto dst_idx_ptr = dst_idx_tensor.data(); + for (size_t i = 0; i < kv_num_heads; ++i) { + for (size_t j = 0; j < kv_dst_indices.size(); ++j) { + std::fill_n(dst_idx_ptr + (i * kv_dst_indices.size() + j) * kv_head_size, kv_head_size, kv_dst_indices[j]); + } + } + ovInfReq.set_tensor("dst_idx", dst_idx_tensor); + } else { + FillTensor("src_idx", ov::element::i32, {0}, 0); + FillTensor("dst_idx", ov::element::i32, {1, kv_num_heads, 0, kv_head_size}, 0); + } + } + // If 'prefill use full chat history' mode is enabled, we need to cache input_ids and position_ids. if (prefill_use_full_chat_history) { auto input_ids_tensor = ovInfReq.get_tensor("input_ids"); @@ -499,6 +535,31 @@ void StatefulOVInferRequest::PreProcessInferRequest() { void StatefulOVInferRequest::Infer() { PreProcessInferRequest(); OVInferRequest::Infer(); + PostProcessInferRequest(); +} + +void StatefulOVInferRequest::PostProcessInferRequest() { + if (is_kvcache_reorder_added) { + kv_src_indices.clear(); + kv_dst_indices.clear(); + } +} + +void StatefulOVInferRequest::ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) { + // Validate input parameters + if (src_indices.size() != dst_indices.size()) { + ORT_THROW(log_tag + + "ReorderKVCache: src_indices and dst_indices must have the same size. " + "Got src_indices.size()=" + + std::to_string(src_indices.size()) + + ", dst_indices.size()=" + std::to_string(dst_indices.size())); + } + + LOGS_DEFAULT(INFO) << log_tag << "ReorderKVCache: Reordering OpenVINO-internal KVCache state with " + << src_indices.size() << " index pairs"; + + kv_src_indices = src_indices; + kv_dst_indices = dst_indices; } void StatefulOVInferRequest::RewindKVCache(size_t index) { diff --git a/onnxruntime/core/providers/openvino/ov_interface.h b/onnxruntime/core/providers/openvino/ov_interface.h index 8fc28b8885e5d..20693e3349ba8 100644 --- a/onnxruntime/core/providers/openvino/ov_interface.h +++ b/onnxruntime/core/providers/openvino/ov_interface.h @@ -39,6 +39,7 @@ class OVCore; class OVInferRequest; class OVExeNetwork; struct ModelBlobWrapper; +struct SessionContext; typedef ov::Tensor OVTensor; typedef ov::ProfilingInfo OVProfilingInfo; @@ -77,7 +78,8 @@ struct OVCore : WeakSingleton { std::string& hw_target, const ov::AnyMap& device_config, bool enable_causallm, - std::filesystem::path model_file_path); + std::filesystem::path model_file_path, + const SessionContext& session_context); std::vector GetAvailableDevices() const; std::vector GetAvailableDevices(const std::string& device_type) const; @@ -89,10 +91,11 @@ class OVExeNetwork { ov::CompiledModel compiled_model_obj; std::string target_device; bool is_stateful_causallm; + bool is_kvcache_reorder_added = false; public: - explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false) - : compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm) {} + explicit OVExeNetwork(ov::CompiledModel compiled_model, std::string device, bool stateful_causallm = false, bool kvcache_reorder_added = false) + : compiled_model_obj(std::move(compiled_model)), target_device(std::move(device)), is_stateful_causallm(stateful_causallm), is_kvcache_reorder_added(kvcache_reorder_added) {} OVExeNetwork() : compiled_model_obj(ov::CompiledModel()), is_stateful_causallm(false) {} ov::CompiledModel& Get() { return compiled_model_obj; } std::shared_ptr CreateInferRequest(); @@ -134,14 +137,16 @@ class OVInferRequest { return ovInfReq; } virtual void RewindKVCache([[maybe_unused]] size_t index) {} + virtual void ReorderKVCache([[maybe_unused]] const std::vector& src_indices, [[maybe_unused]] const std::vector& dst_indices) {} }; class StatefulOVInferRequest : public OVInferRequest { public: - explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device); + explicit StatefulOVInferRequest(ov::InferRequest infer_request, std::string device, bool kvcache_reorder_added = false); void Infer() override; void RewindKVCache(size_t index) override; + void ReorderKVCache(const std::vector& src_indices, const std::vector& dst_indices) override; void FillTensor(const std::string& tensor_name, const ov::element::Type& type, const std::vector& shape, int32_t fill_value); void CacheTensor(const std::string& tensor_name, std::vector& cache); @@ -151,13 +156,19 @@ class StatefulOVInferRequest : public OVInferRequest { private: void PreProcessInferRequest(); + void PostProcessInferRequest(); std::string target_device; + std::vector cached_input_ids; + std::vector cached_position_ids; + std::vector kv_src_indices; + std::vector kv_dst_indices; + // If prefill_use_full_chat_history is true, cache the "input_ids" & "position_ids" tensors, // and ensure that full chat history is passed for each prefill call. bool prefill_use_full_chat_history = false; - std::vector cached_input_ids; - std::vector cached_position_ids; + // If kvcache_reorder is added, will include kv_src/dst_indices as input + bool is_kvcache_reorder_added = false; bool IsNPULogitsSliceRequired(); bool _npu_logits_slice_required = false; diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.cc b/onnxruntime/core/providers/openvino/ov_shared_context.cc index f48284d0cc974..900196c3f652a 100644 --- a/onnxruntime/core/providers/openvino/ov_shared_context.cc +++ b/onnxruntime/core/providers/openvino/ov_shared_context.cc @@ -10,9 +10,10 @@ namespace onnxruntime { namespace openvino_ep { -SharedContext::SharedContext(std::filesystem::path bin_path) - : bin_path_(std::move(bin_path)), - bin_manager_(bin_path_) { +SharedContext::SharedContext(const std::filesystem::path& bin_path) + : bin_path_(bin_path), + bin_manager_(bin_path_), + weight_file_manager_(WeightFileManager::Get()) { } static bool InRange(size_t offset, size_t size, size_t total_size) { @@ -35,7 +36,7 @@ void SharedContext::WeightsFile::LoadWeights(size_t file_offset, void* data, siz file_.read(static_cast(data), size); } -void* SharedContext::WeightsFile::TryGetOrCreateDeviceMapping(std::optional& remote_context) { +const void* SharedContext::WeightsFile::TryGetOrCreateDeviceMapping(std::optional& remote_context) { std::string dev_name{}; if (remote_context) { dev_name = remote_context->get_device_name(); @@ -53,8 +54,12 @@ void* SharedContext::WeightsFile::TryGetOrCreateDeviceMapping(std::optionalsecond = MappingContainer{.ptr_ = mmaped_tensor.data(), .tensor_ = mmaped_tensor}; + OPENVINO_SUPPRESS_DEPRECATED_END } } @@ -70,16 +75,21 @@ void SharedContext::LoadTensorFromFile( const auto weights_location = model_dir / value.serialized.location; auto& weights_file = weight_files_[weights_location]; if (!weights_file) { - weights_file = std::make_unique(weights_location); + weights_file = weight_file_manager_->GetOrCreateWeightsFile(weights_location); } ov::Tensor tensor; - uint8_t* mmaped_weights = static_cast(weights_file->TryGetOrCreateDeviceMapping(remote_context)); + const uint8_t* mmaped_weights = static_cast(weights_file->TryGetOrCreateDeviceMapping(remote_context)); if (mmaped_weights) { // We have memory mapped weights. Create a Tensor view into it for this value. ORT_ENFORCE(InRange(value.serialized.data_offset, value.serialized.size, weights_file->Size()), "File offset + size outside of external initializer file"); - void* mmapped_offset = static_cast(mmaped_weights + value.serialized.data_offset); + const void* mmapped_offset = static_cast(mmaped_weights + value.serialized.data_offset); +#if OPENVINO_VERSION_AT_LEAST(2026, 0) + // In OV 2026.0 we can pass read-only tensors as inputs. tensor = ov::Tensor(element_type, dimensions, mmapped_offset); +#else + tensor = ov::Tensor(element_type, dimensions, const_cast(mmapped_offset)); +#endif } else { ORT_ENFORCE(remote_context, "Unexpected: Don't have remote context and memory mapped weights is null!"); // Can't mmap the file to device tensor, create a host tensor and copy the data diff --git a/onnxruntime/core/providers/openvino/ov_shared_context.h b/onnxruntime/core/providers/openvino/ov_shared_context.h index aee6d5570d8fa..99af8bf208805 100644 --- a/onnxruntime/core/providers/openvino/ov_shared_context.h +++ b/onnxruntime/core/providers/openvino/ov_shared_context.h @@ -19,10 +19,13 @@ namespace onnxruntime { namespace openvino_ep { +class WeightFileManager; + class SharedContext : public std::enable_shared_from_this { public: - explicit SharedContext(std::filesystem::path bin_path); + explicit SharedContext(const std::filesystem::path& bin_path); SharedContext() : SharedContext("") {} + virtual ~SharedContext() {} struct Metadata { struct Value { @@ -83,14 +86,13 @@ class SharedContext : public std::enable_shared_from_this { return BinManager::GetBinPathForModel(model_path); } - private: struct WeightsFile { ORT_DISALLOW_COPY_AND_ASSIGNMENT(WeightsFile); WeightsFile() = delete; virtual ~WeightsFile() = default; explicit WeightsFile(const std::filesystem::path& filename); void LoadWeights(size_t file_offset, void* data, size_t size); - void* TryGetOrCreateDeviceMapping(std::optional& remote_context); + const void* TryGetOrCreateDeviceMapping(std::optional& remote_context); size_t Size() const { return weights_size_; } private: @@ -98,13 +100,15 @@ class SharedContext : public std::enable_shared_from_this { std::filesystem::path file_path_; size_t weights_size_; struct MappingContainer { - void* ptr_{nullptr}; + const void* ptr_{nullptr}; ov::Tensor tensor_; }; std::map imported_device_tensors_; }; - void LoadTensorFromFile( + private: + void + LoadTensorFromFile( Metadata::Value& value, const std::filesystem::path& model_dir, std::optional& remote_context, @@ -114,10 +118,29 @@ class SharedContext : public std::enable_shared_from_this { mutable std::shared_mutex mutex_; std::filesystem::path bin_path_; BinManager bin_manager_; - std::unordered_map> weight_files_; + std::shared_ptr weight_file_manager_; + std::unordered_map> weight_files_; Metadata::Map metadata_; }; +class WeightFileManager : public WeakSingleton { + public: + using WeightsFile = SharedContext::WeightsFile; + std::shared_ptr GetOrCreateWeightsFile(const std::filesystem::path& weights_path) { + auto absolute_path = std::filesystem::absolute(weights_path); + std::lock_guard lock(mutex_); + auto [it, inserted] = files_.try_emplace(absolute_path, nullptr); + if (inserted) { + it->second = std::make_shared(absolute_path); + } + return it->second; + } + + private: + mutable std::mutex mutex_; + std::unordered_map> files_; +}; + class SharedContextManager : public WeakSingleton { public: std::shared_ptr GetOrCreateActiveSharedContext(const std::filesystem::path& model_path) { diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc index fd2b5797a1f40..85c027fcf6b93 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.cc @@ -75,7 +75,8 @@ std::string GetInputOutputName(std::shared_ptr ov_model, void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, - int gather_dim) { + int gather_dim, + const bool should_add_kvcache_reorder) { if (ModelHasInputOutputNames(ov_model, "beam_idx")) { throw std::runtime_error("Model already has fused cache"); } @@ -91,6 +92,7 @@ void FuseCacheReorder(std::shared_ptr ov_model, std::string main_input_name = GetInputOutputName(ov_model, input_name_candidates); auto input_batch = ov_model->input(main_input_name).get_partial_shape()[0]; + auto update_shape = ov_model->input(key_value_input_names[0]).get_partial_shape(); auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape({std::move(input_batch)})); beam_idx->set_friendly_name("beam_idx"); @@ -98,6 +100,23 @@ void FuseCacheReorder(std::shared_ptr ov_model, ov_model->add_parameters({beam_idx}); not_kv_inputs.push_back(beam_idx->get_friendly_name()); + std::shared_ptr src_idx; + std::shared_ptr dst_idx; + + if (should_add_kvcache_reorder) { + src_idx = std::make_shared(ov::element::i32, ov::PartialShape({update_shape[2]})); + src_idx->set_friendly_name("src_idx"); + src_idx->output(0).get_tensor().add_names({"src_idx"}); + ov_model->add_parameters({src_idx}); + not_kv_inputs.push_back(src_idx->get_friendly_name()); + + dst_idx = std::make_shared(ov::element::i32, update_shape); + dst_idx->set_friendly_name("dst_idx"); + dst_idx->output(0).get_tensor().add_names({"dst_idx"}); + ov_model->add_parameters({dst_idx}); + not_kv_inputs.push_back(dst_idx->get_friendly_name()); + } + // Go over all cache parameters and fuse _reorder_cache with indices provided by the new parameter beam_idx for (const auto& input_name : key_value_input_names) { auto parameter_output_port = ov_model->input(input_name); @@ -108,9 +127,25 @@ void FuseCacheReorder(std::shared_ptr ov_model, beam_idx, ov::opset13::Constant::create(ov::element::i64, {}, {gather_dim})); + std::shared_ptr output_node; + if (should_add_kvcache_reorder) { + auto updatekv_gather_op = + std::make_shared(gather_op, + src_idx, + ov::opset13::Constant::create(ov::element::i64, {}, {2})); + + auto updatekv_op = std::make_shared(gather_op, + dst_idx, + updatekv_gather_op, + ov::opset13::Constant::create(ov::element::i64, {}, {2})); + output_node = updatekv_op; + } else { + output_node = gather_op; + } + // Replace the source output for all consumers of the input tensor for (auto& consumer : consumers) { - consumer.replace_source_output(gather_op->output(0)); + consumer.replace_source_output(output_node->output(0)); } } @@ -247,7 +282,7 @@ std::pair, std::vector> ExtractInputKVTens } // Updated PatchStatefulDecoder function -void PatchStatefulDecoder(std::shared_ptr model) { +void PatchStatefulDecoder(std::shared_ptr model, const bool should_add_kvcache_reorder) { // Use the dynamic pattern-based extraction logic auto [key_value_output_names, extracted_patterns] = ExtractKVPatternsFromOutputs(model); auto [key_value_input_names, not_kv_inputs] = ExtractInputKVTensors(model, extracted_patterns); @@ -269,7 +304,7 @@ void PatchStatefulDecoder(std::shared_ptr model) { // batch_dim = 1 if config.model_type == "chatglm" and not hasattr(config, "rope_ratio") else 0 auto batch_dim = 0; - FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim); + FuseCacheReorder(model, not_kv_inputs, key_value_input_names, batch_dim, should_add_kvcache_reorder); MakeStateful(model, key_value_input_names, key_value_output_names); } diff --git a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h index 0b89c4ed02e13..c434e95f3cbb6 100644 --- a/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h +++ b/onnxruntime/core/providers/openvino/ov_stateful_patch_utils.h @@ -13,6 +13,7 @@ #include "openvino/pass/manager.hpp" #include "openvino/pass/make_stateful.hpp" +#include "openvino/opsets/opset12.hpp" #include "openvino/opsets/opset13.hpp" namespace onnxruntime { @@ -25,13 +26,14 @@ bool ModelHasInputOutputNames(std::shared_ptr model, const std::strin void FuseCacheReorder(std::shared_ptr ov_model, std::vector& not_kv_inputs, const std::vector& key_value_input_names, - int gather_dim); + int gather_dim, + const bool should_add_kvcache_reorder = false); void MakeStateful(std::shared_ptr& ov_model, const std::vector& key_value_input_names, const std::vector& key_value_output_names); -void PatchStatefulDecoder(std::shared_ptr model); +void PatchStatefulDecoder(std::shared_ptr model, const bool should_add_kvcache_reorder = false); bool HasOpWithType(const std::shared_ptr& function, const std::string& type_name); diff --git a/onnxruntime/core/providers/openvino/ov_versions/capability.cc b/onnxruntime/core/providers/openvino/ov_versions/capability.cc index 40036212ca125..bb171fb435256 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/capability.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/capability.cc @@ -102,8 +102,16 @@ std::vector> GetCapability::Execute() { if (unsupported_nodes.empty()) { std::vector inputs; std::vector outputs; + auto input_nodes = graph_viewer_.GetInputs(); + // Input is not a tensor, OV only handle a tensor input + for (auto& node : input_nodes) { + auto shape = node->Shape(); + if (!shape) { + return result; + } + } // Fill inputs with names - Iterable2String(inputs, graph_viewer_.GetInputs()); + Iterable2String(inputs, input_nodes); /* In scenarios, when there are no inputs or all inputs being initializers, ConstantFolding optimization in onnxruntime pre-computes the value.*/ diff --git a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc index 373b2121a9b60..b2fc34eab524f 100644 --- a/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc +++ b/onnxruntime/core/providers/openvino/ov_versions/data_ops.cc @@ -35,41 +35,22 @@ namespace openvino_ep { // Ops which are supported only in models(as intermediate nodes) and not in unit tests std::set ops_supported_only_in_model = { - "Add", "Cast", "Celu", - "Concat", "ConstantOfShape", - "DequantizeLinear", "Dropout", "Einsum", - "Exp", - "Expand", - "EyeLike", "GatherElements", "GatherND", "GridSample", - "Identity", "LayerNormalization", - "Loop", "LSTM", - "NonMaxSuppression", - "NonZero", - "Not", "OneHot", "Pad", - "QuantizeLinear", "RandomNormalLike", - "Range", "ReduceMin", - "Resize", - "Round", - "Shape", "Slice", - "Split", - "Tile", - "TopK", - "Trilu"}; + "TopK"}; // Ops which are supported as functions (as composite ops) std::set ops_supported_as_function = { @@ -269,6 +250,8 @@ void DataOps::populate_types_supported() { std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); supported_types_initializer_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); + supported_types_initializer_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32)); supported_types_initializer_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64)); supported_types_initializer_.insert( @@ -317,6 +300,8 @@ void DataOps::populate_types_supported() { std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL)); supported_types_cpu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT)); + supported_types_cpu_.insert( + std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32)); supported_types_cpu_.insert( std::make_pair(V_2020_4, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32)); supported_types_cpu_.insert( @@ -367,6 +352,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"DynamicQuantizeLinear", V_2025_2, {"All"}}); no_dimension_supported_.push_back({"Equal", V_2022_1, {"CPU"}}); no_dimension_supported_.push_back({"Equal", V_2023_0, {"GPU"}}); + no_dimension_supported_.push_back({"Exp", V_2020_4, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Expand", V_2023_3, {"CPU"}}); no_dimension_supported_.push_back({"Expand", V_2024_3, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Floor", V_2020_4, {"All"}}); @@ -382,6 +368,7 @@ void DataOps::populate_op_mode_supported() { no_dimension_supported_.push_back({"Mul", V_2020_4, {"All"}}); no_dimension_supported_.push_back({"Neg", V_2023_0, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"Pow", V_2023_0, {"CPU", "GPU"}}); + no_dimension_supported_.push_back({"PRelu", V_2020_4, {"CPU", "GPU"}}); no_dimension_supported_.push_back({"QuantizeLinear", V_2021_4, {"All"}}); no_dimension_supported_.push_back({"Range", V_2021_2, {"All"}}); no_dimension_supported_.push_back({"ReduceMax", V_2021_4, {"All"}}); @@ -489,6 +476,38 @@ void DataOps::populate_op_mode_supported() { }}; op_list_.insert({"Upsample", obj}); } + { + UnsupportedOpMode obj = {{V_2023_1, V_2023_2, V_2023_3, V_2024_0, V_2024_1, V_2024_2, + V_2024_3, V_2024_4, V_2024_5, V_2024_6, V_2025_0, V_2025_1, V_2025_2, V_2025_3, V_2025_4}, + [this](const Node* node, const InitializedTensorSet&) { + auto& attributes = node->GetAttributes(); + if (attributes.count("coordinate_transformation_mode") > 0) { + auto coordinate_transformation_mode = + attributes.at("coordinate_transformation_mode").s(); + if (coordinate_transformation_mode == "tf_crop_and_resize" || + coordinate_transformation_mode == "half_pixel_symmetric") { + return true; + } + } + if (attributes.count("antialias") > 0) { + auto antialias_mode = + attributes.at("antialias").i(); + auto resize_mode = attributes.at("mode").s(); + if (antialias_mode == 1 && + (resize_mode == "linear" || + resize_mode == "cubic")) { + return true; + } + } + if (attributes.count("exclude_outside") > 0) { + if (attributes.at("exclude_outside").i() == 1) { + return true; + } + } + return false; + }}; + op_list_.insert({"Resize", obj}); + } } bool DataOps::op_is_supported(std::string name, std::vector& op_list) { diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index d468894080b3d..0e49c0f897bea 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -94,6 +94,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), + "", main_context_node.Name(), qnn_models, max_spill_fill_size); @@ -127,6 +128,18 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "The file path in ep_cache_context does not exist or is not accessible."); } + std::string context_binary_path_str = context_binary_path.string(); +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + if (qnn_backend_manager->FileMappingIsEnabled()) { + return qnn_backend_manager->LoadCachedQnnContextFromBuffer(nullptr, + 0, + context_binary_path_str, + main_context_node.Name(), + qnn_models, + max_spill_fill_size); + } +#endif + size_t buffer_size{0}; std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary); ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file."); @@ -144,6 +157,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, cache_file.close(); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), + context_binary_path_str, main_context_node.Name(), qnn_models, max_spill_fill_size); diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc index cfa0c430b053b..202efb7706664 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.cc @@ -51,7 +51,7 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateSimpleOpBuilder("Sum", *this); CreateSimpleOpBuilder("Tanh", *this); - CreateSimpleOpBuilder("Concat", *this); + CreateConcatOpBuilder("Concat", *this); CreateSimpleOpBuilder("QuantizeLinear", *this); CreateSimpleOpBuilder("DequantizeLinear", *this); @@ -212,6 +212,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() { CreateGatherNDOpBuilder("GatherND", *this); } + { + CreateQuickGeluOpBuilder("QuickGelu", *this); + } + { CreateModOpBuilder("Mod", *this); } @@ -227,6 +231,11 @@ OpBuilderRegistrations::OpBuilderRegistrations() { { CreateInverseOpBuilder("Inverse", *this); } + + { + CreateFusedMatMulOpBuilder("FusedMatMul", *this); + CreateMatMulNBitsOpBuilder("MatMulNBits", *this); + } } const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) { diff --git a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h index 1312069892671..1a51dac1cdef5 100644 --- a/onnxruntime/core/providers/qnn/builder/op_builder_factory.h +++ b/onnxruntime/core/providers/qnn/builder/op_builder_factory.h @@ -126,6 +126,11 @@ void CreateThresholdedReluOpBuilder(const std::string& op_type, OpBuilderRegistr void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateFusedMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateQuickGeluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); + +void CreateMatMulNBitsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations); } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index b51732bf0fe12..0bb3accb4d754 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -13,34 +13,6 @@ bool IsOptionalNodeUnitIODef(const NodeUnitIODef& node_io_def) { const NodeArg& arg = node_io_def.node_arg; return !arg.Exists() || arg.Name().empty(); } - -// Function to check whether we should skip processing null input which has 0 dim in shape. -// Such null inputs often exist in models saved from PyTorch, especially for Concat. -bool DoesConcatInputShapeContainZero(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - const NodeUnitIODef& node_io_def, - const logging::Logger& logger) { - // Although the 0 dim issue should be handled for all op types, restricting in Concat for now since current cases - // only happen on one of Concat inputs. One may rename the function and relax the checking here to extend for other - // ops. - if (node_unit.OpType() != "Concat") { - return false; - } - - std::vector input_shape; - if (!qnn_model_wrapper.GetOnnxShape(node_io_def.node_arg, input_shape)) { - return false; - } - - for (const uint32_t& dim : input_shape) { - if (dim == 0) { - LOGS(logger, WARNING) << "Tensor has 0 dim, ignore this input: " << node_io_def.node_arg.Name(); - return true; - } - } - - return false; -} } // namespace std::string BaseOpBuilder::GetOpBuilderType() const { @@ -154,9 +126,7 @@ Status BaseOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const auto& inputs = node_unit.Inputs(); const auto input_count = GetInputCountQnnRequired(node_unit); for (size_t input_i = 0; input_i < input_count; ++input_i) { - if (!DoesConcatInputShapeContainZero(qnn_model_wrapper, node_unit, inputs[input_i], logger)) { - ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names)); - } + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, inputs[input_i], logger, input_names)); } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 51f6523559987..789350ed886fe 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -12,6 +12,7 @@ namespace onnxruntime { namespace qnn { + class BatchNormOpBuilder : public BaseOpBuilder { public: BatchNormOpBuilder() : BaseOpBuilder("BatchNormOpBuilder") {} @@ -262,30 +263,57 @@ class BatchNormOpBuilder : public BaseOpBuilder { return Status::OK(); } + // Maybe dequantizes a 1D BatchNorm parameter tensor to double values. + Status MaybeDequantizeParamTensor(const TensorInfo& info, + const uint8_t* raw_ptr, + const size_t raw_ptr_length, + std::string_view tensor_name, + std::vector& out) const { + uint32_t channel = info.shape[0]; + out.resize(channel); + ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(info.qnn_data_type, channel, raw_ptr_length)); + + const bool is_quantized = info.quant_param.IsQuantized(); + const bool is_per_channel = info.quant_param.IsPerChannel(); + const Qnn_QuantizeParams_t& quant_param = info.quant_param.Get(); + if (is_per_channel) { + // Validate per-channel quantization parameters for 1D BatchNorm tensors. + // For 1D tensors, axis must be 0 and numScaleOffsets must equal channel count. + ORT_RETURN_IF_NOT(quant_param.axisScaleOffsetEncoding.axis == 0, + "Per-channel quantization axis must be 0 for 1D ", tensor_name, " tensor, got ", + quant_param.axisScaleOffsetEncoding.axis); + ORT_RETURN_IF_NOT(quant_param.axisScaleOffsetEncoding.numScaleOffsets == channel, + "Per-channel quantization scale/offset count (", + quant_param.axisScaleOffsetEncoding.numScaleOffsets, + ") must equal channel count (", channel, ") for ", tensor_name, " tensor."); + } + + int offset = 0; + for (uint32_t i = 0; i < channel; ++i) { + double value = 0.0; + ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(info.qnn_data_type, raw_ptr + offset, value, offset)); + // Dequantize if needed + if (is_quantized) { + if (is_per_channel) { + value = utils::Dequantize(quant_param.axisScaleOffsetEncoding.scaleOffset[i].offset, + quant_param.axisScaleOffsetEncoding.scaleOffset[i].scale, + value); + } else { + value = utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + value); + } + } + out[i] = value; + } + return Status::OK(); + } + Status PreprocessMean(const TensorInfo& mean_info, const uint8_t* mean_raw_ptr, const size_t mean_raw_ptr_length, std::vector& mean_out) const { - // tensor length (channel) - uint32_t channel = mean_info.shape[0]; - mean_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(mean_info.qnn_data_type, channel, mean_raw_ptr_length)); - - const bool is_quantized = mean_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || mean_info.quant_param.IsPerTensor(), - "BatchNormalization's input_mean does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = mean_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double mean_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(mean_info.qnn_data_type, mean_raw_ptr + offset, mean_value, offset)); - mean_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - mean_value) - : mean_value; - } - return Status::OK(); + return MaybeDequantizeParamTensor(mean_info, mean_raw_ptr, mean_raw_ptr_length, "mean", mean_out); } Status PreprocessStd(const TensorInfo& var_info, @@ -293,25 +321,12 @@ class BatchNormOpBuilder : public BaseOpBuilder { const size_t var_raw_ptr_length, const float epsilon, std::vector& std_out) const { - // tensor length (channel) - uint32_t channel = var_info.shape[0]; - std_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(var_info.qnn_data_type, channel, var_raw_ptr_length)); - - const bool is_quantized = var_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || var_info.quant_param.IsPerTensor(), - "BatchNormalization's input_var does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = var_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double var_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(var_info.qnn_data_type, var_raw_ptr + offset, var_value, offset)); - std_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - var_value) - : var_value; - std_out[i] = std::sqrt(std_out[i] + static_cast(epsilon)); + std::vector var_dequantized; + ORT_RETURN_IF_ERROR(MaybeDequantizeParamTensor(var_info, var_raw_ptr, var_raw_ptr_length, "variance", var_dequantized)); + + std_out.resize(var_dequantized.size()); + for (size_t i = 0; i < var_dequantized.size(); ++i) { + std_out[i] = std::sqrt(var_dequantized[i] + static_cast(epsilon)); } return Status::OK(); } @@ -323,25 +338,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { double& rmax, double& rmin, std::vector& scale_out) const { - // tensor length (channel) - uint32_t channel = scale_info.shape[0]; - scale_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(scale_info.qnn_data_type, channel, scale_raw_ptr_length)); - - const bool is_quantized = scale_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || scale_info.quant_param.IsPerTensor(), - "BatchNormalization's scale input does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = scale_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double scale_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(scale_info.qnn_data_type, scale_raw_ptr + offset, scale_value, offset)); - scale_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - scale_value) - : scale_value; - scale_out[i] = scale_out[i] / std_double_tensor[i]; + ORT_RETURN_IF_ERROR(MaybeDequantizeParamTensor(scale_info, scale_raw_ptr, scale_raw_ptr_length, "scale", scale_out)); + + for (size_t i = 0; i < scale_out.size(); ++i) { + scale_out[i] /= std_double_tensor[i]; rmax = std::max(rmax, scale_out[i]); rmin = std::min(rmin, scale_out[i]); } @@ -356,25 +356,10 @@ class BatchNormOpBuilder : public BaseOpBuilder { double& rmax, double& rmin, std::vector& bias_out) const { - // tensor length (channel) - uint32_t channel = bias_info.shape[0]; - bias_out.resize(channel); - ORT_RETURN_IF_ERROR(AssertUnpackedTensorSize(bias_info.qnn_data_type, channel, bias_raw_ptr_length)); - - const bool is_quantized = bias_info.quant_param.IsQuantized(); - ORT_RETURN_IF_NOT(!is_quantized || bias_info.quant_param.IsPerTensor(), - "BatchNormalization's bias input does not support per-channel quantization"); - int i = 0; - int offset = 0; - const Qnn_QuantizeParams_t& quant_param = bias_info.quant_param.Get(); - for (; i < static_cast(channel); ++i) { - double bias_value = 0.0; - ORT_RETURN_IF_ERROR(GetValueOnQnnDataType(bias_info.qnn_data_type, bias_raw_ptr + offset, bias_value, offset)); - bias_out[i] = (is_quantized) ? utils::Dequantize(quant_param.scaleOffsetEncoding.offset, - quant_param.scaleOffsetEncoding.scale, - bias_value) - : bias_value; - bias_out[i] = bias_out[i] - (mean_double_tensor[i] * scale_double_tensor[i]); + ORT_RETURN_IF_ERROR(MaybeDequantizeParamTensor(bias_info, bias_raw_ptr, bias_raw_ptr_length, "bias", bias_out)); + + for (size_t i = 0; i < bias_out.size(); ++i) { + bias_out[i] -= mean_double_tensor[i] * scale_double_tensor[i]; rmax = std::max(rmax, bias_out[i]); rmin = std::min(rmin, bias_out[i]); } @@ -390,10 +375,15 @@ class BatchNormOpBuilder : public BaseOpBuilder { bool symmetric = false; if (info.quant_param.IsQuantized()) { size_t data_size = double_tensor.size(); - // QNN BatchNorm int32 bias requires symmetric quantizated + // QNN BatchNorm requires symmetric quantization (zero_point=0) for signed params if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { data_size *= sizeof(int32_t); symmetric = true; + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + data_size *= sizeof(int16_t); + symmetric = true; + } else if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + data_size *= sizeof(uint16_t); } raw_tensor.resize(data_size); float scale = 0.0f; @@ -406,7 +396,6 @@ class BatchNormOpBuilder : public BaseOpBuilder { symmetric)); quant_param = QnnQuantParamsWrapper(scale, zero_point); for (size_t i = 0; i < double_tensor.size(); ++i) { - // onnx only supports 8 bits quantization int quant_value_int = 0; ORT_RETURN_IF_ERROR(utils::Quantize(double_tensor[i], scale, zero_point, info.qnn_data_type, quant_value_int)); if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_8) { @@ -414,12 +403,19 @@ class BatchNormOpBuilder : public BaseOpBuilder { } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_8) { int8_t quant_value = static_cast(quant_value_int); raw_tensor[i] = *reinterpret_cast(&quant_value); + } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_16) { + int16_t quant_value = static_cast(quant_value_int); + size_t pos = i * sizeof(int16_t); + std::memcpy(&raw_tensor[pos], reinterpret_cast(&quant_value), sizeof(int16_t)); + } else if (info.qnn_data_type == QNN_DATATYPE_UFIXED_POINT_16) { + uint16_t quant_value = static_cast(quant_value_int); + size_t pos = i * sizeof(uint16_t); + std::memcpy(&raw_tensor[pos], reinterpret_cast(&quant_value), sizeof(uint16_t)); } else if (info.qnn_data_type == QNN_DATATYPE_SFIXED_POINT_32) { int32_t quant_value = static_cast(quant_value_int); size_t pos = i * sizeof(int32_t); std::memcpy(&raw_tensor[pos], reinterpret_cast(&quant_value), sizeof(int32_t)); } else { - // TODO(adrianlizarraga): Should support 16-bit quantization as well. ORT_RETURN_IF(true, "Qnn Data Type: %d not supported yet.", info.qnn_data_type); } } @@ -437,6 +433,45 @@ class BatchNormOpBuilder : public BaseOpBuilder { const std::vector out_dtypes) const override ORT_MUST_USE_RESULT; }; +namespace { + +// Helper to check if a BatchNorm param is constant - either direct initializer or through a DQ node. +bool IsParamConstant(const QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::string& name) { + if (qnn_model_wrapper.IsConstantInput(name)) { + return true; + } + // Check if param comes through a DQ node with constant input + for (const Node* dq_node : node_unit.GetDQNodes()) { + if (dq_node->OutputDefs()[0]->Name() == name) { + return qnn_model_wrapper.IsConstantInput(dq_node->InputDefs()[0]->Name()); + } + } + return false; +} + +// Adjust BatchNorm param types for QNN HTP compatibility. +// Modifies scale/bias types in-place; quantization happens in Postprocess. +void OverrideParamTypeForRequantize(Qnn_DataType_t x_dtype, + Qnn_DataType_t& scale_dtype, + Qnn_DataType_t& bias_dtype, + bool is_scale_has_negative_values = true) { + // QNN HTP with UFIXED_POINT_16 input doesn't support SFIXED_POINT_8 scale + if (x_dtype == QNN_DATATYPE_UFIXED_POINT_16 && scale_dtype == QNN_DATATYPE_SFIXED_POINT_8) { + scale_dtype = is_scale_has_negative_values ? QNN_DATATYPE_SFIXED_POINT_16 : QNN_DATATYPE_UFIXED_POINT_8; + } + + // QNN HTP requires quantized bias for quantized ops + bool is_quantized = (x_dtype == QNN_DATATYPE_UFIXED_POINT_8 || x_dtype == QNN_DATATYPE_SFIXED_POINT_8 || + x_dtype == QNN_DATATYPE_UFIXED_POINT_16 || x_dtype == QNN_DATATYPE_SFIXED_POINT_16); + if (is_quantized && (bias_dtype == QNN_DATATYPE_FLOAT_32 || bias_dtype == QNN_DATATYPE_FLOAT_16)) { + bias_dtype = QNN_DATATYPE_SFIXED_POINT_32; + } +} + +} // namespace + // BatchNorm is sensitive with data layout, no special validation so far // The nodes from 1st call of GetCapability do not get layout transformer applied, it's still NCHW // The nodes from 2nd call of GetCapability get layout transformer applied, it's NHWC @@ -464,14 +499,14 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, std::vector scale_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[1].node_arg, scale_shape), "Cannot get shape of input 1 (scale)."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[1].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[1].node_arg.Name()), "QNN BatchNorm doesn't support dynamic scale."); ORT_RETURN_IF(scale_shape.size() != 1 || scale_shape[0] != num_channels, "QNN BatchNorm input 1 (scale) must have 1D shape [channel]."); std::vector bias_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[2].node_arg, bias_shape), "Cannot get shape of input 2 (bias)."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[2].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[2].node_arg.Name()), "QNN BatchNorm doesn't support dynamic bias."); ORT_RETURN_IF(bias_shape.size() != 1 || bias_shape[0] != num_channels, @@ -481,14 +516,14 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[3].node_arg, mean_shape), "Cannot get shape of input 3 (mean)."); ORT_RETURN_IF(mean_shape.size() != 1 || mean_shape[0] != num_channels, "QNN BatchNorm input 3 (mean) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[3].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[3].node_arg.Name()), "QNN BatchNorm doesn't support dynamic mean."); std::vector var_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[4].node_arg, var_shape), "Cannot get shape of input 4 (var)."); ORT_RETURN_IF(var_shape.size() != 1 || var_shape[0] != num_channels, "QNN BatchNorm input 4 (var) must have 1D shape [channel]."); - ORT_RETURN_IF_NOT(qnn_model_wrapper.IsConstantInput(inputs[4].node_arg.Name()), + ORT_RETURN_IF_NOT(IsParamConstant(qnn_model_wrapper, node_unit, inputs[4].node_arg.Name()), "QNN BatchNorm doesn't support dynamic var."); ORT_RETURN_IF(node_unit.Outputs().size() > 1, "QNN BatchNorm only support 1 output."); @@ -528,11 +563,15 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[3], mean_info)); ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[4], var_info)); - // scale, bias, mean, and var must be initializers - ORT_RETURN_IF_NOT(scale_info.is_initializer, "scale must be initializers"); - ORT_RETURN_IF_NOT(bias_info.is_initializer, "bias must be initializers"); - ORT_RETURN_IF_NOT(mean_info.is_initializer, "mean must be initializers"); - ORT_RETURN_IF_NOT(var_info.is_initializer, "var must be initializers"); + // Get input tensor info to determine if this is a quantized op + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info)); + const bool is_quantized_op = input_info.quant_param.IsQuantized(); + + // Check if bias needs conversion (will be done after preprocessing) + const bool bias_is_float = !bias_info.quant_param.IsQuantized() && + (bias_info.qnn_data_type == QNN_DATATYPE_FLOAT_32 || + bias_info.qnn_data_type == QNN_DATATYPE_FLOAT_16); std::vector scale_unpacked_tensor; std::vector bias_unpacked_tensor; @@ -582,6 +621,15 @@ Status BatchNormOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, bias_rmin, bias_double_tensor)); + // Apply QNN HTP type conversions + OverrideParamTypeForRequantize(input_info.qnn_data_type, + scale_info.qnn_data_type, + bias_info.qnn_data_type, + scale_rmin < 0.0); + if (is_quantized_op && bias_is_float) { + bias_info.quant_param = QnnQuantParamsWrapper(1.0f, 0); // Placeholder, computed in Postprocess + } + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(scale_name)) { std::vector scale_raw_tensor; QnnQuantParamsWrapper scale_quant_param = scale_info.quant_param; @@ -650,10 +698,17 @@ Status BatchNormOpBuilder::CheckHtpDataTypes(const std::vector i const std::vector out_dtypes) const { bool is_supported_dtype = false; // in_dtypes: [X, scale, B, input_mean, input_var] - std::vector all_dtypes(in_dtypes.begin(), in_dtypes.begin() + 3); // out_dtypes: [Y, running_mean, running_var] - all_dtypes.insert(all_dtypes.end(), out_dtypes.begin(), out_dtypes.begin() + 1); - // FP16 + Qnn_DataType_t x_dtype = in_dtypes[0]; + Qnn_DataType_t scale_dtype = in_dtypes[1]; + Qnn_DataType_t bias_dtype = in_dtypes[2]; + Qnn_DataType_t y_dtype = out_dtypes[0]; + + // We likely need to re-quantize scale/bias for HTP compatibility, override dtypes before checking. + // Note: We conservatively assume scale may have negative values during validation. + OverrideParamTypeForRequantize(x_dtype, scale_dtype, bias_dtype); + std::vector all_dtypes{x_dtype, scale_dtype, bias_dtype, y_dtype}; + // FP16/FP32 if ( (all_dtypes == std::vector{QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16, QNN_DATATYPE_FLOAT_16}) || (all_dtypes == std::vector{QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32, QNN_DATATYPE_FLOAT_32})) { @@ -678,7 +733,7 @@ Status BatchNormOpBuilder::CheckHtpDataTypes(const std::vector i } ORT_RETURN_IF_NOT(is_supported_dtype, "QNN Batchnorm unsupported datatype on HTP."); return Status::OK(); -}; +} } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc new file mode 100644 index 0000000000000..542447b1818f2 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/concat_op_builder.cc @@ -0,0 +1,111 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class ConcatOpBuilder : public BaseOpBuilder { + public: + ConcatOpBuilder() : BaseOpBuilder("ConcatOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ConcatOpBuilder); + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +Status ConcatOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool /*do_op_validation*/) const { + const auto& inputs = node_unit.Inputs(); + + for (const auto& input : inputs) { + const auto& input_name = input.node_arg.Name(); + bool has_zero_dim = false; + + // Check if the tensor has a 0 dimension + if (qnn_model_wrapper.IsConstantInput(input_name)) { + // Process constant inputs (initializers) + const auto* input_tensor = qnn_model_wrapper.GetConstantTensor(input_name); + if (input_tensor != nullptr) { + const auto& shape = input_tensor->dims(); + if (std::find(shape.begin(), shape.end(), 0) != shape.end()) { + // Found a 0 dimension, skip this input + LOGS(logger, VERBOSE) << "Constant input tensor " << input_name << " has a 0 dimension, excluding from Concat"; + has_zero_dim = true; + } + } + } else { + // Process non-constant inputs + std::vector shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input.node_arg, shape), "Cannot get shape"); + + if (std::find(shape.begin(), shape.end(), 0) != shape.end()) { + // Found a 0 dimension, skip this input + LOGS(logger, VERBOSE) << "Input tensor " << input_name << " has a 0 dimension, excluding from Concat"; + has_zero_dim = true; + } + } + + // Process the input if it doesn't have a 0 dimension + if (!has_zero_dim) { + ORT_RETURN_IF_ERROR(ProcessInput(qnn_model_wrapper, input, logger, input_names)); + } + } + + // If all inputs have 0 dimensions, return an error as Concat requires at least one non-zero dimension input + if (input_names.empty()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Concat operation requires at least one input without a 0 dimension"); + } + + return Status::OK(); +} + +Status ConcatOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + if (input_names.size() < 1) { + return Status::OK(); + } + + std::vector param_tensor_names; + + // Process axis attribute + int32_t default_axis = 0; + Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT; + ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis)); + QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_CONCAT_PARAM_AXIS, axis_qnn_scalar); + param_tensor_names.push_back(axis_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(axis_param)); + + // Process outputs + return ProcessOutputs(qnn_model_wrapper, node_unit, + std::move(input_names), + std::move(param_tensor_names), + logger, do_op_validation, GetQnnOpType(node_unit.OpType())); +} + +void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/fused_matmul_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/fused_matmul_op_builder.cc new file mode 100644 index 0000000000000..7bd6790d87ccc --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/fused_matmul_op_builder.cc @@ -0,0 +1,365 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +// FusedMatMul operator is decomposed into MatMul with optional transposition and alpha scaling. +class FusedMatMulOpBuilder : public BaseOpBuilder { + public: + FusedMatMulOpBuilder() : BaseOpBuilder("FusedMatMulOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedMatMulOpBuilder); + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, + std::vector& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + std::vector&& input_names, const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + private: + Status ProcessMatMulInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names) const ORT_MUST_USE_RESULT; + + Status GetFusedMatMulAttributes(const NodeUnit& node_unit, + bool& transA, + bool& transB, + bool& transBatchA, + bool& transBatchB, + float& alpha) const ORT_MUST_USE_RESULT; + + Status ProcessPermAttribute(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::vector& perm, + std::vector& param_tensor_names) const; + + void CreateBatchTransposePermVector(const std::vector& input_shape, std::vector& perm, bool trans_mat = false) const; + + Status HandleBatchTranspose(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const TensorInfo& input_info, + const std::string& input_name, + std::string& transposed_name, + bool trans_mat, + bool do_op_validation) const; +}; + +Status FusedMatMulOpBuilder::GetFusedMatMulAttributes(const NodeUnit& node_unit, + bool& transA, + bool& transB, + bool& transBatchA, + bool& transBatchB, + float& alpha) const { + NodeAttrHelper node_helper(node_unit); + + transA = node_helper.Get("transA", static_cast(0)) != 0; + transB = node_helper.Get("transB", static_cast(0)) != 0; + + transBatchA = node_helper.Get("transBatchA", static_cast(0)) != 0; + transBatchB = node_helper.Get("transBatchB", static_cast(0)) != 0; + + alpha = node_helper.Get("alpha", 1.0f); + + return Status::OK(); +} + +Status FusedMatMulOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, + const logging::Logger& logger, std::vector& input_names, + bool /*do_op_validation*/) const { + const auto& inputs = node_unit.Inputs(); + + if (inputs.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "FusedMatMul requires exactly 2 inputs, got ", inputs.size()); + } + + TensorInfo input_info_0{}; + TensorInfo input_info_1{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info_0)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input_info_1)); + + ORT_RETURN_IF_ERROR(ProcessMatMulInputs(qnn_model_wrapper, node_unit, logger, input_names)); + + return Status::OK(); +} + +Status FusedMatMulOpBuilder::ProcessMatMulInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names) const { + const auto& inputs = node_unit.Inputs(); + + // Process input A + const std::string& input_a_name = inputs[0].node_arg.Name(); + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_a_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_a_name; + } else { + QnnTensorWrapper input_a_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[0], input_a_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_a_tensor)), "Failed to add input A tensor."); + } + input_names.emplace_back(input_a_name); + + // Process input B + const std::string& input_b_name = inputs[1].node_arg.Name(); + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_b_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_b_name; + } else { + QnnTensorWrapper input_b_tensor; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], input_b_tensor)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_b_tensor)), "Failed to add input B tensor."); + } + input_names.emplace_back(input_b_name); + + return Status::OK(); +} + +Status FusedMatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& /*logger*/, + bool do_op_validation) const { + bool transA = false; + bool transB = false; + bool transBatchA = false; + bool transBatchB = false; + float alpha = 1.0f; + ORT_RETURN_IF_ERROR(GetFusedMatMulAttributes(node_unit, transA, transB, transBatchA, transBatchB, alpha)); + + TensorInfo input_a_info{}; + TensorInfo input_b_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_a_info)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[1], input_b_info)); + + std::vector matmul_param_tensor_names; + + // Set transpose parameters for last two dimensions + // Skip using transpose_in0 param when both transA and transBatchA are present + // Only use transpose_in0 when transA is present and transBatchA is not present + if (!(transA && transBatchA)) { + Qnn_Scalar_t transpose_a_scalar = QNN_SCALAR_INIT; + transpose_a_scalar.dataType = QNN_DATATYPE_BOOL_8; + transpose_a_scalar.bool8Value = transA ? 1 : 0; + QnnParamWrapper transpose_a_param(node_unit.Index(), node_unit.Name(), + QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, transpose_a_scalar); + matmul_param_tensor_names.push_back(transpose_a_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(transpose_a_param)); + } + + // Skip using transpose_in1 param when both transB and transBatchB are present + // Only use transpose_in1 when transB is present and transBatchB is not present + if (!(transB && transBatchB)) { + Qnn_Scalar_t transpose_b_scalar = QNN_SCALAR_INIT; + transpose_b_scalar.dataType = QNN_DATATYPE_BOOL_8; + transpose_b_scalar.bool8Value = transB ? 1 : 0; + QnnParamWrapper transpose_b_param(node_unit.Index(), node_unit.Name(), + QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, transpose_b_scalar); + matmul_param_tensor_names.push_back(transpose_b_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(transpose_b_param)); + } + + // QNN doesn't directly support batch dimension transposition in MatMul + // We need to insert additional transpose operations before the MatMul if transBatchA or transBatchB is true + std::string input_a_for_matmul = input_names[0]; + std::string input_b_for_matmul = input_names[1]; + + if (transBatchA && input_a_info.shape.size() > 2) { + std::string transposed_a_name; + ORT_RETURN_IF_ERROR(HandleBatchTranspose(qnn_model_wrapper, node_unit, input_a_info, + input_a_for_matmul, transposed_a_name, transA, do_op_validation)); + input_a_for_matmul = transposed_a_name; + } + + if (transBatchB && input_b_info.shape.size() > 2) { + std::string transposed_b_name; + ORT_RETURN_IF_ERROR(HandleBatchTranspose(qnn_model_wrapper, node_unit, input_b_info, + input_b_for_matmul, transposed_b_name, transB, do_op_validation)); + input_b_for_matmul = transposed_b_name; + } + + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); + TensorInfo output_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); + + if (alpha == 1.0f) { + // When alpha is 1.0f, MatMul output is the final output + Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + + QnnTensorWrapper output_tensor(output_name, + tensor_type, + output_info.qnn_data_type, + output_info.quant_param.Copy(), + std::vector(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), + "Failed to add final output tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + utils::GetUniqueName(node_unit.Name() + "_matmul"), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_MAT_MUL, + {input_a_for_matmul, input_b_for_matmul}, + {output_name}, + std::move(matmul_param_tensor_names), + do_op_validation), + "Failed to create MatMul node for FusedMatMul."); + } else { + // When alpha is not 1.0f, we need an intermediate tensor for MatMul output + // and then apply alpha scaling + std::string matmul_output_name = utils::GetUniqueName(node_unit.Name() + "_matmul_output"); + + QnnTensorWrapper matmul_output_tensor(matmul_output_name, + QNN_TENSOR_TYPE_NATIVE, + output_info.qnn_data_type, + QnnQuantParamsWrapper(), + std::vector(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(matmul_output_tensor)), + "Failed to add MatMul output tensor."); + + Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + + QnnTensorWrapper output_tensor(output_name, + tensor_type, + output_info.qnn_data_type, + output_info.quant_param.Copy(), + std::vector(output_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), + "Failed to add output tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + utils::GetUniqueName(node_unit.Name() + "_matmul"), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_MAT_MUL, + {input_a_for_matmul, input_b_for_matmul}, + {matmul_output_name}, + std::move(matmul_param_tensor_names), + do_op_validation), + "Failed to create MatMul node for FusedMatMul."); + + std::string alpha_tensor_name = utils::GetUniqueName(node_unit.Name() + "_alpha"); + std::vector alpha_shape{1}; + Qnn_DataType_t alpha_qnn_data_type = output_info.qnn_data_type; + std::vector alpha_data; + + // The alpha tensor data type should match the MatMul output data type for element-wise multiply + if (alpha_qnn_data_type == QNN_DATATYPE_FLOAT_16) { + alpha_data.resize(sizeof(MLFloat16)); + MLFloat16 alpha_fp16(alpha); + memcpy(alpha_data.data(), &alpha_fp16.val, sizeof(MLFloat16)); + } else { + alpha_data.resize(sizeof(float)); + memcpy(alpha_data.data(), &alpha, sizeof(float)); + } + + QnnTensorWrapper alpha_tensor_wrapper(alpha_tensor_name, + QNN_TENSOR_TYPE_STATIC, + alpha_qnn_data_type, + QnnQuantParamsWrapper(), + std::move(alpha_shape), + std::move(alpha_data)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_tensor_wrapper)), + "Failed to add alpha tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + utils::GetUniqueName(node_unit.Name() + "_alpha_scale"), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_MULTIPLY, + {matmul_output_name, alpha_tensor_name}, + {output_name}, + {}, + do_op_validation), + "Failed to create alpha scaling node for FusedMatMul."); + } + + return Status::OK(); +} + +Status FusedMatMulOpBuilder::ProcessPermAttribute(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const std::vector& perm, + std::vector& param_tensor_names) const { + QnnParamWrapper transpose_param(node_unit.Index(), node_unit.Name(), QNN_OP_TRANSPOSE_PARAM_PERM, + {static_cast(perm.size())}, std::vector(perm)); + param_tensor_names.push_back(transpose_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(transpose_param)); + + return Status::OK(); +} + +void FusedMatMulOpBuilder::CreateBatchTransposePermVector(const std::vector& input_shape, + std::vector& perm, + bool trans_mat) const { + const size_t shape_size = input_shape.size(); + + perm.clear(); + perm.reserve(shape_size); + + // 1. Add batch dimensions (1 to shape_size-2) + for (size_t i = 1; i < shape_size - 1; ++i) { + perm.push_back(static_cast(i)); + } + + // 2. Add the second-to-last dimension based on trans_mat + perm.push_back(trans_mat ? static_cast(shape_size - 1) : 0); + + // 3. Add the last dimension based on trans_mat + perm.push_back(trans_mat ? 0 : static_cast(shape_size - 1)); +} + +Status FusedMatMulOpBuilder::HandleBatchTranspose(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const TensorInfo& input_info, + const std::string& input_name, + std::string& transposed_name, + bool trans_mat, + bool do_op_validation) const { + transposed_name = utils::GetUniqueName(node_unit.Name() + "_transposed_" + input_name.substr(input_name.find_last_of('/') + 1)); + + // Create perm vector for batch transpose + std::vector perm; + CreateBatchTransposePermVector(input_info.shape, perm, trans_mat); + + std::vector transpose_params; + ORT_RETURN_IF_ERROR(ProcessPermAttribute(qnn_model_wrapper, node_unit, perm, transpose_params)); + + // Calculate transposed shape directly using the permutation + std::vector transposed_shape(input_info.shape.size()); + for (size_t i = 0; i < perm.size(); ++i) { + transposed_shape[i] = input_info.shape[perm[i]]; + } + + QnnTensorWrapper transposed_tensor(transposed_name, + QNN_TENSOR_TYPE_NATIVE, + input_info.qnn_data_type, + input_info.quant_param.Copy(), + std::move(transposed_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(transposed_tensor)), + "Failed to add transposed tensor."); + + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( + utils::GetUniqueName(node_unit.Name() + "_transpose_" + input_name.substr(input_name.find_last_of('/') + 1)), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_TRANSPOSE, + {input_name}, + {transposed_name}, + std::move(transpose_params), + do_op_validation), + "Failed to create transpose node."); + + return Status::OK(); +} + +void CreateFusedMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/matmulnbits_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/matmulnbits_op_builder.cc new file mode 100755 index 0000000000000..b606b01d1d6ed --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/matmulnbits_op_builder.cc @@ -0,0 +1,381 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_quant_params_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { +/* Op Resolution + --> Incoming ONNX Node + 1. MatMulNBits + Attributes : INT64 + - accuracy_level : 4 + - bits : 4 + - block_size : 32 + - K : + - N : + + Inputs + - A : : (fp16/32) : [batch_size{1}, sequence_len, K] + - B : Init : (uint8) : [N, K/block_size, (block_size * bits) / 8] + - scales : Init : (fp32) : [N * K / block_size] + - zero_points : (optional)Init : (uint8) : [N * K / (block_size * 2)] + - bias : (optional)Init : [fp16/32] : [N] + + Outputs + - Y : : (fp16/32) : [batch_size{1}, sequence_len, N] + + <-- Outgoing QNN Node + 1. FullyConnected + Attributes + - + Inputs + - Input : (fp16/32) : [batch_size{1}, sequence_len, K] + - Weight : Static : (qint4) : [N, K] + - Scales : fp32 : [(N * K) / block_size{32}] + - Offsets : int32_t : [(N * K) / block_size{32}] + - Bias : Static :(fp16/32) : [1, N] + Outputs + - Output : (fp16/32) : [batch_size{1} * sequence_len, N] + + 2. Reshape + Inputs + - Input : (fp16/32) : [batch_size{1} * sequence_len, N] + Outputs + - Output : (fp16/32) : [batch_size{1}, sequence_len, N] +*/ + +class MatMulNBitsOpBuilder : public BaseOpBuilder { + public: + MatMulNBitsOpBuilder() : BaseOpBuilder("MatMulNBitsOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(MatMulNBitsOpBuilder); + + Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + protected: + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; + + private: + void DQQToSignedFixedPoint4(std::vector& quant_data, int64_t num_blocks, int64_t block_size) const; +}; + +void MatMulNBitsOpBuilder::DQQToSignedFixedPoint4(std::vector& quant_data, + int64_t num_blocks, + int64_t block_size) const { + for (int64_t block_idx = 0; block_idx < num_blocks; ++block_idx) { + uint32_t zero_point = 8; + for (int64_t val_idx = 0; val_idx < (block_size / 2); ++val_idx) { + SafeInt safe_index = block_idx; + safe_index *= (block_size / 2); + safe_index += val_idx; + + size_t index = gsl::narrow_cast(safe_index); + uint8_t quant_value_4x2 = quant_data[index]; + + int8_t quant_upper_value = + gsl::narrow_cast(((quant_value_4x2 >> 4) & 0xF) - zero_point); + int8_t quant_lower_value = + gsl::narrow_cast(((quant_value_4x2 >> 0) & 0xF) - zero_point); + + quant_data[index] = ((quant_upper_value & 0xF) << 4) | (quant_lower_value & 0xF); + } + } +} + +Status MatMulNBitsOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger) const { + bool is_gpu_backend = IsGpuBackend(qnn_model_wrapper.GetQnnBackendType()); + ORT_RETURN_IF_NOT(is_gpu_backend, "MatMulNBits Op Supported Only for Qnn Gpu Backend"); + + Qnn_DataType_t input_datatype = QNN_DATATYPE_FLOAT_32; + Qnn_DataType_t datatype = QNN_DATATYPE_FLOAT_32; + + NodeAttrHelper node_helper(node_unit); + + // Extract Parameters + const int64_t bits = node_helper.Get("bits", static_cast(4)); + const int64_t block_size = node_helper.Get("block_size", static_cast(32)); + + const int64_t K = node_helper.Get("K", static_cast(1)); + const int64_t N = node_helper.Get("N", static_cast(1)); + + ORT_RETURN_IF_NOT(bits == 4, "Invalid bits. Qnn Gpu Only Supports MatMulNBits with bits == 4"); + ORT_RETURN_IF_NOT(block_size == 32, "Invalid block_size. Qnn Gpu Only Supports MatMulNBits with block_size == 32"); + ORT_RETURN_IF_NOT((K % block_size) == 0, "K must be divisible by block_size"); + ORT_RETURN_IF_NOT(((N * K) % (2 * block_size)) == 0, + "Invalid configuration. N * K must be divisible by 2 * block_size"); + + const int64_t num_blocks = (N * K) / block_size; + ORT_RETURN_IF_NOT(num_blocks > 0, "Invalid configuration. (N * K) / block_size must be > 0"); + + const auto& inputs = node_unit.Inputs(); + // 1. input : Datatype should be float16 or float32 + // Float16 Dlc serialization failing, Skipping float16 support for this op builder + // TODO :: Add Float16 Support + { + const NodeUnitIODef& input_tensor = inputs[0]; + TensorInfo input_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_tensor, input_info)); + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(input_tensor.quant_param.has_value(), + input_tensor.node_arg.TypeAsProto(), + input_datatype)); + ORT_RETURN_IF(input_datatype != QNN_DATATYPE_FLOAT_32, "Unsupported Input datatype"); + } + + // 2. weight : weight supported with packed int4 into int8. + { + const NodeUnitIODef& input_tensor = inputs[1]; + TensorInfo input_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_tensor, input_info)); + + const std::vector input_shape = input_info.shape; + SafeInt safe_total_elements = std::accumulate(input_shape.begin(), + input_shape.end(), + SafeInt{1}, + std::multiplies<>()); + const int64_t total_elements = static_cast(safe_total_elements); + ORT_RETURN_IF_NOT(((total_elements * 2) == (N * K)), + "Invalid B dimensions. Qnn Gpu Only Supports MatMulNBits with bits == 4 " + "in packed format"); + } + + // 3. scales : scales only float32 datatype + { + const NodeUnitIODef& input_tensor = inputs[2]; + TensorInfo input_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_tensor, input_info)); + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(input_tensor.quant_param.has_value(), + input_tensor.node_arg.TypeAsProto(), + input_datatype)); + ORT_RETURN_IF(input_datatype != QNN_DATATYPE_FLOAT_32, "Unsupported Input datatype"); + } + + // 4. If input 3 exists, it has to be zero point. + if (inputs.size() > 3 && inputs[3].node_arg.Exists()) { + const NodeUnitIODef& input_tensor = inputs[3]; + TensorInfo input_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_tensor, input_info)); + ORT_RETURN_IF_ERROR(utils::GetQnnDataType(input_tensor.quant_param.has_value(), + input_tensor.node_arg.TypeAsProto(), + datatype)); + ORT_RETURN_IF((datatype != QNN_DATATYPE_UINT_8), "Invalid zero point datatype."); + + std::vector per_block_uint8_offset; + const auto& zero_points_tensor_name = input_tensor.node_arg.Name(); + const auto& zero_points_tensor_proto = qnn_model_wrapper.GetConstantTensor(zero_points_tensor_name); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*zero_points_tensor_proto, + per_block_uint8_offset)); + + ORT_RETURN_IF_NOT((per_block_uint8_offset.size() * 2) == (num_blocks * sizeof(uint8_t)), + "Only packed uint4 into uint8 offset supported by op builder"); + const uint8_t expected_offset_value = 0b10001000; + for (size_t i = 0; i < per_block_uint8_offset.size(); i++) { + ORT_RETURN_IF_NOT(per_block_uint8_offset[i] == expected_offset_value, "Unsupported zero point value"); + } + } + + ORT_RETURN_IF((inputs.size() > 4 && inputs[4].node_arg.Exists()) || + (inputs.size() > 5 && inputs[5].node_arg.Exists()), + "Unsupported inputs g_idx or bias"); + + // Validate Process + std::vector input_names; + ORT_RETURN_IF_ERROR(ProcessInputs(qnn_model_wrapper, node_unit, logger, input_names, true)); + ORT_RETURN_IF_ERROR(ProcessAttributesAndOutputs(qnn_model_wrapper, node_unit, std::move(input_names), + logger, true)); + + return Status::OK(); +} + +Status MatMulNBitsOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + const logging::Logger& logger, + std::vector& input_names, + bool do_op_validation) const { + if (do_op_validation) { + bool is_gpu_backend = IsGpuBackend(qnn_model_wrapper.GetQnnBackendType()); + ORT_RETURN_IF_NOT(is_gpu_backend, "MatMulNBits Op Supported Only for Qnn Gpu Backend"); + } + NodeAttrHelper node_helper(node_unit); + + // Extract Parameters + const int64_t block_size = node_helper.Get("block_size", static_cast(32)); + const int64_t K = node_helper.Get("K", static_cast(1)); + const int64_t N = node_helper.Get("N", static_cast(1)); + + // Prepare essential parameters + const int64_t num_blocks = (N * K) / block_size; + const auto& inputs = node_unit.Inputs(); + + // 1. Add Input + { + const NodeUnitIODef& input_tensor = inputs[0]; + const std::string& input_tensor_name = input_tensor.node_arg.Name(); + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_tensor_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_tensor_name; + } else { + TensorInfo input_info{}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(input_tensor, input_info)); + + QnnTensorWrapper input_tensor_wrapper; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(input_info, + input_tensor_name, + input_tensor_wrapper)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor_wrapper)), + "Failed to add tensor."); + } + input_names.push_back(input_tensor_name); + } + + // 2. Add Weights and add its Quantization Data + { + const auto& weight_tensor = inputs[1]; + const auto& scales_tensor = inputs[2]; + + const auto& weight_tensor_name = weight_tensor.node_arg.Name(); + if (qnn_model_wrapper.IsQnnTensorWrapperExist(weight_tensor_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << weight_tensor_name; + } else { + const std::vector block_sizes = {1, gsl::narrow_cast(block_size)}; + + // 2.1 Quantization Weight Data + std::vector quant_data; + Qnn_TensorType_t weight_tensor_type = qnn_model_wrapper.GetTensorType(weight_tensor_name); + const auto& weight_tensor_proto = qnn_model_wrapper.GetConstantTensor(weight_tensor_name); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*weight_tensor_proto, + quant_data, + false)); + + // 2.2 Quantization Scales + std::vector per_block_uint8_scale; + const auto& scale_tensor_proto = qnn_model_wrapper.GetConstantTensor(scales_tensor.node_arg.Name()); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*scale_tensor_proto, + per_block_uint8_scale)); + ORT_RETURN_IF_NOT(per_block_uint8_scale.size() == (num_blocks * sizeof(float)), + "Scale Initializer Invalid Size"); + float* per_block_float_scale_ptr = reinterpret_cast(per_block_uint8_scale.data()); + const std::vector per_block_float_scale(per_block_float_scale_ptr, + per_block_float_scale_ptr + num_blocks); + + // 2.3 Quantization Offsets : QNN Support only symmetric quantization with default value of 0 + std::vector per_block_int32_offset(num_blocks, 0); + + // 2.4 Transform quantized weights to signed fixed point 4. + DQQToSignedFixedPoint4(quant_data, num_blocks, block_size); + + // 2.5 Create Quantization Parameter and create Weight Tensor + QnnQuantParamsWrapper quantize_param = QnnQuantParamsWrapper(per_block_float_scale, + per_block_int32_offset, + block_sizes, + QNN_DATATYPE_SFIXED_POINT_4); + + std::vector weight_shape = {static_cast(N), static_cast(K)}; + QnnTensorWrapper weight_tensor_wrapper(weight_tensor_name, + weight_tensor_type, + QNN_DATATYPE_SFIXED_POINT_4, + std::move(quantize_param), + std::move(weight_shape), + std::move(quant_data)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(weight_tensor_wrapper)), + "Failed to add tensor."); + } + input_names.push_back(weight_tensor_name); + } + + return Status::OK(); +} + +Status MatMulNBitsOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + if (do_op_validation) { + bool is_gpu_backend = IsGpuBackend(qnn_model_wrapper.GetQnnBackendType()); + ORT_RETURN_IF_NOT(is_gpu_backend, "MatMulNBits Op Supported Only for Qnn Gpu Backend"); + } + + NodeAttrHelper node_helper(node_unit); + // Extract Parameters + const int64_t N = node_helper.Get("N", static_cast(1)); + + // 1. Add Output for Reshape + const NodeUnitIODef& output_tensor = node_unit.Outputs()[0]; + const std::string& output_tensor_name = output_tensor.node_arg.Name(); + if (qnn_model_wrapper.IsQnnTensorWrapperExist(output_tensor_name)) { + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << output_tensor_name; + } else { + QnnTensorWrapper output_tensor_wrapper; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(output_tensor, output_tensor_wrapper)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor_wrapper)), + "Failed to add output"); + } + + // 2. Add Output for Pre Reshape(FullyConnected) + const std::string pre_reshape_name = utils::GetUniqueName(output_tensor_name, "_pre_reshape"); + TensorInfo output_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(output_tensor, output_info)); + std::vector pre_reshape_shape(2); + pre_reshape_shape[0] = static_cast(std::accumulate(output_info.shape.begin(), + output_info.shape.end(), + SafeInt{1}, + std::multiplies<>()) / + N); + pre_reshape_shape[1] = gsl::narrow_cast(N); + QnnTensorWrapper output_tensor_wrapper(pre_reshape_name, + QNN_TENSOR_TYPE_NATIVE, + output_info.qnn_data_type, + output_info.quant_param.Copy(), + std::vector(pre_reshape_shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor_wrapper)), + "Failed to add tensor."); + + // 3. Add FullyConnected Op + const std::string fully_connected_node_name = utils::GetUniqueName(node_unit, QNN_OP_FULLY_CONNECTED); + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(fully_connected_node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_FULLY_CONNECTED, + std::move(input_names), + {pre_reshape_name}, + {}, + do_op_validation), + "Failed to add fused Matmul node."); + + // 4. Add Reshape Op + const bool is_graph_output = qnn_model_wrapper.IsGraphOutput(output_tensor_name); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.AddReshapeNode(pre_reshape_name, + output_tensor_name, + pre_reshape_shape, + output_info.shape, + output_info.qnn_data_type, + output_info.quant_param, + do_op_validation, + false, + is_graph_output)); + + return Status::OK(); +} + +void CreateMatMulNBitsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc new file mode 100644 index 0000000000000..02a9a5cc06f1e --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/quick_gelu_op_builder.cc @@ -0,0 +1,142 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_utils.h" + +namespace onnxruntime { +namespace qnn { + +class QuickGeluOpBuilder : public BaseOpBuilder { + public: + QuickGeluOpBuilder() : BaseOpBuilder("QuickGeluOpBuilder") {} + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QuickGeluOpBuilder); + + protected: + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const override ORT_MUST_USE_RESULT; +}; + +Status QuickGeluOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& node_unit, + std::vector&& input_names, + const logging::Logger& logger, + bool do_op_validation) const { + LOGS(logger, VERBOSE) << "Processing QuickGelu operator: " << node_unit.Name(); + + const std::string& input_name = input_names[0]; + const auto& outputs = node_unit.Outputs(); + const std::string& output_name = outputs[0].node_arg.Name(); + + NodeAttrHelper node_helper(node_unit); + float alpha = node_helper.Get("alpha", 1.702f); + + TensorInfo input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info)); + + // Skip alpha multiplication when alpha is 1.0 to reduce accumulated error + constexpr float alpha_epsilon = 1e-6f; + const bool skip_alpha_mul = std::abs(alpha - 1.0f) < alpha_epsilon; + + std::string sigmoid_input_name; + std::string sigmoid_output_name = utils::GetUniqueName(node_unit.Name() + "_sigmoid"); + + if (skip_alpha_mul) { + sigmoid_input_name = input_name; + } else { + const std::string alpha_mul_output_name = utils::GetUniqueName(node_unit.Name() + "_alpha_mul"); + sigmoid_input_name = alpha_mul_output_name; + + // The alpha tensor data type should match the input data type for element-wise multiply + std::string alpha_tensor_name = utils::GetUniqueName(node_unit.Name() + "_alpha"); + std::vector alpha_shape{1}; + Qnn_DataType_t alpha_qnn_data_type = input_info.qnn_data_type; + std::vector alpha_data; + + if (alpha_qnn_data_type == QNN_DATATYPE_FLOAT_16) { + alpha_data.resize(sizeof(MLFloat16)); + MLFloat16 alpha_fp16(alpha); + memcpy(alpha_data.data(), &alpha_fp16.val, sizeof(MLFloat16)); + } else { + alpha_data.resize(sizeof(float)); + memcpy(alpha_data.data(), &alpha, sizeof(float)); + } + + QnnTensorWrapper alpha_tensor_wrapper(alpha_tensor_name, + QNN_TENSOR_TYPE_STATIC, + alpha_qnn_data_type, + QnnQuantParamsWrapper(), + std::move(alpha_shape), + std::move(alpha_data)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_tensor_wrapper)), "Failed to add alpha tensor."); + + QnnTensorWrapper alpha_mul_output_tensor_wrapper(alpha_mul_output_name, + QNN_TENSOR_TYPE_NATIVE, + input_info.qnn_data_type, + QnnQuantParamsWrapper(), + std::vector(input_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_mul_output_tensor_wrapper)), + "Failed to add alpha_mul_output tensor."); + + // Step 1: Create Mul node for alpha * x + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_alpha_mul"), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_MULTIPLY, + {alpha_tensor_name, input_name}, + {alpha_mul_output_name}, + {}, + do_op_validation), + "Failed to create alpha_mul node."); + } + + QnnTensorWrapper sigmoid_output_tensor_wrapper(sigmoid_output_name, + QNN_TENSOR_TYPE_NATIVE, + input_info.qnn_data_type, + QnnQuantParamsWrapper(), + std::vector(input_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(sigmoid_output_tensor_wrapper)), + "Failed to add sigmoid_output tensor."); + + Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; + QnnTensorWrapper output_tensor_wrapper(output_name, + tensor_type, + input_info.qnn_data_type, + input_info.quant_param.Copy(), + std::vector(input_info.shape)); + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor_wrapper)), + "Failed to add output tensor."); + + // Step 2: Create Sigmoid node for sigmoid(alpha * x) or sigmoid(x) + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_sigmoid"), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_SIGMOID, + {sigmoid_input_name}, + {sigmoid_output_name}, + {}, + do_op_validation), + "Failed to create sigmoid node."); + + // Step 3: Create Mul node for x * sigmoid(alpha * x) or x * sigmoid(x) + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_final_mul"), + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_ELEMENT_WISE_MULTIPLY, + {input_name, sigmoid_output_name}, + {output_name}, + {}, + do_op_validation), + "Failed to create final_mul node."); + + return Status::OK(); +} + +void CreateQuickGeluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { + op_registrations.AddOpBuilder(op_type, std::make_unique()); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 164e4c3157f62..eba0a8c2615aa 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -29,6 +29,10 @@ #include "core/providers/qnn/builder/qnn_configs_helper.h" #include "core/providers/qnn/builder/qnn_utils.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +#include "core/providers/qnn/builder/qnn_windows_file_mapper.h" +#endif + // Flag to determine if Backend should do node validation for each opNode added #define DO_GRAPH_NODE_VALIDATIONS 1 @@ -441,7 +445,6 @@ void QnnLogging(const char* format, QnnLog_Level_t level, uint64_t timestamp, va_list argument_parameter) { - ORT_UNUSED_PARAMETER(level); ORT_UNUSED_PARAMETER(timestamp); if (!::onnxruntime::logging::LoggingManager::HasDefaultLogger()) { @@ -451,7 +454,8 @@ void QnnLogging(const char* format, } const auto& logger = ::onnxruntime::logging::LoggingManager::DefaultLogger(); - const auto severity = ::onnxruntime::logging::Severity::kVERBOSE; + // Map QNN log level to ORT severity + logging::Severity severity = QnnBackendManager::MapQNNLogLevelToOrtSeverity(level); const auto data_type = ::onnxruntime::logging::DataType::SYSTEM; if (logger.OutputIsEnabled(severity, data_type)) { @@ -525,6 +529,22 @@ QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity } } +/* static */ logging::Severity QnnBackendManager::MapQNNLogLevelToOrtSeverity(QnnLog_Level_t qnn_log_level) { + // Map QNN log level to ORT log severity + switch (qnn_log_level) { + case QNN_LOG_LEVEL_VERBOSE: + case QNN_LOG_LEVEL_DEBUG: + return logging::Severity::kVERBOSE; + case QNN_LOG_LEVEL_INFO: + return logging::Severity::kINFO; + case QNN_LOG_LEVEL_WARN: + return logging::Severity::kWARNING; + case QNN_LOG_LEVEL_ERROR: + default: + return logging::Severity::kERROR; + } +} + Status QnnBackendManager::ResetQnnLogLevel(std::optional ort_log_level) { std::lock_guard lock(logger_recursive_mutex_); if (!backend_setup_completed_ || logger_ == nullptr) { @@ -770,22 +790,148 @@ Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t return Status::OK(); } +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +// Callback required for allocating file mapping resources +static Qnn_ErrorHandle_t MapDmaDataCallback(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, void* notify_param) { + if (notify_param == nullptr) { + LOGS_DEFAULT(ERROR) << "MapDmaDataCallback: notify_param is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + auto callback_info = reinterpret_cast(notify_param); + + if (callback_info->backend_manager == nullptr) { + LOGS_DEFAULT(ERROR) << "MapDmaDataCallback: QnnBackendManager is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + return callback_info->backend_manager->MapDmaData(request, response, + callback_info->mapped_file_ptr, + callback_info->file_size); +} + +Qnn_ErrorHandle_t QnnBackendManager::MapDmaData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, + void* const mapped_base_ptr, + const size_t file_size) { + if (!file_mapped_weights_enabled_) { + LOGS(*logger_, WARNING) << "Attempting to map DMA data but file mapping has been disabled, " + << "possibly due to an error in a previous request."; + return QNN_CONTEXT_ERROR_ABORTED; + } + + if (mapped_base_ptr == nullptr) { + LOGS(*logger_, ERROR) << "Attempting to map DMA data for null memory mapped base pointer"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + LOGS(*logger_, INFO) << "Mapping DMA data for request: memory mapped base pointer(" + << mapped_base_ptr << "), offset(" << request.offset + << "), size(" << request.size << "), total file size(" + << file_size << ") isBackendMappingNeeded(" + << request.isBackendMappingNeeded << ")"; + + auto size = request.size; + if (size == 0 || !request.isBackendMappingNeeded) { + LOGS(*logger_, ERROR) << "Mapping request size must be > 0 with backend mapping required"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + // offset & size are type uint64_t + // Should never be an issue, but if this occurs then there is something inherently wrong with QNN + if ((UINT64_MAX - request.offset) < size) { + LOGS(*logger_, ERROR) << "Critical error in QNN: mapping request offset + size will overflow 64 bits"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + // file_size will be promoted to 64 bits on 32-bit systems + if ((request.offset + size) > file_size) { + LOGS(*logger_, ERROR) << "Requested offset and size includes memory outside of mapped file"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + void* unaligned_data_ptr = static_cast(mapped_base_ptr) + request.offset; + rpcmem_library_->Api().register_buf(unaligned_data_ptr, size, NULL, + rpcmem::RPCMEM_ATTR_IMPORT_BUFFER | rpcmem::RPCMEM_ATTR_READ_ONLY); + + auto fd = rpcmem_library_->Api().to_fd(unaligned_data_ptr); + if (fd == -1) { + LOGS(*logger_, ERROR) << "Failed to register DMA data mapping to RPCMEM"; + return QNN_COMMON_ERROR_SYSTEM; + } + + LOGS(*logger_, INFO) << "Created DMA data mapping with address: " << unaligned_data_ptr; + + response->dmaBuffer.fd = fd; + response->dmaBuffer.data = unaligned_data_ptr; + response->dataStartOffset = 0; + response->alignedSize = size; + + return QNN_SUCCESS; +} + +// Callback required for releasing file mapping resources +static Qnn_ErrorHandle_t ReleaseDmaDataCallback(Qnn_ContextBinaryDmaDataMem_t data_mem, void* notify_param) { + if (notify_param == nullptr) { + LOGS_DEFAULT(ERROR) << "ReleaseDmaDataCallback: notify_param is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + auto callback_info = reinterpret_cast(notify_param); + + if (callback_info->backend_manager == nullptr) { + LOGS_DEFAULT(ERROR) << "ReleaseDmaDataCallback: QnnBackendManager is null"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + return callback_info->backend_manager->ReleaseDmaData(data_mem, callback_info->mapped_file_ptr); +} + +// Use LOGS_DEFAULT here as this function will be called during destruction of QnnBackendManager +// At time of destruction, usage of logger_ will not be available and will result in a seg fault +Qnn_ErrorHandle_t QnnBackendManager::ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, + void* mapped_base_ptr) { + if (mapped_base_ptr == nullptr) { + LOGS_DEFAULT(ERROR) << "Attempting to release DMA data for null memory mapped pointer"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + LOGS_DEFAULT(INFO) << "Releasing DMA data mapping for memory mapped pointer(" + << mapped_base_ptr << "), address(" << data_mem.dmaBuffer.data + << "), size: (" << data_mem.memSize << ")"; + + if (data_mem.dmaBuffer.data == nullptr || data_mem.memSize == 0) { + LOGS_DEFAULT(ERROR) << "Mapping release request address must not be null and size must be > 0"; + return QNN_CONTEXT_ERROR_INVALID_ARGUMENT; + } + + // Deregister file mapped data from NPU regardless of file_mapped_weights_enabled_ + // as there may be file mapped data registered to the NPU prior to any mapping error + void* unaligned_data_ptr = data_mem.dmaBuffer.data; + rpcmem_library_->Api().register_buf(unaligned_data_ptr, data_mem.memSize, -1, + rpcmem::RPCMEM_ATTR_IMPORT_BUFFER | rpcmem::RPCMEM_ATTR_READ_ONLY); + + auto fd = rpcmem_library_->Api().to_fd(unaligned_data_ptr); + if (fd != -1) { + LOGS_DEFAULT(ERROR) << "Failed to deregister buffer from RPCMEM: " << unaligned_data_ptr; + return QNN_CONTEXT_ERROR_MEM_ALLOC; + } + return QNN_SUCCESS; +} +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + // callback required to add context handles to class list // when using contextCreateFromBinaryListAsync() -void ContextCreateAsyncCallback(Qnn_ContextHandle_t context, - Qnn_GraphHandle_t graph, - const char* graphName, - QnnContext_createFromBinaryAsyncNotifyType_t notifyType, - void* notifyParam, - Qnn_ErrorHandle_t status) { +static void ContextCreateAsyncCallback(Qnn_ContextHandle_t context, + Qnn_GraphHandle_t /* graph */, + const char* /* graph_name */, + QnnContext_createFromBinaryAsyncNotifyType_t /* notify_type */, + void* notify_param, + Qnn_ErrorHandle_t /* status */) { auto qnn_backend_manager = SharedContext::GetInstance().GetSharedQnnBackendManager(); if (context) { - qnn_backend_manager->ProcessContextFromBinListAsync(context, notifyParam); - } - - if (nullptr == graphName || graph || notifyType || status) { - // Avoid compilation unused var warning error + qnn_backend_manager->ProcessContextFromBinListAsync(context, notify_param); } } @@ -809,6 +955,41 @@ void QnnBackendManager::ProcessContextFromBinListAsync(Qnn_ContextHandle_t conte } } +Status QnnBackendManager::GetFileSizeIfValid(const std::string& filepath, + size_t& file_size) { + std::error_code ec; + ORT_RETURN_IF(!std::filesystem::exists(filepath, ec), "Context binary does not exist: ", filepath); + ORT_RETURN_IF(ec, "Failed to read file: ", filepath, + ", error: ", ec.message()); + + auto size = std::filesystem::file_size(filepath, ec); + ORT_RETURN_IF(ec, "Failed to retrieve size of file: ", filepath, + ", error: ", ec.message()); + + ORT_RETURN_IF(size == 0, "File is empty: ", filepath); + ORT_RETURN_IF(size > SIZE_MAX, "File (", filepath, ") file size (", size, + " bytes) exceeds maximum value of size_t for this platform (", SIZE_MAX, " bytes)."); + + file_size = static_cast(size); + return Status::OK(); +} + +Status QnnBackendManager::ReadContextBinIfValid(const std::string& context_bin_filepath, + std::vector& buffer) { + size_t buffer_size; + ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_size)); + + buffer.resize(buffer_size); + + std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary); + ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to read context binary from: ", context_bin_filepath); + + const auto& read_result = cache_file.read(buffer.data(), buffer_size); + ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); + + return Status::OK(); +} + Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map) { #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) QnnContext_Config_t context_config_resource_sharing = QNN_CONTEXT_CONFIG_INIT; @@ -845,10 +1026,27 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord #endif nullptr}; +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + if (file_mapped_weights_enabled_ && file_mapper_) { + // Retry logic -- if context creation failed with file mapped weights, then retry with feature disabled + auto res = CreateContextFromListAsyncWithCallback(configs, context_bin_map); + if (!res.IsOK()) { + LOGS(*logger_, WARNING) << res.ErrorMessage() << ". Retrying with feature disabled."; + } else { + return Status::OK(); + } + } +#endif + return CreateContextFromListAsync(configs, context_bin_map); +} + +Status QnnBackendManager::CreateContextFromListAsync(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map) { std::vector context_params_list; std::vector context_paramsv1_list; std::vector context_params_ptr_list; - std::vector> buffer_list; + std::vector> buffer_list; context_params_list.reserve(context_bin_map.size()); context_params_ptr_list.reserve(context_bin_map.size() + 1); @@ -856,22 +1054,14 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord for (auto& it : context_bin_map) { auto context_bin_filepath = it.first; - std::ifstream cache_file(context_bin_filepath.c_str(), std::ifstream::binary); - ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to retrieve context binary from: ", context_bin_filepath); + std::vector buffer; + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, buffer)); - cache_file.seekg(0, cache_file.end); - size_t buffer_size = static_cast(cache_file.tellg()); - ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered."); - - cache_file.seekg(0, cache_file.beg); - std::unique_ptr buffer = std::make_unique(buffer_size); - ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file."); - const auto& read_result = cache_file.read(buffer.get(), buffer_size); - ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file."); + size_t buffer_size = buffer.size(); + buffer_list.push_back(std::move(buffer)); - cache_file.close(); QnnContext_ParamsV1_t context_params_v1 = {nullptr, - buffer.get(), + buffer_list.back().data(), buffer_size, nullptr, ContextCreateAsyncCallback, @@ -880,7 +1070,6 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_1, {context_params_v1}}; - buffer_list.push_back(std::move(buffer)); context_params_list.push_back(std::move(context_params)); context_paramsv1_list.push_back(std::move(context_params_v1)); context_params_ptr_list.push_back(&context_params_list.back()); @@ -892,15 +1081,76 @@ Status QnnBackendManager::CreateContextVtcmBackupBufferSharingEnabled(std::unord configs, nullptr); - context_params_ptr_list.clear(); - context_paramsv1_list.clear(); - context_params_list.clear(); - buffer_list.clear(); - ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context. Error: ", QnnErrorHandleToString(result), ", Code:", result); return Status::OK(); } +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +Status QnnBackendManager::CreateContextFromListAsyncWithCallback(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map) { + std::vector context_params_list; + std::vector context_paramsv2_list; + std::vector context_callbacks_list; + std::vector context_params_ptr_list; + + context_params_list.reserve(context_bin_map.size()); + context_paramsv2_list.reserve(context_bin_map.size()); + context_callbacks_list.reserve(context_bin_map.size()); + context_params_ptr_list.reserve(context_bin_map.size() + 1); + + for (auto& it : context_bin_map) { + auto context_bin_filepath = it.first; + + size_t buffer_size; + ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_size)); + + void* buffer; + ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappedMemoryPtr(context_bin_filepath, &buffer)); + + auto notify_param_ptr = std::make_unique(buffer, buffer_size, this); + + Qnn_ContextBinaryCallback_t context_file_map_callbacks; + context_file_map_callbacks.type = QNN_CONTEXT_CALLBACK_DMA_BUFFER; + context_file_map_callbacks.dmaBufferCallback.version = QNN_CONTEXT_CALLBACK_DMA_BUFFER_VERSION_1; + context_file_map_callbacks.dmaBufferCallback.v1.dataProvide = MapDmaDataCallback; + context_file_map_callbacks.dmaBufferCallback.v1.dataRelease = ReleaseDmaDataCallback; + context_file_map_callbacks.dmaBufferCallback.v1.notifyParam = reinterpret_cast(notify_param_ptr.get()); + + file_mapping_notify_params_.push_back(std::move(notify_param_ptr)); + context_callbacks_list.push_back(std::move(context_file_map_callbacks)); + + // Callbacks require QnnContext_ParamsV2_t which is new to QNN API 2.32 + QnnContext_ParamsV2_t context_params_v2 = {nullptr, + buffer, + buffer_size, + nullptr, + ContextCreateAsyncCallback, + it.second.get(), + &context_callbacks_list.back()}; + + QnnContext_Params_t context_params = {QnnContext_ParamsVersion_t::QNN_CONTEXT_PARAMS_VERSION_2, + {}}; + + context_paramsv2_list.push_back(std::move(context_params_v2)); + + context_params.v2 = &context_paramsv2_list.back(); + context_params_list.push_back(std::move(context_params)); + context_params_ptr_list.push_back(&(context_params_list.back())); + } + context_params_ptr_list.push_back(nullptr); + auto result = qnn_interface_.contextCreateFromBinaryListAsync(backend_handle_, + device_handle_, + context_params_ptr_list.data(), + configs, + nullptr); + + ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context with file mapping enabled. Error: ", + QnnErrorHandleToString(result), ", Code:", result); + return Status::OK(); +} +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Status QnnBackendManager::SetContextPriority(ContextPriority context_priority) { QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority, context_priority_config)); @@ -918,7 +1168,7 @@ Status QnnBackendManager::ResetContextPriority() { return SetContextPriority(context_priority_); } -Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { +Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing, bool enable_htp_extended_udma_mode) { if (true == context_created_) { LOGS_DEFAULT(INFO) << "Context created already."; return Status::OK(); @@ -934,8 +1184,16 @@ Status QnnBackendManager::CreateContext(bool enable_htp_weight_sharing) { QnnContext_Config_t context_priority_config = QNN_CONTEXT_CONFIG_INIT; ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, context_priority_config)); + QnnContext_Config_t context_config_extended_udma = QNN_CONTEXT_CONFIG_INIT; + QnnHtpContext_CustomConfig_t udma_custom_config; + udma_custom_config.option = QNN_HTP_CONTEXT_CONFIG_OPTION_USE_EXTENDED_UDMA; + udma_custom_config.useExtendedUdma = enable_htp_extended_udma_mode; + context_config_extended_udma.option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + context_config_extended_udma.customConfig = &udma_custom_config; + const QnnContext_Config_t* npu_context_configs[] = {&context_priority_config, &context_config_weight_sharing, + &context_config_extended_udma, nullptr}; const QnnContext_Config_t* empty_context_configs[] = {nullptr}; @@ -1098,6 +1356,7 @@ Status QnnBackendManager::GetMaxSpillFillBufferSize(unsigned char* buffer, } Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + const std::string& context_bin_filepath, std::string node_name, QnnModelLookupTable& qnn_models, int64_t max_spill_fill_size) { @@ -1106,6 +1365,24 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t nullptr == qnn_sys_interface_.systemContextFree; ORT_RETURN_IF(result, "Failed to get valid function pointer."); + void* bin_buffer = nullptr; +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + if (file_mapped_weights_enabled_) { + ORT_RETURN_IF(!file_mapper_, "Attemping to use File Mapping feature but file_mapper_ is uninitialized"); + + ORT_RETURN_IF_ERROR(GetFileSizeIfValid(context_bin_filepath, buffer_length)); + + ORT_RETURN_IF(buffer_length == 0, "Context bin has a size of 0 bytes: ", context_bin_filepath); + ORT_RETURN_IF_ERROR(file_mapper_->GetContextBinMappedMemoryPtr(context_bin_filepath, &bin_buffer)); + + } else { + ORT_RETURN_IF(buffer == nullptr, "Attempting to load QNN context from buffer but buffer is null"); + bin_buffer = static_cast(buffer); + } +#else + bin_buffer = static_cast(buffer); +#endif + QnnSystemContext_Handle_t sys_ctx_handle = nullptr; auto rt = qnn_sys_interface_.systemContextCreate(&sys_ctx_handle); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create system handle."); @@ -1113,7 +1390,7 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t const QnnSystemContext_BinaryInfo_t* binary_info = nullptr; Qnn_ContextBinarySize_t binary_info_size{0}; rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle, - static_cast(buffer), + bin_buffer, buffer_length, &binary_info, &binary_info_size); @@ -1188,6 +1465,26 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_ContextBinaryCallback_t callbacks; + if (file_mapped_weights_enabled_ && file_mapper_) { + ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinaryWithCallback, + "Invalid function pointer for contextCreateFromBinaryWithCallback."); + + auto notify_param_ptr = std::make_unique(bin_buffer, buffer_length, this); + + callbacks.type = QNN_CONTEXT_CALLBACK_DMA_BUFFER; + callbacks.dmaBufferCallback.version = QNN_CONTEXT_CALLBACK_DMA_BUFFER_VERSION_1; + callbacks.dmaBufferCallback.v1.dataProvide = MapDmaDataCallback; + callbacks.dmaBufferCallback.v1.dataRelease = ReleaseDmaDataCallback; + callbacks.dmaBufferCallback.v1.notifyParam = reinterpret_cast(notify_param_ptr.get()); + + file_mapping_notify_params_.push_back(std::move(notify_param_ptr)); + } +#else + ORT_UNUSED_PARAMETER(context_bin_filepath); +#endif + qnn::profile::ProfilingInfo profiling_info; #ifdef QNN_SYSTEM_PROFILE_API_ENABLED if (ProfilingEnabled()) { @@ -1195,13 +1492,41 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t } #endif - rt = qnn_interface_.contextCreateFromBinary(backend_handle_, - device_handle_, - context_configs, - static_cast(buffer), - buffer_length, - &context, - profile_backend_handle_); +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + std::vector backup_buffer; + if (file_mapped_weights_enabled_ && file_mapper_) { + rt = qnn_interface_.contextCreateFromBinaryWithCallback(backend_handle_, + device_handle_, + context_configs, + &callbacks, + bin_buffer, + buffer_length, + &context, + profile_backend_handle_, + NULL); + + if (rt != QNN_SUCCESS) { + LOGS(*logger_, WARNING) << "Failed to create context with file mapping enabled. Error: " + << QnnErrorHandleToString(rt) << ", Code : " << rt + << ". Retrying with feature disabled."; + + // Read context bin from file since file mapping has failed + ORT_RETURN_IF_ERROR(ReadContextBinIfValid(context_bin_filepath, backup_buffer)); + + bin_buffer = static_cast(backup_buffer.data()); + } + } +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + + if (!file_mapped_weights_enabled_ || rt != QNN_SUCCESS) { + rt = qnn_interface_.contextCreateFromBinary(backend_handle_, + device_handle_, + context_configs, + bin_buffer, + buffer_length, + &context, + profile_backend_handle_); + } #ifdef QNN_SYSTEM_PROFILE_API_ENABLED if (ProfilingEnabled()) { @@ -1249,7 +1574,10 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool need_load_system_lib, bool share_ep_contexts, bool enable_vtcm_backup_buffer_sharing, - std::unordered_map>>& context_bin_map) { + bool enable_file_mapped_weights, + std::shared_ptr rpcmem_library, + std::unordered_map>>& context_bin_map, + bool enable_htp_extended_udma_mode) { std::lock_guard lock(logger_recursive_mutex_); if (backend_setup_completed_) { LOGS(logger, VERBOSE) << "Backend setup already!"; @@ -1288,6 +1616,20 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, } else { status = LoadQnnSerializerBackend(); } + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + // Backend is determined after LoadBackend() or LoadQnnSerializerBackend() + if (enable_file_mapped_weights && !file_mapper_ && GetQnnBackendType() == QnnBackendType::HTP) { + ORT_RETURN_IF(!rpcmem_library, "RPCMem Library is required for file mapping but is uninitialized."); + rpcmem_library_ = rpcmem_library; + file_mapped_weights_enabled_ = true; + file_mapper_ = std::make_unique(logger); + } +#else + ORT_UNUSED_PARAMETER(enable_file_mapped_weights); + ORT_UNUSED_PARAMETER(rpcmem_library); +#endif + if (status.IsOK()) { LOGS(logger, VERBOSE) << "LoadBackend succeed."; } @@ -1346,7 +1688,7 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, if (status.IsOK() && (vtcm_backup_buffer_sharing_enabled_ || !load_from_cached_context)) { status = vtcm_backup_buffer_sharing_enabled_ ? CreateContextVtcmBackupBufferSharingEnabled(context_bin_map) - : CreateContext(enable_htp_weight_sharing); + : CreateContext(enable_htp_weight_sharing, enable_htp_extended_udma_mode); if (status.IsOK()) { LOGS(logger, VERBOSE) << "CreateContext succeed."; @@ -1529,7 +1871,6 @@ void QnnBackendManager::ReleaseResources() { } backend_setup_completed_ = false; - return; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index f1c6c19bb1311..dfa40a2c8aa0d 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -25,6 +25,7 @@ #include "System/QnnSystemInterface.h" #include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/rpcmem_library.h" #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" #include "core/providers/qnn/builder/qnn_def.h" @@ -32,6 +33,10 @@ #include "core/providers/qnn/builder/qnn_profile_serializer.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +#include "core/providers/qnn/builder/qnn_file_mapping_interface.h" +#endif + namespace onnxruntime { namespace qnn { @@ -154,6 +159,7 @@ class QnnBackendManager : public std::enable_shared_from_this std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + const std::string& context_bin_filepath, std::string node_name, std::unordered_map>& qnn_models, int64_t max_spill_fill_size); @@ -163,7 +169,10 @@ class QnnBackendManager : public std::enable_shared_from_this Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context, bool need_load_system_lib, bool share_ep_contexts, bool enable_vtcm_backup_buffer_sharing, - std::unordered_map>>& context_bin_map); + bool enable_file_mapped_weights, + std::shared_ptr rpcmem_library, + std::unordered_map>>& context_bin_map, + bool enable_htp_extended_udma_mode); Status CreateHtpPowerCfgId(uint32_t deviceId, uint32_t coreId, uint32_t& htp_power_config_id); @@ -212,6 +221,8 @@ class QnnBackendManager : public std::enable_shared_from_this void SetQnnBackendType(uint32_t backend_id); QnnBackendType GetQnnBackendType() { return qnn_backend_type_; } + uint32_t GetSocModel() const { return soc_model_; } + const std::string& GetSdkVersion() { return sdk_build_version_; } Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id); @@ -246,6 +257,34 @@ class QnnBackendManager : public std::enable_shared_from_this bool ProfilingEnabled() { return profiling_enabled_; } #endif + bool FileMappingIsEnabled() { + return file_mapped_weights_enabled_; + } + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Qnn_ErrorHandle_t MapDmaData(Qnn_ContextBinaryDataRequest_t request, + Qnn_ContextBinaryDmaDataResponse_t* response, + void* const mapped_base_ptr, + const size_t file_size); + + Qnn_ErrorHandle_t ReleaseDmaData(Qnn_ContextBinaryDmaDataMem_t data_mem, void* mapped_base_ptr); +#endif + + QnnLog_Level_t MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level); + static logging::Severity MapQNNLogLevelToOrtSeverity(QnnLog_Level_t qnn_log_level); + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + typedef struct FileMappingCallbackInfo { + void* const mapped_file_ptr; + const size_t file_size; + QnnBackendManager* const backend_manager; + + FileMappingCallbackInfo(void* ptr, size_t size, QnnBackendManager* manager) + : mapped_file_ptr(ptr), file_size(size), backend_manager(manager) {} + + } FileMappingCallbackInfo_t; +#endif + private: Status LoadBackend(); @@ -261,11 +300,26 @@ class QnnBackendManager : public std::enable_shared_from_this Status ReleaseProfilehandle(); - Status CreateContext(bool enable_htp_weight_sharing); + Status CreateContext(bool enable_htp_weight_sharing, bool enable_htp_extended_udma_mode); + + Status GetFileSizeIfValid(const std::string& filepath, size_t& file_size); + + Status ReadContextBinIfValid(const std::string& context_bin_filepath, + std::vector& buffer); Status CreateContextVtcmBackupBufferSharingEnabled(std::unordered_map>>& context_bin_map); + Status CreateContextFromListAsync(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map); + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + Status CreateContextFromListAsyncWithCallback(const QnnContext_Config_t** configs, + std::unordered_map>>& context_bin_map); +#endif + Status ReleaseContext(); // Sets the ORT logger and creates a corresponding QNN logger with the same log level. @@ -325,7 +379,6 @@ class QnnBackendManager : public std::enable_shared_from_this const char* QnnProfileErrorToString(QnnProfile_Error_t error); std::string QnnErrorHandleToString(Qnn_ErrorHandle_t error); - QnnLog_Level_t MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level); // Adds a new QNN context. // Transfers ownership of `context_handle` (i.e., responsibility of freeing it) to this instance @@ -451,6 +504,15 @@ class QnnBackendManager : public std::enable_shared_from_this bool context_created_ = false; bool backend_setup_completed_ = false; bool vtcm_backup_buffer_sharing_enabled_ = false; + bool file_mapped_weights_enabled_ = false; + +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + std::unique_ptr file_mapper_ = nullptr; + // Notify params for file mapping must persist throughout lifetime of + // QnnBackendManager for release of DMA data callback on destruction + std::vector> file_mapping_notify_params_; +#endif + // NPU backend requires quantized model QnnBackendType qnn_backend_type_ = QnnBackendType::CPU; Qnn_ProfileHandle_t profile_backend_handle_ = nullptr; @@ -469,6 +531,8 @@ class QnnBackendManager : public std::enable_shared_from_this // Mapping of thread id to on-run-start/end power configs std::mutex per_thread_power_configs_mutex_; std::unordered_map per_thread_power_configs_; + + std::shared_ptr rpcmem_library_ = nullptr; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.cc b/onnxruntime/core/providers/qnn/builder/qnn_def.cc index 9f28e2609faa1..3d7193d70e6f5 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.cc @@ -456,7 +456,7 @@ bool CreateTensorInQnnGraph(const QNN_INTERFACE_VER_TYPE& qnn_interface, return false; } // verify size expressed by the dims matches the raw tensor size - uint32_t qnn_tensor_size = CalcQnnTensorNumElems(qnn_tensor) * gsl::narrow_cast(data_size); + const auto qnn_tensor_size = utils::GetQnnTensorDataSizeInBytes(qnn_tensor); auto qnn_tensor_buf_size = GetQnnTensorClientBuf(qnn_tensor).dataSize; if (qnn_tensor_size != qnn_tensor_buf_size) { ss << "Data length mismatch for static tensor. node_name: " << node_name diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 625166f62d166..847de084c49f6 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -19,6 +19,12 @@ namespace qnn { #define QNN_SYSTEM_PROFILE_API_ENABLED #endif +#if defined(_WIN32) && (defined(__aarch64__) || defined(_M_ARM64)) +#if QNN_API_VERSION_MAJOR > 2 || ((QNN_API_VERSION_MAJOR) == 2 && (QNN_API_VERSION_MINOR >= 32)) +#define QNN_FILE_MAPPED_WEIGHTS_AVAILABLE +#endif +#endif + // QNN only support subset of POSIX of dlopen/dlsym/dladdr/dlerror/dlclose // except the following flags for dlopen, others should be done only // when we really need them diff --git a/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_interface.h b/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_interface.h new file mode 100644 index 0000000000000..f99cc7b1ee5dd --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_file_mapping_interface.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_def.h" + +namespace onnxruntime { +namespace qnn { + +class FileMappingInterface { + public: + virtual ~FileMappingInterface() = default; + + virtual Status GetContextBinMappedMemoryPtr(const std::string& bin_filepath, + void** mapped_data_ptr) = 0; +}; + +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc index 1e4ba6afe6f0b..6032623541384 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc @@ -222,6 +222,187 @@ Status QnnModelWrapper::ValidateQnnNode(const std::string& node_name, return Status::OK(); } +bool QnnModelWrapper::CreateBF16CastTensor(const std::string& tensor_name, + const std::vector& shape, + Qnn_TensorType_t tensor_type) { + QnnTensorWrapper bf16_tensor(tensor_name, tensor_type, QNN_DATATYPE_BFLOAT_16, + QnnQuantParamsWrapper(), std::vector(shape)); + if (!AddTensorWrapper(std::move(bf16_tensor))) { + LOGS(logger_, ERROR) << "BF16: Failed to add tensor: " << tensor_name; + return false; + } + return true; +} + +bool QnnModelWrapper::ProcessBF16InputConversion(const std::string& qnn_node_name, + const std::vector& input_names, + std::vector& converted_input_names, + std::vector& cast_ops_to_add) { + ORT_UNUSED_PARAMETER(qnn_node_name); + + for (size_t i = 0; i < input_names.size(); ++i) { + const auto& input_name = input_names[i]; + + auto it = model_tensors_map_.find(input_name); + if (it == model_tensors_map_.end()) { + LOGS(logger_, ERROR) << "BF16: Input tensor not found: " << input_name; + return false; + } + + auto& tensor_wrapper = it->second; + Qnn_DataType_t tensor_dtype = tensor_wrapper.GetTensorDataType(); + Qnn_TensorType_t tensor_type = tensor_wrapper.GetTensorType(); + bool is_graph_input_or_init = IsGraphInput(input_name) || IsConstantInput(input_name) || IsGraphOutput(input_name); + + if (is_graph_input_or_init && tensor_dtype == QNN_DATATYPE_FLOAT_32) { + // Insert Cast node for FP32 graph inputs/initializers: FP32 -> BF16 + std::string cast_output_name = input_name + "_bf16_intermediate"; + + if (!IsQnnTensorWrapperExist(cast_output_name)) { + std::vector shape = tensor_wrapper.GetTensorDims(); + + if (!CreateBF16CastTensor(cast_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) { + return false; + } + + LOGS(logger_, VERBOSE) << "BF16: Adding Cast op " << input_name << " -> " << cast_output_name; + + QnnOpProperty cast_op(cast_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST, + std::vector{input_name}, + std::vector{cast_output_name}, + std::vector{}); + cast_ops_to_add.push_back(std::move(cast_op)); + } + converted_input_names.push_back(cast_output_name); + } else if (tensor_type == QNN_TENSOR_TYPE_NATIVE && tensor_dtype == QNN_DATATYPE_FLOAT_32) { + // Convert intermediate FP32 tensors to BF16 directly + SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16); + converted_input_names.push_back(input_name); + } else if (tensor_type == QNN_TENSOR_TYPE_STATIC && !IsConstantInput(input_name) && tensor_dtype == QNN_DATATYPE_FLOAT_32) { + // Initializers that are created in QNN and are not present in ONNX + std::string cast_output_name = input_name + "_bf16_intermediate"; + if (!IsQnnTensorWrapperExist(cast_output_name)) { + std::vector shape = tensor_wrapper.GetTensorDims(); + if (!CreateBF16CastTensor(cast_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) { + return false; + } + LOGS(logger_, VERBOSE) << "BF16: Adding Cast op for static tensor " << input_name << " -> " << cast_output_name; + QnnOpProperty cast_op(cast_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST, + std::vector{input_name}, + std::vector{cast_output_name}, + std::vector{}); + cast_ops_to_add.push_back(std::move(cast_op)); + } + converted_input_names.push_back(cast_output_name); + } else { + converted_input_names.push_back(input_name); + } + } + + return true; +} + +bool QnnModelWrapper::ProcessBF16OutputConversion(const std::string& qnn_node_name, + const std::vector& output_names, + std::vector& converted_output_names, + std::vector>& graph_output_cast_ops) { + ORT_UNUSED_PARAMETER(qnn_node_name); + + for (size_t i = 0; i < output_names.size(); ++i) { + const auto& output_name = output_names[i]; + + auto it = model_tensors_map_.find(output_name); + if (it == model_tensors_map_.end()) { + continue; + } + auto& tensor_wrapper = it->second; + Qnn_DataType_t tensor_dtype = tensor_wrapper.GetTensorDataType(); + Qnn_TensorType_t tensor_type = tensor_wrapper.GetTensorType(); + + if (IsGraphOutput(output_name) && + (tensor_dtype == QNN_DATATYPE_FLOAT_32 || tensor_dtype == QNN_DATATYPE_BFLOAT_16)) { + // For FP32 graph outputs, insert Cast node to convert BF16 back to FP32 + std::string bf16_output_name = utils::GetUniqueName(output_name, "_bf16_intermediate"); + + if (!IsQnnTensorWrapperExist(bf16_output_name)) { + std::vector shape = tensor_wrapper.GetTensorDims(); + + if (!CreateBF16CastTensor(bf16_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) { + return false; + } + LOGS(logger_, VERBOSE) << "BF16: Adding Cast op " << bf16_output_name << " -> " << output_name; + graph_output_cast_ops.push_back({bf16_output_name, output_name}); + } + converted_output_names.push_back(bf16_output_name); + } else if (tensor_type == QNN_TENSOR_TYPE_NATIVE && tensor_dtype == QNN_DATATYPE_FLOAT_32) { + // Convert intermediate FP32 tensors to BF16 directly + SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16); + converted_output_names.push_back(output_name); + } else { + converted_output_names.push_back(output_name); + } + } + + return true; +} + +bool QnnModelWrapper::ApplyBF16ConversionForValidation(const std::vector& input_names, + const std::vector& output_names, + std::vector& validation_input_names, + std::vector& validation_output_names) { + // Temporarily convert FP32 tensors to BF16 for validation + for (const auto& input_name : input_names) { + auto it = model_tensors_map_.find(input_name); + if (it == model_tensors_map_.end()) { + LOGS(logger_, ERROR) << "BF16: Validation failed - input tensor not found: " << input_name; + return false; + } + + auto& tensor_wrapper = it->second; + if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_FLOAT_32) { + SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16); + } + validation_input_names.push_back(input_name); + } + + for (const auto& output_name : output_names) { + auto it = model_tensors_map_.find(output_name); + if (it != model_tensors_map_.end()) { + auto& tensor_wrapper = it->second; + if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_FLOAT_32) { + SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16); + } + } + validation_output_names.push_back(output_name); + } + + return true; +} + +void QnnModelWrapper::RestoreFP32AfterValidation(const std::vector& input_names, + const std::vector& output_names) { + // Restore FP32 data types after validation + for (const auto& input_name : input_names) { + auto it = model_tensors_map_.find(input_name); + if (it != model_tensors_map_.end()) { + auto& tensor_wrapper = it->second; + if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_BFLOAT_16) { + SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_FLOAT_32); + } + } + } + + for (const auto& output_name : output_names) { + auto it = model_tensors_map_.find(output_name); + if (it != model_tensors_map_.end()) { + auto& tensor_wrapper = it->second; + if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_BFLOAT_16) { + SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_FLOAT_32); + } + } + } +} + bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, const std::string& package_name, const std::string& qnn_node_type, @@ -233,15 +414,31 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, std::vector input_tensors; std::vector output_tensors; std::vector params; - if (!CreateQnnInputOutputTensors(qnn_node_name, input_names, input_tensors, do_op_validation)) { - return false; - } - if (!CreateQnnInputOutputTensors(qnn_node_name, output_names, output_tensors, do_op_validation)) { - return false; + // Apply BF16 conversion for validation if enabled + std::vector validation_input_names; + std::vector validation_output_names; + + // Use RAII guard for BF16 conversion to ensure cleanup + std::unique_ptr bf16_guard; + + if (IsBF16ConversionEnabled()) { + LOGS(logger_, VERBOSE) << "[BF16] Validation with BF16 conversion enabled"; + if (!ApplyBF16ConversionForValidation(input_names, output_names, validation_input_names, validation_output_names)) { + LOGS(logger_, ERROR) << "[BF16] ApplyBF16ConversionForValidation failed for node: " << qnn_node_name; + return false; + } + // Create the guard after successful conversion + bf16_guard = std::make_unique(this, input_names, output_names); + } else { + validation_input_names = input_names; + validation_output_names = output_names; } - if (!CreateQnnParamTensors(qnn_node_name, param_tensor_names, params, do_op_validation)) { + // Create tensors for validation + if (!CreateQnnInputOutputTensors(qnn_node_name, validation_input_names, input_tensors, do_op_validation) || + !CreateQnnInputOutputTensors(qnn_node_name, validation_output_names, output_tensors, do_op_validation) || + !CreateQnnParamTensors(qnn_node_name, param_tensor_names, params, do_op_validation)) { return false; } @@ -257,6 +454,7 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, std::string error_msg; bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg); + if (!rt) { // TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more // specific validation error (instead of "failed to add node"). @@ -264,6 +462,7 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, } return rt; } else { + // Standard execution - just add the node to the op list QnnOpProperty qnn_op(qnn_node_name, package_name, qnn_node_type, std::move(input_names), std::move(output_names), std::move(param_tensor_names)); qnn_op_property_list_.push_back(std::move(qnn_op)); @@ -271,6 +470,70 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name, } } +bool QnnModelWrapper::ProcessBF16Conversions(std::vector& final_ops) { + std::vector processed_ops; + std::vector input_cast_ops; + + for (const auto& op_property : qnn_op_property_list_) { + // Make copies of the strings to avoid reference invalidation + std::string qnn_node_name = op_property.GetNodeName(); + std::string package_name = op_property.GetPackageName(); + std::string qnn_node_type = op_property.GetNodeType(); + std::vector input_names = op_property.GetInputNames(); + std::vector output_names = op_property.GetOutputNames(); + std::vector param_tensor_names = op_property.GetParamTensorNames(); + + LOGS(logger_, VERBOSE) << "[BF16] Processing node for BF16 conversion: " << qnn_node_name; + + std::vector converted_input_names; + std::vector converted_output_names; + std::vector> graph_output_cast_ops; + + if (!ProcessBF16InputConversion(qnn_node_name, input_names, converted_input_names, input_cast_ops)) { + LOGS(logger_, ERROR) << "[BF16] ProcessBF16InputConversion failed for node: " << qnn_node_name; + return false; + } + + if (!ProcessBF16OutputConversion(qnn_node_name, output_names, converted_output_names, graph_output_cast_ops)) { + LOGS(logger_, ERROR) << "[BF16] ProcessBF16OutputConversion failed for node: " << qnn_node_name; + return false; + } + + // Add the main node with BF16-converted tensor names + LOGS(logger_, VERBOSE) << "[BF16] Adding main node with converted tensors: " << qnn_node_name; + processed_ops.emplace_back(std::move(qnn_node_name), std::move(package_name), std::move(qnn_node_type), + std::move(converted_input_names), std::move(converted_output_names), + std::move(param_tensor_names)); + + // Add Cast operations for graph outputs to convert BF16 back to FP32 + LOGS(logger_, VERBOSE) << "[BF16] Adding " << graph_output_cast_ops.size() << " output cast operations"; + for (size_t i = 0; i < graph_output_cast_ops.size(); ++i) { + const auto& [bf16_name, fp32_name] = graph_output_cast_ops[i]; + std::string cast_node_name = bf16_name; + LOGS(logger_, VERBOSE) << "[BF16] Adding output Cast op[" << i << "]: " << cast_node_name + << " (" << bf16_name << " -> " << fp32_name << ")"; + + processed_ops.emplace_back(std::move(cast_node_name), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST, + std::vector{bf16_name}, + std::vector{fp32_name}, + std::vector{}); + } + } + + // Prepend input cast ops to the beginning of processed_ops + final_ops.reserve(input_cast_ops.size() + processed_ops.size()); + + for (auto& cast_op : input_cast_ops) { + final_ops.push_back(std::move(cast_op)); + } + + for (auto& op : processed_ops) { + final_ops.push_back(std::move(op)); + } + + return true; +} + bool QnnModelWrapper::ComposeQnnGraph(bool build_json_qnn_graph) { LOGS(logger_, VERBOSE) << "Compose Qnn Graph."; // ORT_RETURN_IF(qnn_op_property_list_.empty(), "Empty Qnn op list, no graph to compose."); @@ -278,7 +541,19 @@ bool QnnModelWrapper::ComposeQnnGraph(bool build_json_qnn_graph) { return false; } - for (const auto& op_property : qnn_op_property_list_) { + // Determine which ops to process + const std::vector* ops_to_process = &qnn_op_property_list_; + std::vector bf16_processed_ops; + + if (IsBF16ConversionEnabled()) { + if (!ProcessBF16Conversions(bf16_processed_ops)) { + return false; + } + ops_to_process = &bf16_processed_ops; + } + + // Create QNN graph ops from the op properties + for (const auto& op_property : *ops_to_process) { std::vector input_tensors; std::vector output_tensors; std::vector params; @@ -606,7 +881,8 @@ void QnnModelWrapper::GetGraphInputOutputTensorWrapper(const std::vector& unpacked_tensor) const { + std::vector& unpacked_tensor, + const bool unpack_4_bit_to_8_bit) const { if (initializer.data_location() == onnx::TensorProto_DataLocation_EXTERNAL) { ORT_RETURN_IF_ERROR(onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer_.ModelPath(), unpacked_tensor)); @@ -616,12 +892,13 @@ Status QnnModelWrapper::UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& int32_t onnx_data_type = initializer.data_type(); - // If this is an int4, we need to unpack it because QNN treats int4 as a full int8. - if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) { + // If this is an int4, + // If unpack_4_bit_to_8_bit is true, we need to unpack it because QNN HTP treats int4 as a full int8. + if (unpack_4_bit_to_8_bit && onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_INT4) { TensorShape shape(qnn::utils::GetInitializerShape(initializer)); const size_t num_int4_elems = shape.Size(); ORT_RETURN_IF_ERROR(qnn::utils::UnpackInt4ToInt8(num_int4_elems, unpacked_tensor)); - } else if (onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { + } else if (unpack_4_bit_to_8_bit && onnx_data_type == ONNX_NAMESPACE::TensorProto_DataType_UINT4) { TensorShape shape(qnn::utils::GetInitializerShape(initializer)); const size_t num_uint4_elems = shape.Size(); ORT_RETURN_IF_ERROR(qnn::utils::UnpackInt4ToInt8(num_uint4_elems, unpacked_tensor)); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h index f0d145c2938c8..c5aaf32dfb274 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.h @@ -19,6 +19,10 @@ namespace onnxruntime { namespace qnn { +// Forward declarations +class QnnModelWrapper; +class BF16ConversionGuard; + // Stores information about an ONNX input or output tensor. // Filled out by QnnModelWrapper::GetTensorInfo() struct TensorInfo { @@ -32,9 +36,13 @@ struct TensorInfo { struct ModelSettings { bool offload_graph_io_quantization = false; bool htp_shared_memory = false; + bool htp_bf16_enable = false; }; class QnnModelWrapper { + // Allow BF16ConversionGuard to access private RestoreFP32AfterValidation method + friend class BF16ConversionGuard; + public: QnnModelWrapper(const GraphViewer& graph_viewer, const logging::Logger& logger, @@ -237,7 +245,8 @@ class QnnModelWrapper { } Status UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer, - std::vector& unpacked_tensor) const; + std::vector& unpacked_tensor, + const bool unpack_4_bit_to_8_bit = true) const; QnnBackendType GetQnnBackendType() const { return qnn_backend_type_; } @@ -323,6 +332,36 @@ class QnnModelWrapper { void GetGraphInputOutputTensorWrapper(const std::vector& names, std::vector& wrappers_list); + // BF16 conversion helper methods + bool IsBF16ConversionEnabled() const { + return model_settings_.htp_bf16_enable && + (qnn_backend_type_ == QnnBackendType::HTP || qnn_backend_type_ == QnnBackendType::SERIALIZER); + } + + bool ProcessBF16InputConversion(const std::string& qnn_node_name, + const std::vector& input_names, + std::vector& converted_input_names, + std::vector& cast_ops_to_add); + + bool ProcessBF16OutputConversion(const std::string& qnn_node_name, + const std::vector& output_names, + std::vector& converted_output_names, + std::vector>& graph_output_cast_ops); + + bool ApplyBF16ConversionForValidation(const std::vector& input_names, + const std::vector& output_names, + std::vector& validation_input_names, + std::vector& validation_output_names); + + void RestoreFP32AfterValidation(const std::vector& input_names, + const std::vector& output_names); + + bool CreateBF16CastTensor(const std::string& tensor_name, + const std::vector& shape, + Qnn_TensorType_t tensor_type); + + bool ProcessBF16Conversions(std::vector& final_ops); + const GraphViewer& graph_viewer_; const logging::Logger& logger_; const QNN_INTERFACE_VER_TYPE& qnn_interface_; @@ -398,5 +437,40 @@ inline Status AddQnnScalar(QnnModelWrapper& qnn_model_wrapper, return Status::OK(); } +// RAII guard to ensure FP32 restoration after BF16 conversion for validation +class BF16ConversionGuard { + public: + BF16ConversionGuard(QnnModelWrapper* wrapper, + const std::vector& input_names, + const std::vector& output_names) + : wrapper_(wrapper), + input_names_(input_names), + output_names_(output_names) {} + + ~BF16ConversionGuard() { + if (wrapper_) { + try { + wrapper_->RestoreFP32AfterValidation(input_names_, output_names_); + } catch (...) { + // Destructors must not throw exceptions + // Silently catch any exceptions during cleanup + } + } + } + + // Prevent copying + BF16ConversionGuard(const BF16ConversionGuard&) = delete; + BF16ConversionGuard& operator=(const BF16ConversionGuard&) = delete; + + // Prevent moving to avoid double-cleanup issues + BF16ConversionGuard(BF16ConversionGuard&&) = delete; + BF16ConversionGuard& operator=(BF16ConversionGuard&&) = delete; + + private: + QnnModelWrapper* wrapper_; + std::vector input_names_; // Store by value, not reference + std::vector output_names_; // Store by value, not reference +}; + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc index 5395e69531336..17f0ff0f2b8dd 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.cc @@ -20,8 +20,11 @@ QnnQuantParamsWrapper::QnnQuantParamsWrapper(const QnnQuantParamsWrapper& other) size_t num_scaleoffsets = 0; if (other.IsLPBQ()) { num_scaleoffsets = other.per_channel_scales_size_; + } else if (other.IsBlockQuantized()) { + block_encoding_tensor_rank_ = other.block_encoding_tensor_rank_; + num_scaleoffsets = other.num_blocks_; } - Status status = Init(other.params_, num_scaleoffsets); + Status status = Init(other.params_, num_scaleoffsets, block_encoding_tensor_rank_); assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding. } @@ -30,8 +33,11 @@ QnnQuantParamsWrapper& QnnQuantParamsWrapper::operator=(const QnnQuantParamsWrap size_t num_scaleoffsets = 0; if (other.IsLPBQ()) { num_scaleoffsets = other.per_channel_scales_size_; + } else if (other.IsBlockQuantized()) { + block_encoding_tensor_rank_ = other.block_encoding_tensor_rank_; + num_scaleoffsets = other.num_blocks_; } - Status status = Init(other.params_, num_scaleoffsets); + Status status = Init(other.params_, num_scaleoffsets, block_encoding_tensor_rank_); assert(status.IsOK()); // Expect other QnnQuantParamsWrapper to always have a supported quantization encoding. } @@ -156,6 +162,39 @@ QnnQuantParamsWrapper::QnnQuantParamsWrapper(gsl::span per_channel_ params_.blockwiseExpansion = lpbqPtr; } +// Construct a BlockEncoding BQ quantization param. +QnnQuantParamsWrapper::QnnQuantParamsWrapper( + gsl::span scales, + gsl::span offsets, + gsl::span block_sizes, + Qnn_DataType_t tensor_data_type) { + ORT_UNUSED_PARAMETER(tensor_data_type); + assert(block_sizes.size() > 0); + assert(scales.size() > 0); + assert(scales.size() == offsets.size()); // Logic error if sizes don't match. + + num_blocks_ = static_cast(scales.size()); + params_.encodingDefinition = QNN_DEFINITION_DEFINED; + params_.quantizationEncoding = QNN_QUANTIZATION_ENCODING_BLOCK; + + block_encoding_tensor_rank_ = static_cast(block_sizes.size()); + block_encoding_axis_data_ = std::make_unique(block_encoding_tensor_rank_); + std::memcpy(block_encoding_axis_data_.get(), + block_sizes.data(), + static_cast(block_encoding_tensor_rank_) * sizeof(uint32_t)); + params_.blockEncoding.blockSize = block_encoding_axis_data_.get(); + + // Deep copy the scale offsets + if (num_blocks_ > 0) { + block_encoding_scale_offsets_data_ = std::make_unique(num_blocks_); + for (size_t i = 0; i < num_blocks_; ++i) { + block_encoding_scale_offsets_data_[i].offset = offsets[i]; + block_encoding_scale_offsets_data_[i].scale = scales[i]; + } + params_.blockEncoding.scaleOffset = block_encoding_scale_offsets_data_.get(); + } +} + // Get a copy of scales. Works for both per-tensor and per-channel. Status QnnQuantParamsWrapper::GetScales(/*out*/ std::vector& scales) const { ORT_RETURN_IF_NOT(params_.encodingDefinition == QNN_DEFINITION_DEFINED, "Unquantized qparams does not have scales"); @@ -195,6 +234,18 @@ Status QnnQuantParamsWrapper::GetScales(/*out*/ std::vector& scales) cons } break; } + case QNN_QUANTIZATION_ENCODING_BLOCK: { + scales.resize(num_blocks_); + + if (num_blocks_ > 0) { + gsl::span scale_offsets(params_.blockEncoding.scaleOffset, num_blocks_); + + for (size_t i = 0; i < num_blocks_; i++) { + scales[i] = scale_offsets[i].scale; + } + } + break; + } default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params_.quantizationEncoding); @@ -208,7 +259,7 @@ QnnQuantParamsWrapper QnnQuantParamsWrapper::Copy() const { } // Initializes by copying from a Qnn_QuantizeParams_t. -Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const size_t lpbq_num_scaleoffsets) { +Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const size_t num_scaleoffsets, const size_t tensor_rank) { if (per_channel_data_) { per_channel_data_.reset(nullptr); params_ = QNN_QUANTIZE_PARAMS_INIT; @@ -278,7 +329,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz break; } case QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION: { - assert(lpbq_num_scaleoffsets && "Can't create BlockwiseExpansion encoding object with zero ScaleOffsets"); + assert(num_scaleoffsets && "Can't create BlockwiseExpansion encoding object with zero ScaleOffsets"); params_.encodingDefinition = params.encodingDefinition; params_.quantizationEncoding = params.quantizationEncoding; @@ -291,7 +342,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz params_.blockwiseExpansion = bwe_aligned_dst; // Deep copy the scaleoffsets - const size_t so_num_elems = lpbq_num_scaleoffsets; + const size_t so_num_elems = num_scaleoffsets; const size_t so_num_bytes = so_num_elems * sizeof(Qnn_ScaleOffset_t); constexpr std::uintptr_t so_align = alignof(Qnn_ScaleOffset_t); per_channel_data_ = std::make_unique(so_num_bytes + so_align); @@ -301,7 +352,7 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz params_.blockwiseExpansion->scaleOffsets = so_aligned_dst; // Deep copy blockscales - const size_t bs_num_elems = lpbq_num_scaleoffsets * params.blockwiseExpansion->numBlocksPerAxis; + const size_t bs_num_elems = num_scaleoffsets * params.blockwiseExpansion->numBlocksPerAxis; const size_t bs_num_bytes = bs_num_elems * sizeof(uint8_t); constexpr std::uintptr_t bs_align = alignof(uint8_t); block_scales_data_ = std::make_unique(bs_num_bytes + bs_align); @@ -310,6 +361,28 @@ Status QnnQuantParamsWrapper::Init(const Qnn_QuantizeParams_t& params, const siz params_.blockwiseExpansion->blocksScale8 = bs_aligned_dst; break; } + case QNN_QUANTIZATION_ENCODING_BLOCK: { + assert(num_scaleoffsets && "Can't create Block encoding object with zero ScaleOffsets"); + params_.encodingDefinition = params.encodingDefinition; + params_.quantizationEncoding = params.quantizationEncoding; + + block_encoding_tensor_rank_ = static_cast(tensor_rank); + block_encoding_axis_data_ = std::make_unique(block_encoding_tensor_rank_); + std::memcpy(block_encoding_axis_data_.get(), + params.blockEncoding.blockSize, + static_cast(block_encoding_tensor_rank_) * sizeof(uint32_t)); + params_.blockEncoding.blockSize = block_encoding_axis_data_.get(); + + // Deep copy the scale offsets + block_encoding_scale_offsets_data_ = std::make_unique(num_scaleoffsets); + for (size_t i = 0; i < num_scaleoffsets; ++i) { + block_encoding_scale_offsets_data_[i].scale = params.blockEncoding.scaleOffset[i].scale; + block_encoding_scale_offsets_data_[i].offset = params.blockEncoding.scaleOffset[i].offset; + } + params_.blockEncoding.scaleOffset = block_encoding_scale_offsets_data_.get(); + + break; + } default: return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported QNN quantization encoding: ", params.quantizationEncoding); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h index a74733037a9d0..a70c329e56c14 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_quant_params_wrapper.h @@ -34,11 +34,16 @@ class QnnQuantParamsWrapper { QnnQuantParamsWrapper(gsl::span per_channel_float_scales, gsl::span per_block_int_scales, gsl::span offsets, int64_t axis, int64_t block_size, bool is_int4); + // Construct a BQ quantization param. + QnnQuantParamsWrapper( + gsl::span scales, gsl::span offsets, + gsl::span block_size, Qnn_DataType_t tensor_data_type); + Qnn_QuantizeParams_t& Get() { return params_; } const Qnn_QuantizeParams_t& Get() const { return params_; } // Initialize this object from a raw Qnn_QuantizeParam_t object. - Status Init(const Qnn_QuantizeParams_t& params, const size_t lpbq_num_scaleoffsets = 0); + Status Init(const Qnn_QuantizeParams_t& params, const size_t num_scaleoffsets = 0, const size_t tensor_rank = 0); // Initialize this object from a (potentially) quantized ONNX tensor. // QnnModelWrapper provides utilities for unpacking scale and zero-point ONNX initializers. @@ -67,6 +72,11 @@ class QnnQuantParamsWrapper { (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCKWISE_EXPANSION); } + bool IsBlockQuantized() const { + return params_.encodingDefinition == QNN_DEFINITION_DEFINED && + (params_.quantizationEncoding == QNN_QUANTIZATION_ENCODING_BLOCK); + } + // Get a copy of scales. Works for both per-tensor and per-channel. Status GetScales(/*out*/ std::vector& scales) const; @@ -163,6 +173,12 @@ class QnnQuantParamsWrapper { uint32_t per_channel_scales_size_; std::unique_ptr block_scales_data_; std::unique_ptr blockwise_expansion_data_; + + // Stores BlockEncoding axis and scale offset data + uint32_t block_encoding_tensor_rank_ = 0; + uint32_t num_blocks_ = 0; + std::unique_ptr block_encoding_axis_data_; + std::unique_ptr block_encoding_scale_offsets_data_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index e4f9d490678b8..4d7189d672af7 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -34,10 +34,13 @@ size_t GetElementSizeByType(const Qnn_DataType_t& data_type) { {QNN_DATATYPE_UINT_64, 8}, {QNN_DATATYPE_FLOAT_16, 2}, {QNN_DATATYPE_FLOAT_32, 4}, + {QNN_DATATYPE_BFLOAT_16, 2}, {QNN_DATATYPE_BOOL_8, 1}, + {QNN_DATATYPE_SFIXED_POINT_4, sizeof(Int4x2)}, {QNN_DATATYPE_SFIXED_POINT_8, 1}, {QNN_DATATYPE_SFIXED_POINT_16, 2}, {QNN_DATATYPE_SFIXED_POINT_32, 4}, + {QNN_DATATYPE_UFIXED_POINT_4, sizeof(Int4x2)}, {QNN_DATATYPE_UFIXED_POINT_8, 1}, {QNN_DATATYPE_UFIXED_POINT_16, 2}, {QNN_DATATYPE_UFIXED_POINT_32, 4}, @@ -104,11 +107,25 @@ size_t GetElementSizeByType(ONNX_NAMESPACE::TensorProto_DataType onnx_type) { } // Unreachable } +size_t GetQnnTensorDataSizeInBytes(size_t num_elements, Qnn_DataType_t element_type) { + SafeInt safe_num_elements = num_elements; + if (element_type == QNN_DATATYPE_SFIXED_POINT_4 || element_type == QNN_DATATYPE_UFIXED_POINT_4) { + return (safe_num_elements + 1) / 2; + } + return (safe_num_elements * GetElementSizeByType(element_type)); +} size_t GetQnnTensorDataSizeInBytes(gsl::span shape, Qnn_DataType_t element_type) { ORT_ENFORCE(!shape.empty(), "Empty shape not allowed."); // TODO can we just treat empty shape as a scalar? - SafeInt data_length = GetElementSizeByType(element_type); - return std::accumulate(shape.begin(), shape.end(), data_length, std::multiplies<>{}); + SafeInt num_elements = std::accumulate(shape.begin(), shape.end(), SafeInt{1}, std::multiplies<>{}); + return GetQnnTensorDataSizeInBytes(num_elements, element_type); +} + +size_t GetQnnTensorDataSizeInBytes(const Qnn_Tensor_t& tensor) { + uint32_t rank = GetQnnTensorRank(tensor); + uint32_t* dims = GetQnnTensorDims(tensor); + gsl::span shape{dims, static_cast(rank)}; + return GetQnnTensorDataSizeInBytes(shape, GetQnnTensorDataType(tensor)); } bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor) { @@ -202,6 +219,9 @@ std::ostream& operator<<(std::ostream& out, const Qnn_DataType_t& data_type) { case QNN_DATATYPE_FLOAT_32: out << "QNN_DATATYPE_FLOAT_32"; break; + case QNN_DATATYPE_BFLOAT_16: + out << "QNN_DATATYPE_BFLOAT_16"; + break; case QNN_DATATYPE_SFIXED_POINT_8: out << "QNN_DATATYPE_SFIXED_POINT_8"; break; @@ -995,7 +1015,7 @@ Status QuantizeData(gsl::span data, gsl::span shape const size_t num_dims = shape.size(); const size_t num_elems = ShapeSizeCalc(shape, 0, num_dims); ORT_RETURN_IF_NOT(num_elems == data.size(), "Shape mismatch with data to quantize"); - size_t expected_num_quant_bytes = GetElementSizeByType(data_type) * data.size(); + size_t expected_num_quant_bytes = GetQnnTensorDataSizeInBytes(data.size(), data_type); ORT_RETURN_IF_NOT(quant_bytes.size() == expected_num_quant_bytes, "Cannot quantize data because output buffer is not the correct size"); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index 3d46aa5e3c9ae..6e188d5d41260 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -78,7 +78,9 @@ class QnnJSONGraph { size_t GetElementSizeByType(ONNX_NAMESPACE::TensorProto_DataType onnx_type); +size_t GetQnnTensorDataSizeInBytes(size_t num_elements, Qnn_DataType_t element_data_type); size_t GetQnnTensorDataSizeInBytes(gsl::span shape, Qnn_DataType_t element_data_type); +size_t GetQnnTensorDataSizeInBytes(const Qnn_Tensor_t& tensor); bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor); diff --git a/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc new file mode 100644 index 0000000000000..71f562d59d847 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.cc @@ -0,0 +1,113 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_windows_file_mapper.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + +#include + +#include + +#include + +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +WindowsFileMapper::WindowsFileMapper(const logging::Logger& logger) + : logger_(&logger) { +} + +WindowsFileMapper::~WindowsFileMapper() { +} + +static void UnmapFile(void* addr) noexcept { + bool successful = UnmapViewOfFile(addr); + if (!successful) { + const auto error_code = GetLastError(); + LOGS_DEFAULT(ERROR) << "Failed to unmap view of file with ptr: " << addr + << ", Error code: " << error_code << ", \"" + << std::system_category().message(error_code) << "\""; + } +} + +Status WindowsFileMapper::GetContextBinMappedMemoryPtr(const std::string& bin_filepath, + void** mapped_data_ptr) { + LOGS(*logger_, INFO) << "Creating context bin file mapping for " + << bin_filepath; + + ORT_RETURN_IF(bin_filepath.empty(), "Context bin file path is empty"); + + std::lock_guard lock(map_mutex_); + auto map_it = mapped_memory_ptrs_.find(bin_filepath); + if (map_it != mapped_memory_ptrs_.end()) { + *mapped_data_ptr = map_it->second.get(); + LOGS(*logger_, INFO) << "Found existing mapview memory pointer (" << mapped_data_ptr + << ") for context bin file: " << bin_filepath; + return Status::OK(); + } + + std::wstring bin_filepath_wstr(bin_filepath.begin(), bin_filepath.end()); + wil::unique_hfile file_handle{CreateFile2(bin_filepath_wstr.c_str(), + GENERIC_READ, + FILE_SHARE_READ, + OPEN_EXISTING, + NULL)}; + if (file_handle.get() == INVALID_HANDLE_VALUE) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to create file handle for context bin", bin_filepath, + ". Error code: ", error_code, ", \"", + std::system_category().message(error_code), "\""); + } + + LOGS(*logger_, VERBOSE) << "Created file handle (" << file_handle.get() << ") for context bin: " + << bin_filepath; + + wil::unique_hfile file_mapping_handle{CreateFileMappingW(file_handle.get(), + nullptr, + PAGE_READONLY, + 0x00, + 0x00, + nullptr)}; + if (file_mapping_handle.get() == INVALID_HANDLE_VALUE) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to create file mapping handle for context bin", + bin_filepath, ". Error code: ", error_code, ", \"", + std::system_category().message(error_code), "\""); + } + + LOGS(*logger_, VERBOSE) << "Created file mapping with handle (" << file_mapping_handle.get() + << ") for context bin:" << bin_filepath; + + void* const mapped_base_ptr = MapViewOfFile(file_mapping_handle.get(), + FILE_MAP_READ, + 0, 0, 0); + + if (mapped_base_ptr == nullptr) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, + "Failed to retrieve mapview pointer for context bin", + bin_filepath, ". Error code: ", error_code, ", \"", + std::system_category().message(error_code), "\""); + } + + LOGS(*logger_, INFO) << "Created mapview pointer with address " << mapped_base_ptr + << " for context bin " << bin_filepath; + + onnxruntime::Env::MappedMemoryPtr mapped_memory_ptr{reinterpret_cast(mapped_base_ptr), + [mapped_base_ptr](void*) { + UnmapFile(mapped_base_ptr); + }}; + + *mapped_data_ptr = mapped_memory_ptr.get(); + mapped_memory_ptrs_.emplace(bin_filepath, std::move(mapped_memory_ptr)); + + return Status::OK(); +} +} // namespace qnn +} // namespace onnxruntime + +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE diff --git a/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h new file mode 100644 index 0000000000000..742255b26f07d --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_windows_file_mapper.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/qnn/builder/qnn_file_mapping_interface.h" +#ifdef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + +#include +#include +#include +#include + +#include + +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class WindowsFileMapper : public FileMappingInterface { + public: + explicit WindowsFileMapper(const logging::Logger& logger); + ~WindowsFileMapper() override; + + // Creates a file mapping of the context binary and returns the + // mapview pointer of the file mapping + Status GetContextBinMappedMemoryPtr(const std::string& bin_filepath, + void** mapped_data_ptr) override; + + private: + // A container of smart pointers of mapview memory pointers to mapped context bins + // key: filepath to context bin, value: smart pointer of mapview memory pointers + std::mutex map_mutex_; + std::unordered_map mapped_memory_ptrs_; + const logging::Logger* logger_; +}; + +} // namespace qnn +} // namespace onnxruntime + +#endif // QNN_FILE_MAPPED_WEIGHTS_AVAILABLE diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 737216b81139c..c3d8328b37411 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -475,6 +475,21 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio #endif } + static const std::string DISABLE_FILE_MAPPED_WEIGHTS = "disable_file_mapped_weights"; + auto disable_file_mapped_weights_pos = provider_options_map.find(DISABLE_FILE_MAPPED_WEIGHTS); + if (disable_file_mapped_weights_pos != provider_options_map.end()) { + if ("1" == disable_file_mapped_weights_pos->second) { + enable_file_mapped_weights_ = false; + } + LOGS_DEFAULT(VERBOSE) << "User specified disable_file_mapped_weights: " << enable_file_mapped_weights_; + } + +#ifndef QNN_FILE_MAPPED_WEIGHTS_AVAILABLE + enable_file_mapped_weights_ = false; + LOGS_DEFAULT(WARNING) << "File mapped weights feature is only available on Windows arm64 devices for QNN API versions >= 2.32. " + << "Feature will be disabled by default"; +#endif + static const std::string QNN_DEVICE_ID = "device_id"; auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); if (dev_id_pos != provider_options_map.end()) { @@ -541,6 +556,8 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio model_settings_.offload_graph_io_quantization = ParseBoolOption("offload_graph_io_quantization", true, provider_options_map); + model_settings_.htp_bf16_enable = ParseBoolOption("htp_bf16_enable", false, provider_options_map); + if (disable_cpu_ep_fallback_ && model_settings_.offload_graph_io_quantization) { LOGS_DEFAULT(INFO) << "Fallback to CPU EP is disabled, but user tried to configure QNN EP to offload graph I/O " << "quantization/dequantization to another EP. These are conflicting options. Fallback to CPU " @@ -550,11 +567,24 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } static const std::string QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED = "enable_htp_shared_memory_allocator"; - if (ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map)) { + enable_htp_shared_mem_allocator_ = ParseBoolOption(QNN_HTP_SHARED_MEMORY_ALLOCATOR_ENABLED, false, provider_options_map); + if (enable_htp_shared_mem_allocator_) { // Initialize rpcmem_library_. // This is necessary for HtpSharedMemoryAllocator to function and also indicates that the allocator is available. rpcmem_library_ = std::make_shared(); - model_settings_.htp_shared_memory = true; + model_settings_.htp_shared_memory = enable_htp_shared_mem_allocator_; + } + + if (enable_file_mapped_weights_ && !rpcmem_library_) { + // Attempt to init rpcmem_library_ if needed. If this fails, then + // disable file mapped weights and proceed with normal operation + try { + rpcmem_library_ = std::make_shared(); + } catch (const std::exception& e) { + LOGS_DEFAULT(WARNING) << "Unable to load RPCMem library: " << e.what() + << " - Disabling file mapped weights."; + enable_file_mapped_weights_ = false; + } } dump_json_qnn_graph_ = ParseBoolOption("dump_json_qnn_graph", false, provider_options_map); @@ -572,6 +602,19 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } + static const std::string QNN_HTP_EXTENDED_UDMA_MODE = "extended_udma"; + auto htp_extended_udma_pos = provider_options_map.find(QNN_HTP_EXTENDED_UDMA_MODE); + if (htp_extended_udma_pos != provider_options_map.end()) { + if ("1" == htp_extended_udma_pos->second) { + enable_htp_extended_udma_mode_ = true; + } else if ("0" == htp_extended_udma_pos->second) { + enable_htp_extended_udma_mode_ = false; + } else { + LOGS_DEFAULT(WARNING) << "Invalid extended_udma mode: " << enable_htp_extended_udma_mode_ << " only 0 or 1 allowed. Set to 0."; + } + LOGS_DEFAULT(VERBOSE) << "User specified extended_udma mode: " << enable_htp_extended_udma_mode_; + } + // Option to skip QNN API interface version check to use other QNN library other than default. static const std::string SKIP_QNN_VERSION_CHECK = "skip_qnn_version_check"; auto skip_qnn_version_check = ParseBoolOption(SKIP_QNN_VERSION_CHECK, false, provider_options_map); @@ -906,6 +949,25 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer const size_t num_nodes_in_graph = static_cast(graph_viewer.NumberOfNodes()); const auto& logger = *GetLogger(); + + // Check BF16 compatibility early + if (model_settings_.htp_bf16_enable) { + // Check SoC model + uint32_t soc_model = qnn_backend_manager_->GetSocModel(); + if (soc_model == QNN_SOC_MODEL_UNKNOWN) { + LOGS(logger, WARNING) << "BF16 mode is enabled but soc_model is not specified. " + << "Both parameters must be set together for BF16 support. " + << "QNN EP will not handle any nodes."; + return result; // Empty result means QNN EP won't handle any nodes + } else if (soc_model < 88) { + LOGS(logger, WARNING) << "BF16 mode is enabled but SoC model is " << soc_model + << " (expected 88 and above). QNN EP will not handle any nodes."; + return result; // Empty result means QNN EP won't handle any nodes + } + + LOGS(logger, INFO) << "BF16 mode enabled with compatible hardware: SoC " << soc_model; + } + bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); const auto gen_metadef_name = [&]() { @@ -926,7 +988,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer } std::unordered_map>> context_bin_map; - if (enable_vtcm_backup_buffer_sharing_) { + if (enable_vtcm_backup_buffer_sharing_ || enable_file_mapped_weights_) { std::unordered_set ep_ctx_nodes; GetMainEPCtxNodes(graph_viewer, ep_ctx_nodes, logger); @@ -939,7 +1001,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer NodeAttrHelper node_helper(*ep_ctx_node); std::string context_bin_filepath(parent_path.string()); context_bin_filepath.append("/").append(node_helper.Get(qnn::EP_CACHE_CONTEXT, "")); - if (context_bin_map.find(context_bin_filepath) == context_bin_map.end()) { context_bin_map.emplace(context_bin_filepath, std::make_unique>()); // Push context bin filepath for lookup between sessions @@ -956,7 +1017,10 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer context_cache_enabled_ && enable_spill_fill_buffer_, share_ep_contexts_, enable_vtcm_backup_buffer_sharing_, - context_bin_map); + enable_file_mapped_weights_, + rpcmem_library_, + context_bin_map, + enable_htp_extended_udma_mode_); context_bin_map.clear(); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index dd301d7915935..c5d41789e7a1f 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -80,7 +80,7 @@ class QNNExecutionProvider : public IExecutionProvider { qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); - bool IsHtpSharedMemoryAllocatorAvailable() const { return rpcmem_library_ != nullptr; } + bool IsHtpSharedMemoryAllocatorAvailable() const { return enable_htp_shared_mem_allocator_ && rpcmem_library_ != nullptr; } private: // Will return true if any power config options need to be updated @@ -119,12 +119,15 @@ class QNNExecutionProvider : public IExecutionProvider { bool share_ep_contexts_ = false; bool stop_share_ep_contexts_ = false; bool enable_spill_fill_buffer_ = false; + bool enable_file_mapped_weights_ = true; + bool enable_htp_shared_mem_allocator_ = false; #if defined(_WIN32) onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback callback_ETWSink_provider_ = nullptr; #endif qnn::ModelSettings model_settings_ = {}; bool dump_json_qnn_graph_ = false; std::string json_qnn_graph_dir_ = ""; + bool enable_htp_extended_udma_mode_ = false; // Whether this is set depends on a session option enabling it and if the RPCMEM dynamic library is available. // This is potentially shared with HtpSharedMemoryAllocator which may be returned by CreatePreferredAllocators(). diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.cc b/onnxruntime/core/providers/qnn/rpcmem_library.cc index 20918f8bc6de1..f89a15157ddf4 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.cc +++ b/onnxruntime/core/providers/qnn/rpcmem_library.cc @@ -165,6 +165,8 @@ RpcMemApi CreateApi(void* library_handle) { ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "rpcmem_to_fd", (void**)&api.to_fd)); + ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(library_handle, "remote_register_buf_attr2", (void**)&api.register_buf)); + return api; } diff --git a/onnxruntime/core/providers/qnn/rpcmem_library.h b/onnxruntime/core/providers/qnn/rpcmem_library.h index 2746e147373bb..0f4b5b5391f59 100644 --- a/onnxruntime/core/providers/qnn/rpcmem_library.h +++ b/onnxruntime/core/providers/qnn/rpcmem_library.h @@ -24,6 +24,9 @@ constexpr uint32_t RPCMEM_DEFAULT_FLAGS = 1; constexpr int RPCMEM_HEAP_ID_SYSTEM = 25; +constexpr int RPCMEM_ATTR_IMPORT_BUFFER = 256; +constexpr int RPCMEM_ATTR_READ_ONLY = 512; + /** * Allocate a zero-copy buffer for size upto 2 GB with the FastRPC framework. * Buffers larger than 2 GB must be allocated with rpcmem_alloc2 @@ -46,6 +49,17 @@ using FreeFnPtr = void (*)(void* po); */ using ToFdFnPtr = int (*)(void* po); +/** + * Registers and maps a CPU buffer to RPC memory space + * @param[in] buff Data pointer for a CPU-allocated buffer + * @param[in] size Size of the buffer in bytes + * @param[in] fd File descriptor for a CPU-allocated buffer + * Note: Can be NULL if N/A or -1 to signal deregistration + * @param[in] attr Specified attributes for the buffer + * @return Data pointer for an RPCMEM-allocated buffer + */ +using RegisterBufFnPtr = void (*)(void* buff, size_t size, int fd, int attr); + } // namespace rpcmem // RPCMEM API function pointers. @@ -53,6 +67,7 @@ struct RpcMemApi { rpcmem::AllocFnPtr alloc; rpcmem::FreeFnPtr free; rpcmem::ToFdFnPtr to_fd; + rpcmem::RegisterBufFnPtr register_buf; }; // Loads and provides access to the RPCMEM API functions from a dynamically loaded library. diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 2a9a8127874ee..1ed78c89e722d 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -30,6 +30,7 @@ #include "core/common/float8.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "core/framework/float4.h" #include "core/framework/tensor_shape.h" #include "core/providers/providers.h" @@ -80,6 +81,8 @@ enum TensorProto_DataType : int { TensorProto_DataType_INT4 = 22, TensorProto_DataType_FLOAT4E2M1 = 23, TensorProto_DataType_FLOAT8E8M0 = 24, + TensorProto_DataType_UINT2 = 25, + TensorProto_DataType_INT2 = 26, }; enum TensorProto_DataLocation : int { @@ -410,6 +413,15 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4; } +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2; +} +template <> +constexpr ONNXTensorElementDataType GetONNXTensorElementDataType() { + return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2; +} + inline std::vector> CreateSupportedPartitions(const GraphViewer& graph_viewer, const std::unordered_set& supported_nodes, @@ -441,11 +453,6 @@ inline bool HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto return g_host->Utils__HasExternalDataInMemory(ten_proto); } -inline Status ValidateExternalDataPath(const std::filesystem::path& base_dir, - const std::filesystem::path& location) { - return g_host->Utils__ValidateExternalDataPath(base_dir, location); -} - } // namespace utils namespace graph_utils { diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 5732984af29b4..ee00a06751d0a 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -172,6 +172,10 @@ template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Int4x2(); } template <> MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_UInt4x2(); } +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_Int2x4(); } +template <> +MLDataType DataTypeImpl::GetType() { return Provider_GetHost()->DataTypeImpl__GetType_UInt2x4(); } #if !defined(DISABLE_FLOAT4_TYPES) template <> @@ -222,6 +226,10 @@ template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Int4x2(); } template <> MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_UInt4x2(); } +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_Int2x4(); } +template <> +MLDataType DataTypeImpl::GetTensorType() { return Provider_GetHost()->DataTypeImpl__GetTensorType_UInt2x4(); } #if !defined(DISABLE_FLOAT4_TYPES) template <> diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 51bd2c467acec..9cbbc6234a99b 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -775,6 +775,8 @@ struct ProviderHost { #endif virtual MLDataType DataTypeImpl__GetType_Int4x2() = 0; virtual MLDataType DataTypeImpl__GetType_UInt4x2() = 0; + virtual MLDataType DataTypeImpl__GetType_Int2x4() = 0; + virtual MLDataType DataTypeImpl__GetType_UInt2x4() = 0; virtual MLDataType DataTypeImpl__GetTensorTypeFromOnnxType(int) = 0; virtual MLDataType DataTypeImpl__GetTensorType_bool() = 0; @@ -802,6 +804,8 @@ struct ProviderHost { virtual MLDataType DataTypeImpl__GetTensorType_Int4x2() = 0; virtual MLDataType DataTypeImpl__GetTensorType_UInt4x2() = 0; + virtual MLDataType DataTypeImpl__GetTensorType_Int2x4() = 0; + virtual MLDataType DataTypeImpl__GetTensorType_UInt2x4() = 0; #if !defined(DISABLE_SPARSE_TENSORS) virtual MLDataType DataTypeImpl__GetSparseTensorType_bool() = 0; @@ -1000,9 +1004,6 @@ struct ProviderHost { virtual bool Utils__HasExternalDataInMemory(const ONNX_NAMESPACE::TensorProto& ten_proto) = 0; - virtual Status Utils__ValidateExternalDataPath(const std::filesystem::path& base_path, - const std::filesystem::path& location) = 0; - // Model virtual std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, @@ -1260,6 +1261,8 @@ struct ProviderHost { #endif virtual Int4x2* Tensor__MutableData_Int4x2(Tensor* p) = 0; virtual UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) = 0; + virtual Int2x4* Tensor__MutableData_Int2x4(Tensor* p) = 0; + virtual UInt2x4* Tensor__MutableData_UInt2x4(Tensor* p) = 0; virtual const bool* Tensor__Data_bool(const Tensor* p) = 0; virtual const int8_t* Tensor__Data_int8(const Tensor* p) = 0; @@ -1286,6 +1289,8 @@ struct ProviderHost { #endif virtual const Int4x2* Tensor__Data_Int4x2(const Tensor* p) = 0; virtual const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) = 0; + virtual const Int2x4* Tensor__Data_Int2x4(const Tensor* p) = 0; + virtual const UInt2x4* Tensor__Data_UInt2x4(const Tensor* p) = 0; virtual gsl::span Tensor__DataAsSpan_int64(const Tensor* p) = 0; @@ -1322,6 +1327,8 @@ struct ProviderHost { #endif virtual bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept = 0; virtual bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_Int2x4(const Tensor* p) noexcept = 0; + virtual bool Tensor__IsDataType_UInt2x4(const Tensor* p) noexcept = 0; virtual const TensorShape& Tensor__Shape(const Tensor* p) = 0; virtual void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 0ab7ee0aedd1a..041cf764e7ede 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1524,6 +1524,10 @@ inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataTy template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_UInt4x2(this); } template <> +inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_Int2x4(this); } +template <> +inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_UInt2x4(this); } +template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_int8(this); } template <> inline bool Tensor::IsDataType() const { return g_host->Tensor__IsDataType_uint8(this); } @@ -1571,6 +1575,10 @@ inline Int4x2* Tensor::MutableData() { return g_host->Tensor__MutableDat template <> inline UInt4x2* Tensor::MutableData() { return g_host->Tensor__MutableData_UInt4x2(this); } template <> +inline Int2x4* Tensor::MutableData() { return g_host->Tensor__MutableData_Int2x4(this); } +template <> +inline UInt2x4* Tensor::MutableData() { return g_host->Tensor__MutableData_UInt2x4(this); } +template <> inline int8_t* Tensor::MutableData() { return g_host->Tensor__MutableData_int8(this); } template <> inline uint8_t* Tensor::MutableData() { return g_host->Tensor__MutableData_uint8(this); } @@ -1618,6 +1626,10 @@ inline const Int4x2* Tensor::Data() const { return g_host->Tensor__Data_ template <> inline const UInt4x2* Tensor::Data() const { return g_host->Tensor__Data_UInt4x2(this); } template <> +inline const Int2x4* Tensor::Data() const { return g_host->Tensor__Data_Int2x4(this); } +template <> +inline const UInt2x4* Tensor::Data() const { return g_host->Tensor__Data_UInt2x4(this); } +template <> inline const int8_t* Tensor::Data() const { return g_host->Tensor__Data_int8(this); } template <> inline const uint8_t* Tensor::Data() const { return g_host->Tensor__Data_uint8(this); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index ec529c2ad1fc2..0c1c930132da3 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -386,7 +386,6 @@ void deinitialize_vitisai_ep() { s_domains_vitisaiep.clear(); s_library_vitisaiep.Clear(); - s_kernel_registry_vitisaiep.reset(); } static void set_version_info(vaip_core::OrtApiForVaip& api) { diff --git a/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc index 84a0afd873d23..c3842a5c875e3 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc @@ -57,6 +57,11 @@ Status ConvTranspose::ComputeInternal(ComputeContext& context) bool has_bias = context.InputCount() > 2; const auto* bias = has_bias ? context.Input(2) : nullptr; + // Validate bias shape if provided + if (has_bias && (bias->Shape().NumDimensions() != 1 || bias->Shape()[0] != num_output_channels)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "invalid bias"); + } + if (input_shape.NumDimensions() == 3 && filter_shape.NumDimensions() == 3) { // ConvTranspose1D TensorShapeVector input_shape_vector = input_shape.AsShapeVector(); diff --git a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template index 2f64525469561..b7313903897e1 100644 --- a/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template +++ b/onnxruntime/core/providers/webgpu/nn/im2col_matmul.wgsl.template @@ -33,7 +33,7 @@ fn load_src(batch : u32, m : u32, k_packed_idx : u32) -> src_value_t { // 4. Calculate the coordinate in the original input tensor let src_h_coord : i32 = i32(src_h_coord_padded) - i32(uniforms.pads.x); - let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.z); + let src_w_coord : i32 = i32(src_w_coord_padded) - i32(uniforms.pads.y); // 5. Check for padding/out-of-bounds if (src_h_coord < 0 || src_h_coord >= i32(uniforms.src_h) || diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 7cb6a852e8d7e..8b8d884a35281 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -138,10 +138,10 @@ void WebGpuContext::Initialize(const WebGpuContextConfig& config) { config.buffer_cache_config.uniform.mode, config.buffer_cache_config.query_resolve.mode); - // create initializer buffer manager. cache is always disabled for initializer buffer manager + // create initializer buffer manager. initializer_buffer_mgr_ = BufferManagerFactory::Create(*this, - BufferCacheMode::Disabled, - BufferCacheMode::Disabled, + BufferCacheMode::LazyRelease, + BufferCacheMode::LazyRelease, BufferCacheMode::Disabled); // create program manager diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc index 8303d2ff4293f..8a52b7a188fd5 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.cc @@ -49,6 +49,12 @@ Status WebGpuKernel::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr / Status s = PrePackInternal(context, tensor, input_idx, ep_.PrepackAllocator(), is_packed); + if (is_packed) { + // Flush pending commands to ensure GPU buffer creations are completed. + // This allows the initializer buffer manager to release temporary buffers and reduce memory usage. + webgpu_context_.Flush(webgpu_context_.InitializerBufferManager()); + } + if (webgpu_context_.ValidationMode() >= ValidationMode::Full) { ORT_RETURN_IF_ERROR(webgpu_context_.PopErrorScope()); } diff --git a/onnxruntime/core/session/ep_graph_assignment_info.h b/onnxruntime/core/session/ep_graph_assignment_info.h new file mode 100644 index 0000000000000..a3b98b2e6315a --- /dev/null +++ b/onnxruntime/core/session/ep_graph_assignment_info.h @@ -0,0 +1,34 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#if !defined(ORT_MINIMAL_BUILD) +#include +#include +#include "core/common/common.h" + +/// +/// Contains information about a node assigned to an EP. This is the definition of an opaque struct in the C API. +/// +struct OrtEpAssignedNode { + std::string name; + std::string domain; + std::string op_type; +}; + +/// +/// Contains information about a subgraph assigned to an EP by the session graph partitioner. +/// This is the definition of an opaque struct in the C API. +/// +struct OrtEpAssignedSubgraph { + OrtEpAssignedSubgraph() = default; + OrtEpAssignedSubgraph(OrtEpAssignedSubgraph&&) = default; + OrtEpAssignedSubgraph& operator=(OrtEpAssignedSubgraph&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(OrtEpAssignedSubgraph); + + std::string ep_name; + std::vector> nodes_storage; + std::vector nodes; +}; +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e9b0c2f230263..0c9b3c0663b5c 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -34,6 +34,7 @@ #include "core/framework/tensor_type_and_shape.h" #include "core/framework/op_kernel_context_internal.h" #include "core/framework/ort_value_pattern_planner.h" +#include "core/framework/plugin_ep_stream.h" #include "core/framework/transform_layout_functions.h" #include "core/framework/utils.h" #include "core/graph/graph_viewer.h" @@ -1287,6 +1288,32 @@ common::Status InferenceSession::ApplyUpdates(const OrtModel& model_editor_api_m return model_->MainGraph().UpdateUsingModelEditorApiModel(model_editor_api_model); } +#if !defined(ORT_MINIMAL_BUILD) +static std::unique_ptr CreateEpAssignedSubgraph(const Graph& graph, + const ComputeCapability& capability, + const std::string& ep_name) { + auto assigned_subgraph = std::make_unique(); + assigned_subgraph->ep_name = ep_name; + + gsl::span node_indices = capability.sub_graph->nodes; + + for (NodeIndex node_index : node_indices) { + const Node* node = graph.GetNode(node_index); + if (node != nullptr) { + auto assigned_node = std::make_unique(); + assigned_node->name = node->Name(); + assigned_node->domain = node->Domain(); + assigned_node->op_type = node->OpType(); + + assigned_subgraph->nodes.push_back(assigned_node.get()); + assigned_subgraph->nodes_storage.push_back(std::move(assigned_node)); + } + } + + return assigned_subgraph; +} +#endif // !defined(ORT_MINIMAL_BUILD) + common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. @@ -1301,12 +1328,29 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // 8. Repeat steps 5 to 7 depending on the graph optimizations loop level. // 9. insert copy nodes (required transformer). + OnPartitionAssignmentFunction on_partition_assignment_fn; +#if !defined(ORT_MINIMAL_BUILD) + bool record_ep_graph_assignment = + session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsRecordEpGraphAssignmentInfo, "0") == "1"; + if (record_ep_graph_assignment) { + on_partition_assignment_fn = [this](const Graph& graph, const ComputeCapability& compute_capability, + const std::string& ep_name) { + std::unique_ptr assigned_subgraph = CreateEpAssignedSubgraph(graph, + compute_capability, + ep_name); + + this->ep_graph_assignment_info_.push_back(assigned_subgraph.get()); + this->ep_graph_assignment_info_storage_.push_back(std::move(assigned_subgraph)); + }; + } +#endif // !defined(ORT_MINIMAL_BUILD) + // Create GraphOptimizerRegistry instance for providing predefined graph optimizers and selection functions for EPs to lookup auto graph_optimizer_registry = std::make_unique(&session_options_, execution_providers_.Get(onnxruntime::kCpuExecutionProvider), session_logger_); GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_, std::move(graph_optimizer_registry), - check_load_cancellation_fn_); + check_load_cancellation_fn_, on_partition_assignment_fn); // Run Ahead Of time function inlining if (const bool disable_aot_function_inlining = @@ -3050,6 +3094,15 @@ Status InferenceSession::Run(const RunOptions& run_options, #ifdef ORT_ENABLE_STREAM DeviceStreamCollectionHolder device_stream_collection_holder(session_state_.get()); + if (run_options.sync_stream != nullptr) { + if (session_options_.execution_mode != ExecutionMode::ORT_SEQUENTIAL) { + // XXX: Not tested in Parallel execution mode and disabled at this time. + LOGS(*session_logger_, WARNING) << "Setting sync stream is not supported in parallel execution mode."; + } else { + ORT_RETURN_IF_ERROR_SESSIONID_( + device_stream_collection_holder.p_->SetStreamOverride(run_options.sync_stream)); + } + } #endif if (retval.IsOK()) { diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index ff54f6fa7bca0..1dbf0318c988c 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -30,6 +30,7 @@ #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" +#include "core/session/ep_graph_assignment_info.h" #include #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" @@ -504,6 +505,19 @@ class InferenceSession { */ const std::vector& GetRegisteredProviderTypes() const; + /** + * Get the registered Execution Providers. + * + * This method can be called after EP registration but before Initialize() completes. + * Used only for early validation of compiled model compatibility where accessing + * EPs through session state is not yet possible. + * + * @return const reference to the ExecutionProviders collection. + */ + const ExecutionProviders& GetExecutionProviders() const noexcept { + return execution_providers_; + } + /* * Get the options this session was initialized with. */ @@ -662,6 +676,12 @@ class InferenceSession { return session_id_; } +#if !defined(ORT_MINIMAL_BUILD) + const std::vector& GetEpGraphAssignmentInfo() const { + return this->ep_graph_assignment_info_; + } +#endif // !defined(ORT_MINIMAL_BUILD) + protected: #if !defined(ORT_MINIMAL_BUILD) @@ -1054,6 +1074,13 @@ class InferenceSession { // Enable nodestats collection std::optional node_stats_recorder_; #endif + +#if !defined(ORT_MINIMAL_BUILD) + // Information about the subgraphs/nodes assigned to each EP. + // A user gets this information via the OrtApi::GetEpGraphAssignmentInfo C API function. + std::vector> ep_graph_assignment_info_storage_; + std::vector ep_graph_assignment_info_; +#endif // !defined(ORT_MINIMAL_BUILD) }; struct SessionIOBinding { diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index afb17f867fc00..aa6a69adb549c 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -3,11 +3,12 @@ #include #include +#include #include #include #include -#include #include +#include #include "core/common/common.h" #include "core/common/logging/logging.h" @@ -30,15 +31,19 @@ #include "core/framework/utils.h" #include "core/graph/constants.h" #include "core/graph/graph.h" +#include "core/graph/model.h" #include "core/graph/model_editor_api_types.h" #include "core/graph/ep_api_types.h" +#include "core/graph/onnx_protobuf.h" #include "core/providers/get_execution_providers.h" #include "core/session/abi_devices.h" #include "core/session/abi_session_options_impl.h" #include "core/session/allocator_adapters.h" #include "core/session/compile_api.h" #include "core/session/environment.h" +#include "core/session/ep_graph_assignment_info.h" #include "core/session/interop_api.h" +#include "core/session/onnxruntime_ep_device_ep_metadata_keys.h" #include "core/session/plugin_ep/ep_api.h" #include "core/session/plugin_ep/ep_library_internal.h" #include "core/session/inference_session.h" @@ -47,6 +52,7 @@ #include "core/session/lora_adapters.h" #include "core/session/model_editor_api.h" #include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "core/session/ort_apis.h" #include "core/session/ort_env.h" #include "core/session/utils.h" @@ -927,6 +933,153 @@ ORT_API_STATUS_IMPL(OrtApis::RunAsync, _Inout_ OrtSession* sess, _In_opt_ const API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::Session_GetEpGraphAssignmentInfo, _In_ const OrtSession* session, + _Outptr_ const OrtEpAssignedSubgraph* const** ep_subgraphs, + _Out_ size_t* num_ep_subgraphs) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + const auto* inference_session = reinterpret_cast(session); + const auto& session_options = inference_session->GetSessionOptions(); + bool is_enabled = + session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsRecordEpGraphAssignmentInfo, "0") == "1"; + + if (!is_enabled) { + std::ostringstream oss; + oss << "Session configuration entry '" << kOrtSessionOptionsRecordEpGraphAssignmentInfo + << "' must be set to \"1\" to retrieve EP graph assignment information."; + return OrtApis::CreateStatus(ORT_FAIL, oss.str().c_str()); + } + + if (ep_subgraphs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'ep_subgraphs' argument is null"); + } + + if (num_ep_subgraphs == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "'num_ep_subgraphs' argument is null"); + } + + const std::vector& ep_assignment_info = inference_session->GetEpGraphAssignmentInfo(); + + *ep_subgraphs = ep_assignment_info.data(); + *num_ep_subgraphs = ep_assignment_info.size(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(session); + ORT_UNUSED_PARAMETER(ep_subgraphs); + ORT_UNUSED_PARAMETER(num_ep_subgraphs); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::EpAssignedSubgraph_GetEpName, _In_ const OrtEpAssignedSubgraph* ep_subgraph, + _Outptr_ const char** out) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "EpAssignedSubgraph_GetEpName requires a valid (non-null) `out` output parameter " + "into which to store the EP name string."); + } + + *out = ep_subgraph->ep_name.c_str(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ep_subgraph); + ORT_UNUSED_PARAMETER(out); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgraph* ep_subgraph, + _Outptr_ const OrtEpAssignedNode* const** ep_nodes, _Out_ size_t* num_ep_nodes) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (ep_nodes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "EpAssignedSubgraph_GetNodes requires a valid (non-null) `ep_nodes` output parameter " + "into which to store the pointer to the node array."); + } + + if (num_ep_nodes == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "EpAssignedSubgraph_GetNodes requires a valid (non-null) `num_ep_nodes` " + "output parameter into which to store the number of nodes."); + } + + *ep_nodes = ep_subgraph->nodes.data(); + *num_ep_nodes = ep_subgraph->nodes.size(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ep_subgraph); + ORT_UNUSED_PARAMETER(ep_nodes); + ORT_UNUSED_PARAMETER(num_ep_nodes); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, + _Outptr_ const char** out) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "EpAssignedNode_GetName requires a valid (non-null) `out` output parameter " + "into which to store the name string."); + } + + *out = ep_node->name.c_str(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ep_node); + ORT_UNUSED_PARAMETER(out); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, + _Outptr_ const char** out) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "EpAssignedNode_GetDomain requires a valid (non-null) `out` output parameter " + "into which to store the domain string."); + } + + *out = ep_node->domain.c_str(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ep_node); + ORT_UNUSED_PARAMETER(out); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtApis::EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, + _Outptr_ const char** out) { + API_IMPL_BEGIN +#if !defined(ORT_MINIMAL_BUILD) + if (out == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "EpAssignedNode_GetOperatorType requires a valid (non-null) `out` output parameter " + "into which to store the operator type string."); + } + + *out = ep_node->op_type.c_str(); + return nullptr; +#else + ORT_UNUSED_PARAMETER(ep_node); + ORT_UNUSED_PARAMETER(out); + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "EP graph assignment information is not supported in this build"); +#endif // !defined(ORT_MINIMAL_BUILD) + API_IMPL_END +} + struct OrtIoBinding { std::unique_ptr<::onnxruntime::IOBinding> binding_; explicit OrtIoBinding(std::unique_ptr<::onnxruntime::IOBinding>&& binding) : binding_(std::move(binding)) {} @@ -2874,72 +3027,242 @@ ORT_API_STATUS_IMPL(OrtApis::Graph_GetGraphView, _In_ const OrtGraph* src_graph, const EpGraph* ep_graph = EpGraph::ToInternal(src_graph); if (ep_graph == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "src_graph is a ModelEditorGraph which doesn't support Graph_GetSubGraph."); + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "src_graph is a ModelEditorGraph which doesn't support Graph_GetGraphView."); } - const Graph& graph = ep_graph->GetGraphViewer().GetGraph(); + const GraphViewer& graph_viewer = ep_graph->GetGraphViewer(); + const Graph& graph = graph_viewer.GetGraph(); // Create a GraphViewer with filtered info + // TODO: Investigate whether utils::MakeComputeCapability can be extended and reused instead std::unique_ptr indexed_sub_graph = std::make_unique(); - std::unique_ptr metadef = std::make_unique(); - metadef->name = "sub_graph"; - metadef->since_version = 1; - std::unordered_set outputs; - std::unordered_set initializers; - - auto add_inputs = [&](ConstPointerContainer> defs) { - for (const auto* def : defs) { - if (def->Exists()) { - // not the output of a previous node - if (outputs.count(def->Name()) == 0) { - metadef->inputs.push_back(def->Name()); - } else { - // consumed by node so no longer subgraph output - // NOTE: Ignoring edge case where a node output is an overall graph output AND a node input - outputs.erase(def->Name()); - } - if (graph.IsInitializedTensor(def->Name())) { - initializers.insert(def); - } + // Following data structures help determine the final inputs/outputs of the subgraph. + // Note: The 'subgraph' here refers to a graph contains a subset of nodes in the 'src_graph'. + + // Subgraph's node set + const std::unordered_set node_set = [&]() { + std::unordered_set node_set; + for (size_t i = 0; i < num_nodes; i++) { + const OrtNode* ort_node = nodes[i]; + const EpNode* ep_node = EpNode::ToInternal(ort_node); + if (ep_node != nullptr) { + node_set.insert(ep_node->GetInternalNode().Index()); } } - }; - auto add_node = [&](const Node& node) { - indexed_sub_graph->nodes.push_back(node.Index()); - add_inputs(node.InputDefs()); - add_inputs(node.ImplicitInputDefs()); + return node_set; + }(); - for (const auto* def : node.OutputDefs()) { - outputs.insert(def->Name()); - } - }; + // Source graph output names + std::unordered_set graph_output_names; + for (const auto* output_arg : graph_viewer.GetOutputs()) { + graph_output_names.insert(output_arg->Name()); + } + + // These maps store the inputs and outputs of the subgraph. + // Please note that the inputs and outputs of the maps will be dynamically updated during node iteration + // to determine the final inputs and outputs of the subgraph. + std::unordered_map subgraph_inputs, subgraph_outputs; + + // This map stores the node's output that will be consumed by another node outside of this subgraph. + // So the node's output should be put into the subgraph's output list. + std::unordered_map subgraph_outputs_to_add; + + // This map stores the node's output that is original graph's output. + // So the node's output should be put into the subgraph's output list. + std::unordered_map graph_outputs_to_add; + + std::unordered_set erased; + + // This is the relative ordering that ensures node's input or output being added to the 'subgraph_inputs', + // 'subgraph_outputs', 'subgraph_outputs_to_add' and 'graph_outputs_to_add' maps is associated with a relative order index. + // Items added earlier receive a smaller order index than items added later. + // When constructing the final subgraph's input or output lists, entries with smaller + // order indices will appear before those with larger indices. + int input_order = 0; + int output_order = 0; + + // node arg to its consumer nodes. + // Note: graph.GetConsumerNodes() is not available in minimal build, in order to use unified implementation across + // all builds, this map is needed to determine if node arg is consumed by other nodes. + std::unordered_map> node_arg_to_consumer_nodes; + + std::vector initializers; // Add nodes - for (size_t node_idx = 0; node_idx < num_nodes; node_idx++) { - const OrtNode* ort_node = nodes[node_idx]; + for (size_t i = 0; i < num_nodes; i++) { + const OrtNode* ort_node = nodes[i]; const EpNode* ep_node = EpNode::ToInternal(ort_node); if (ep_node == nullptr) { - return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, "node is a ModelEditorNode which doesn't support Graph_GetSubGraph."); + return OrtApis::CreateStatus(OrtErrorCode::ORT_INVALID_ARGUMENT, + "node is a ModelEditorNode which doesn't support Graph_GetGraphView."); + } + const Node& node = ep_node->GetInternalNode(); + indexed_sub_graph->nodes.push_back(node.Index()); + + for (const auto& input : node.InputDefs()) { + if (!input->Exists()) { + continue; + } + + if (graph_viewer.IsConstantInitializer(input->Name(), true)) { + initializers.push_back(input->Name()); + continue; + } + const auto& it = subgraph_outputs.find(input); + if (it != subgraph_outputs.end()) { + subgraph_outputs.erase(it); + erased.insert(input); + } else if (erased.find(input) == erased.end()) { + // Only when input is neither in output list nor erased list, add the input to input list + subgraph_inputs.insert({input, input_order++}); + } + } + + for (const auto& input : node.ImplicitInputDefs()) { + if (!input->Exists()) { + continue; + } + + if (graph_viewer.IsConstantInitializer(input->Name(), true)) { + initializers.push_back(input->Name()); + continue; + } + const auto& it = subgraph_outputs.find(input); + if (it != subgraph_outputs.end()) { + subgraph_outputs.erase(it); + erased.insert(input); + } else if (erased.find(input) == erased.end()) { + // Only when input is neither in output list nor erased list, add the input to input list + subgraph_inputs.insert({input, input_order++}); + } + } + + // For output searching, there are two special cases, + // One is, if subgraph's node output is parent graph's output. the node output should + // be also added to the subgraph's output list + // The other one is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, + // if the output is connected to nodes that don't belong to the subgraph, the output need to be added + // to the output list + for (const auto& output : node.OutputDefs()) { + if (!output->Exists()) { + continue; + } + + const auto& it = subgraph_inputs.find(output); + if (it != subgraph_inputs.end()) { + subgraph_inputs.erase(it); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + auto has_consumer_nodes = [&](const std::string& node_arg_str) -> bool { + // Same implementation as Graph::PopulateNodeArgToProducerConsumerLookupsFromNodes() + if (node_arg_to_consumer_nodes.empty()) { + for (const auto& node : graph.Nodes()) { + node.ForEachDef([&](const NodeArg& node_arg, bool is_input) { + if (is_input) { + node_arg_to_consumer_nodes[node_arg.Name()].insert(node.Index()); + } + }); + } + } + return node_arg_to_consumer_nodes.find(node_arg_str) != node_arg_to_consumer_nodes.end(); + }; + + if (has_consumer_nodes(output->Name())) { + // Only when output is neither in input list nor erased list, + // and the output is consumed by another node, add the output to output list + subgraph_outputs.insert({output, output_order++}); + } + } + + if (graph_output_names.find(output->Name()) != graph_output_names.end()) { + // This output is the graph's output. + // So the output should be put into the subgraph's output list. + graph_outputs_to_add.insert({output, output_order++}); + } + } + + if (node.GetOutputEdgesCount() > node.OutputDefs().size()) { + for (auto it = node.OutputEdgesBegin(), end = node.OutputEdgesEnd(); it != end; ++it) { + const auto& node_idx = it->GetNode().Index(); + + if (node_set.find(node_idx) == node_set.end()) { + // This output will be consumed by another node outside of this subgraph. + // So the output should be put into the subgraph's output list. + const NodeArg* output = nullptr; + + // The dst_arg_index from GetDstArgIndex() could be the index for explicit/implicit input defs of the node. + // We need to get the correct input index accordingly. (See Graph::BuildConnections() in graph.cc for more details) + if (it->GetDstArgIndex() < static_cast(it->GetNode().InputDefs().size())) { + output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; + } else { + output = (it->GetNode()).ImplicitInputDefs()[it->GetDstArgIndex() - it->GetNode().InputDefs().size()]; + } + subgraph_outputs_to_add.insert({output, output_order++}); + } + } } - add_node(ep_node->GetInternalNode()); } - // Add initializers - for (auto& initializer : initializers) { - metadef->constant_initializers.push_back(initializer->Name()); + subgraph_outputs.insert(subgraph_outputs_to_add.begin(), subgraph_outputs_to_add.end()); + subgraph_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); + + std::multimap inputs, outputs; + + // Get the input order of the original graph + std::unordered_map original_inputs; + int order = 0; + for (const auto* input : graph_viewer.GetInputs()) { + original_inputs[input] = order++; } - // Add outputs - for (auto& output : outputs) { - metadef->outputs.push_back(output); + // input order needs to be consistent with original graph's input order + for (const auto& [node_arg, subgraph_input_order] : subgraph_inputs) { + const auto& original_input_it = original_inputs.find(node_arg); + + if (original_input_it != original_inputs.end()) { + inputs.insert(std::make_pair( + original_input_it->second, // input order from original graph + node_arg)); + } else { + inputs.insert(std::make_pair( + subgraph_input_order, // input order from subgraph + node_arg)); + } + } + + // Sort outputs by the order they were added + for (auto it = subgraph_outputs.begin(), end = subgraph_outputs.end(); it != end; ++it) { + outputs.insert(std::pair(it->second, it->first)); } - indexed_sub_graph->SetMetaDef(std::move(metadef)); - auto graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); + std::unique_ptr meta_def = std::make_unique(); + meta_def->name = "sub_graph"; + meta_def->since_version = 1; + + // Assign inputs and outputs to subgraph's meta_def + for (const auto& input : inputs) { + if (input.second->Exists()) { + meta_def->inputs.push_back(input.second->Name()); + } + } + + for (const auto& initializer : initializers) { + meta_def->constant_initializers.push_back(initializer); + } + + for (const auto& output : outputs) { + if (output.second->Exists()) { + meta_def->outputs.push_back(output.second->Name()); + } + } + + indexed_sub_graph->SetMetaDef(std::move(meta_def)); + auto new_graph_viewer = std::make_unique(graph, *indexed_sub_graph.get()); std::unique_ptr result; - ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(graph_viewer), std::move(indexed_sub_graph), result)); + ORT_API_RETURN_IF_STATUS_NOT_OK(EpGraph::Create(std::move(new_graph_viewer), std::move(indexed_sub_graph), result)); *dst_graph = result.release(); @@ -3390,6 +3713,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtS ep_option_vals_span, session_options->value)); + ORT_API_RETURN_IF_STATUS_NOT_OK(AddEpCustomDomainsToSessionOptions( + ep_devices_span, + *session_options)); + session_options->provider_factories.push_back(std::move(provider_factory)); return nullptr; @@ -3620,6 +3947,93 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, API_IMPL_END } +// Helper function to extract compatibility info from model metadata +static OrtStatus* ExtractCompatibilityInfoFromModelProto( + const ONNX_NAMESPACE::ModelProto& model_proto, + const char* ep_type, + OrtAllocator* allocator, + char** compatibility_info) { + // Build the key we're looking for + std::string target_key = std::string(kOrtModelMetadata_EpCompatibilityInfoPrefix) + ep_type; + + // Search through metadata_props for the matching key + for (const auto& prop : model_proto.metadata_props()) { + if (prop.key() == target_key) { + // Found it - allocate and copy the value using the provided allocator + *compatibility_info = onnxruntime::StrDup(prop.value(), allocator); + if (*compatibility_info == nullptr) { + return OrtApis::CreateStatus(ORT_FAIL, "Failed to allocate memory for compatibility info."); + } + return nullptr; + } + } + + // Key not found - return nullptr (not an error, just means no compat info for this EP) + *compatibility_info = nullptr; + return nullptr; +} + +// Extract EP compatibility info from a model file +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModel, + _In_ const ORTCHAR_T* model_path, + _In_ const char* ep_type, + _Inout_ OrtAllocator* allocator, + _Outptr_result_maybenull_ char** compatibility_info) { + API_IMPL_BEGIN + if (model_path == nullptr || ep_type == nullptr || ep_type[0] == '\0' || + allocator == nullptr || compatibility_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid argument provided to GetCompatibilityInfoFromModel."); + } + + *compatibility_info = nullptr; + + // Use Model::Load for proper cross-platform path handling via file descriptor + ONNX_NAMESPACE::ModelProto model_proto; + auto status = Model::Load(PathString(model_path), model_proto); + if (!status.IsOK()) { + if (status.Code() == common::NO_SUCHFILE) { + return OrtApis::CreateStatus(ORT_NO_SUCHFILE, status.ErrorMessage().c_str()); + } + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, status.ErrorMessage().c_str()); + } + + return ExtractCompatibilityInfoFromModelProto(model_proto, ep_type, allocator, compatibility_info); + API_IMPL_END +} + +// Extract EP compatibility info from model bytes in memory +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModelBytes, + _In_reads_(model_data_length) const void* model_data, + _In_ size_t model_data_length, + _In_ const char* ep_type, + _Inout_ OrtAllocator* allocator, + _Outptr_result_maybenull_ char** compatibility_info) { + API_IMPL_BEGIN + if (model_data == nullptr || model_data_length == 0 || ep_type == nullptr || ep_type[0] == '\0' || + allocator == nullptr || compatibility_info == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Invalid argument provided to GetCompatibilityInfoFromModelBytes."); + } + + *compatibility_info = nullptr; + + // Explicit check for size limit - Model::LoadFromBytes uses int for size due to protobuf API + if (model_data_length > static_cast(INT_MAX)) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, + "Model data size exceeds maximum supported size (2GB). Use GetCompatibilityInfoFromModel with a file path instead."); + } + + ONNX_NAMESPACE::ModelProto model_proto; + auto status = Model::LoadFromBytes(static_cast(model_data_length), model_data, model_proto); + if (!status.IsOK()) { + return OrtApis::CreateStatus(ORT_INVALID_GRAPH, status.ErrorMessage().c_str()); + } + + return ExtractCompatibilityInfoFromModelProto(model_proto, ep_type, allocator, compatibility_info); + API_IMPL_END +} + // GetInteropApi - returns the Interop API struct ORT_API(const OrtInteropApi*, OrtApis::GetInteropApi) { return OrtInteropAPI::GetInteropApi(); @@ -3658,6 +4072,29 @@ ORT_API_STATUS_IMPL(OrtApis::GetModelCompatibilityForEpDevices, API_IMPL_END } +// Minimal build stub for GetCompatibilityInfoFromModel +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModel, + _In_ const ORTCHAR_T* /*model_path*/, + _In_ const char* /*ep_type*/, + _Inout_ OrtAllocator* /*allocator*/, + _Outptr_result_maybenull_ char** /*compatibility_info*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetCompatibilityInfoFromModel is not supported in a minimal build."); + API_IMPL_END +} + +// Minimal build stub for GetCompatibilityInfoFromModelBytes +ORT_API_STATUS_IMPL(OrtApis::GetCompatibilityInfoFromModelBytes, + _In_reads_(model_data_length) const void* /*model_data*/, + _In_ size_t /*model_data_length*/, + _In_ const char* /*ep_type*/, + _Inout_ OrtAllocator* /*allocator*/, + _Outptr_result_maybenull_ char** /*compatibility_info*/) { + API_IMPL_BEGIN + return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "GetCompatibilityInfoFromModelBytes is not supported in a minimal build."); + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOptions* /*session_options*/, _In_ OrtEnv* /*env*/, _In_reads_(num_ep_devices) const OrtEpDevice* const* /*ep_devices*/, @@ -4354,8 +4791,19 @@ static constexpr OrtApi ort_api_1_to_24 = { &OrtApis::DeviceEpIncompatibilityDetails_GetNotes, &OrtApis::DeviceEpIncompatibilityDetails_GetErrorCode, &OrtApis::ReleaseDeviceEpIncompatibilityDetails, + &OrtApis::GetCompatibilityInfoFromModel, + &OrtApis::GetCompatibilityInfoFromModelBytes, &OrtApis::CreateEnvWithOptions, + &OrtApis::Session_GetEpGraphAssignmentInfo, + &OrtApis::EpAssignedSubgraph_GetEpName, + &OrtApis::EpAssignedSubgraph_GetNodes, + &OrtApis::EpAssignedNode_GetName, + &OrtApis::EpAssignedNode_GetDomain, + &OrtApis::EpAssignedNode_GetOperatorType, + &OrtApis::RunOptionsSetSyncStream, + &OrtApis::GetTensorElementTypeAndShapeDataReference, + // End of Version 24 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -4392,9 +4840,10 @@ static_assert(offsetof(OrtApi, SetEpDynamicOptions) / sizeof(void*) == 284, "Siz static_assert(offsetof(OrtApi, GetEpApi) / sizeof(void*) == 317, "Size of version 22 API cannot change"); static_assert(offsetof(OrtApi, CreateExternalInitializerInfo) / sizeof(void*) == 389, "Size of version 23 API cannot change"); +static_assert(offsetof(OrtApi, GetTensorElementTypeAndShapeDataReference) / sizeof(void*) == 414, "Size of version 24 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.24.0", +static_assert(std::string_view(ORT_VERSION) == "1.24.3", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // 2. If there were any APIs added to ort_api_1_to_24 above: diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index a38ee0c1eab11..ab3dd45629777 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -122,6 +122,7 @@ ORT_API_STATUS_IMPL(RunOptionsGetRunTag, _In_ const OrtRunOptions*, _Out_ const ORT_API_STATUS_IMPL(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); ORT_API_STATUS_IMPL(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); +ORT_API(void, RunOptionsSetSyncStream, _Inout_ OrtRunOptions* options, _In_ OrtSyncStream* sync_stream); ORT_API_STATUS_IMPL(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, @@ -659,6 +660,17 @@ ORT_API_STATUS_IMPL(GetModelCompatibilityForEpDevices, _In_ size_t num_ep_devices, _In_ const char* compatibility_info, _Out_ OrtCompiledModelCompatibility* out_status); +ORT_API_STATUS_IMPL(GetCompatibilityInfoFromModel, + _In_ const ORTCHAR_T* model_path, + _In_ const char* ep_type, + _Inout_ OrtAllocator* allocator, + _Outptr_result_maybenull_ char** compatibility_info); +ORT_API_STATUS_IMPL(GetCompatibilityInfoFromModelBytes, + _In_reads_(model_data_length) const void* model_data, + _In_ size_t model_data_length, + _In_ const char* ep_type, + _Inout_ OrtAllocator* allocator, + _Outptr_result_maybenull_ char** compatibility_info); ORT_API_STATUS_IMPL(Graph_GetModelPath, _In_ const OrtGraph* graph, _Outptr_ const ORTCHAR_T** model_path); ORT_API_STATUS_IMPL(Graph_GetOnnxIRVersion, _In_ const OrtGraph* graph, _Out_ int64_t* onnx_ir_version); ORT_API_STATUS_IMPL(Graph_GetNumOperatorSets, _In_ const OrtGraph* graph, _Out_ size_t* num_operator_sets); @@ -784,4 +796,21 @@ ORT_API_STATUS_IMPL(SessionGetEpDeviceForOutputs, _In_ const OrtSession* session // OrtEnv ORT_API_STATUS_IMPL(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out); + +// APIs to get EP graph assignment info +ORT_API_STATUS_IMPL(Session_GetEpGraphAssignmentInfo, _In_ const OrtSession* session, + _Outptr_ const OrtEpAssignedSubgraph* const** ep_subgraphs, + _Out_ size_t* num_ep_subgraphs); +ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetEpName, _In_ const OrtEpAssignedSubgraph* ep_subgraph, + _Outptr_ const char** out); +ORT_API_STATUS_IMPL(EpAssignedSubgraph_GetNodes, _In_ const OrtEpAssignedSubgraph* ep_subgraph, + _Outptr_ const OrtEpAssignedNode* const** ep_nodes, _Out_ size_t* num_ep_nodes); +ORT_API_STATUS_IMPL(EpAssignedNode_GetName, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); +ORT_API_STATUS_IMPL(EpAssignedNode_GetDomain, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); +ORT_API_STATUS_IMPL(EpAssignedNode_GetOperatorType, _In_ const OrtEpAssignedNode* ep_node, _Outptr_ const char** out); + +ORT_API_STATUS_IMPL(GetTensorElementTypeAndShapeDataReference, _In_ const OrtValue* value, + _Out_ ONNXTensorElementDataType* elem_type, + _Outptr_result_maybenull_ const int64_t** shape_data, + _Out_ size_t* shape_data_count); } // namespace OrtApis diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h index f562ee73f2aaa..01f7bc67a522e 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_internal_impl.h @@ -97,6 +97,18 @@ class EpFactoryInternalImpl { return nullptr; } + virtual OrtStatus* GetNumCustomOpDomains(_Out_ size_t* num_domains) const noexcept { + *num_domains = 0; + return nullptr; + } + + virtual OrtStatus* GetCustomOpDomains(_Out_writes_all_(num_domains) OrtCustomOpDomain** domains, + _In_ size_t num_domains) const noexcept { + ORT_UNUSED_PARAMETER(domains); + ORT_UNUSED_PARAMETER(num_domains); + return nullptr; + } + // Function ORT calls to release an EP instance. void ReleaseEp(OrtEp* ep); diff --git a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h index 3a7a1b6504d12..eb1427db87463 100644 --- a/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h +++ b/onnxruntime/core/session/plugin_ep/ep_factory_provider_bridge.h @@ -86,6 +86,21 @@ class ProviderBridgeEpFactory : public EpFactoryInternalImpl { return ep_factory_.GetHardwareDeviceIncompatibilityDetails(&ep_factory_, hw, details); } + OrtStatus* ValidateCompiledModelCompatibilityInfo( + const OrtHardwareDevice* const* devices, + size_t num_devices, + const char* compatibility_info, + OrtCompiledModelCompatibility* model_compatibility) noexcept override { + // Forward to underlying factory if it supports validation + if (ep_factory_.ValidateCompiledModelCompatibilityInfo) { + return ep_factory_.ValidateCompiledModelCompatibilityInfo( + &ep_factory_, devices, num_devices, compatibility_info, model_compatibility); + } + // If not supported, return NOT_APPLICABLE + *model_compatibility = OrtCompiledModelCompatibility_EP_NOT_APPLICABLE; + return nullptr; + } + OrtEpFactory& ep_factory_; ProviderLibrary& provider_library_; std::optional library_path_; diff --git a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc index 4db8bb05f94de..f8cba9435a6bc 100644 --- a/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc +++ b/onnxruntime/core/session/plugin_ep/ep_plugin_provider_interfaces.cc @@ -61,8 +61,7 @@ PluginExecutionProviderFactory::CreateProvider(const OrtSessionOptions& session_ Status status = CreatePluginExecutionProvider(session_options, session_logger, plugin_ep); if (!status.IsOK()) { - LOGS(*session_logger.ToInternal(), ERROR) << "Error creating execution provider: " << status.ToString(); - return nullptr; + ORT_THROW("Error creating execution provider: ", status.ToString()); } return plugin_ep; @@ -209,7 +208,7 @@ PluginExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const logging::Logger& logger = GetLogger() != nullptr ? *GetLogger() : logging::LoggingManager::DefaultLogger(); std::unique_ptr ep_graph = nullptr; - if (Status status = EpGraph::Create(graph_viewer, ep_graph); !status.IsOK()) { + if (Status status = EpGraph::Create(graph_viewer, ep_graph, true); !status.IsOK()) { LOGS(logger, ERROR) << "Failed to create OrtGraph for " << Type() << ": " << status.ToString(); return {}; } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index a55ab38113a0f..3dc2df6d78ba1 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1002,6 +1002,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetType_Int4x2() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetType_UInt4x2() override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetType_Int2x4() override { return DataTypeImpl::GetType(); } + MLDataType DataTypeImpl__GetType_UInt2x4() override { return DataTypeImpl::GetType(); } MLDataType DataTypeImpl__GetTensorTypeFromOnnxType(int onnx_type) override { return DataTypeImpl::TensorTypeFromONNXEnum(onnx_type)->AsTensorType(); } MLDataType DataTypeImpl__GetTensorType_bool() override { return DataTypeImpl::GetTensorType(); } @@ -1031,6 +1033,8 @@ struct ProviderHostImpl : ProviderHost { MLDataType DataTypeImpl__GetTensorType_Int4x2() override { return DataTypeImpl::GetTensorType(); } MLDataType DataTypeImpl__GetTensorType_UInt4x2() override { return DataTypeImpl::GetTensorType(); } + MLDataType DataTypeImpl__GetTensorType_Int2x4() override { return DataTypeImpl::GetTensorType(); } + MLDataType DataTypeImpl__GetTensorType_UInt2x4() override { return DataTypeImpl::GetTensorType(); } #if !defined(DISABLE_SPARSE_TENSORS) MLDataType DataTypeImpl__GetSparseTensorType_bool() override { return DataTypeImpl::GetSparseTensorType(); } @@ -1282,11 +1286,6 @@ struct ProviderHostImpl : ProviderHost { return onnxruntime::utils::HasExternalDataInMemory(ten_proto); } - Status Utils__ValidateExternalDataPath(const std::filesystem::path& base_path, - const std::filesystem::path& location) override { - return onnxruntime::utils::ValidateExternalDataPath(base_path, location); - } - // Model (wrapped) std::unique_ptr Model__construct(ONNX_NAMESPACE::ModelProto&& model_proto, const PathString& model_path, const IOnnxRuntimeOpSchemaRegistryList* local_registries, @@ -1680,6 +1679,8 @@ struct ProviderHostImpl : ProviderHost { Int4x2* Tensor__MutableData_Int4x2(Tensor* p) override { return p->MutableData(); } UInt4x2* Tensor__MutableData_UInt4x2(Tensor* p) override { return p->MutableData(); } + Int2x4* Tensor__MutableData_Int2x4(Tensor* p) override { return p->MutableData(); } + UInt2x4* Tensor__MutableData_UInt2x4(Tensor* p) override { return p->MutableData(); } const bool* Tensor__Data_bool(const Tensor* p) override { return p->Data(); } const int8_t* Tensor__Data_int8(const Tensor* p) override { return p->Data(); } @@ -1708,6 +1709,8 @@ struct ProviderHostImpl : ProviderHost { const Int4x2* Tensor__Data_Int4x2(const Tensor* p) override { return p->Data(); } const UInt4x2* Tensor__Data_UInt4x2(const Tensor* p) override { return p->Data(); } + const Int2x4* Tensor__Data_Int2x4(const Tensor* p) override { return p->Data(); } + const UInt2x4* Tensor__Data_UInt2x4(const Tensor* p) override { return p->Data(); } gsl::span Tensor__DataAsSpan_int64(const Tensor* p) override { return p->DataAsSpan(); } @@ -1744,6 +1747,8 @@ struct ProviderHostImpl : ProviderHost { bool Tensor__IsDataType_Int4x2(const Tensor* p) noexcept override { return p->IsDataType(); } bool Tensor__IsDataType_UInt4x2(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_Int2x4(const Tensor* p) noexcept override { return p->IsDataType(); } + bool Tensor__IsDataType_UInt2x4(const Tensor* p) noexcept override { return p->IsDataType(); } const TensorShape& Tensor__Shape(const Tensor* p) override { return p->Shape(); } void Tensor__Reshape(Tensor* p, const TensorShape& new_shape) override { return p->Reshape(new_shape); } diff --git a/onnxruntime/core/session/utils.cc b/onnxruntime/core/session/utils.cc index 944e83d8cad66..a354cf26368d4 100644 --- a/onnxruntime/core/session/utils.cc +++ b/onnxruntime/core/session/utils.cc @@ -98,6 +98,58 @@ Status TestAutoSelectEPsImpl(const Environment& env, InferenceSession& sess, con return Status::OK(); } + +Status GetCustomOpDomainsFromEpDevice(const OrtEpDevice& ep_device, InlinedVector& domains_out) { + InlinedVector domains{}; + + // Get custom op domain provided by EP factory if any. + // OrtEpFactory::GetNumCustomOpDomains and OrtEpFactory::GetCustomOpDomains were added in ORT 1.24. + OrtEpFactory* ep_factory = ep_device.ep_factory; + if (ep_factory && + ep_factory->ort_version_supported >= 24 && + ep_factory->GetNumCustomOpDomains != nullptr && + ep_factory->GetCustomOpDomains != nullptr) { + size_t num_domains = 0; + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetNumCustomOpDomains(ep_factory, &num_domains))); + + domains.resize(num_domains); + ORT_RETURN_IF_ERROR(ToStatusAndRelease(ep_factory->GetCustomOpDomains(ep_factory, domains.data(), + domains.size()))); + } + + domains_out = std::move(domains); + return Status::OK(); +} + +bool DoesDomainWithNameExist(const std::string& domain_name, gsl::span domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; +} + +bool ShouldAddDomain(const OrtCustomOpDomain* domain_to_add, + gsl::span existing_domains) { + if (!domain_to_add) { + return false; + } + + if (domain_to_add->custom_ops_.size() == 0) { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain_to_add->domain_ + << "': custom ops is empty."; + return false; + } + + if (DoesDomainWithNameExist(domain_to_add->domain_, existing_domains)) { + LOGS_DEFAULT(WARNING) << "Skipping custom op domain '" << domain_to_add->domain_ + << "': domain already exists in session options."; + return false; + } + + return true; +} } // namespace #endif // !defined(ORT_MINIMAL_BUILD) @@ -195,6 +247,31 @@ static OrtStatus* CreateSessionAndLoadModelImpl(_In_ const OrtSessionOptions* op } #endif +#if !defined(ORT_MINIMAL_BUILD) + // Add custom domains for all OrtEpDevice instances to inference session. + // The custom domains should be registered before model load for ORT to validate the custom ops. + if (options != nullptr && + options->provider_factories.empty() && + options->value.ep_selection_policy.enable) { + InlinedVector all_ep_custom_op_domains; + + for (const OrtEpDevice* ep_device : env.GetOrtEpDevices()) { + InlinedVector domains; + ORT_API_RETURN_IF_STATUS_NOT_OK(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); + + for (auto domain : domains) { + if (ShouldAddDomain(domain, options->custom_op_domains_)) { + all_ep_custom_op_domains.push_back(domain); + } + } + } + + if (!all_ep_custom_op_domains.empty()) { + ORT_API_RETURN_IF_STATUS_NOT_OK(sess->AddCustomOpDomains(all_ep_custom_op_domains)); + } + } +#endif + // Finish load if (load_config_from_model) { #if !defined(ORT_MINIMAL_BUILD) @@ -258,8 +335,10 @@ static Status ValidateCompiledModelCompatibility(InferenceSession& sess) { const auto& registered_provider_types = sess.GetRegisteredProviderTypes(); - // Access the execution providers through the session state (available after Initialize) - const auto& execution_providers = sess.GetSessionState().GetExecutionProviders(); + // Access the execution providers directly from the session. + // This allows validation to run before Initialize() completes, avoiding expensive + // graph transformations for incompatible models. EPs are fully registered at this point. + const auto& execution_providers = sess.GetExecutionProviders(); for (const auto& ep_type : registered_provider_types) { // Construct the full metadata key using the prefix + EP type @@ -378,14 +457,20 @@ OrtStatus* InitializeSession(_In_ const OrtSessionOptions* options, reinterpret_cast(prepacked_weights_container))); } - ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); - #if !defined(ORT_MINIMAL_BUILD) - // Validate compiled model compatibility for all registered execution providers - // This must be done after Initialize() so the session state is available + // Validate compiled model compatibility for all registered execution providers BEFORE Initialize(). + // This is an optimization to fail fast for incompatible models, avoiding expensive graph transformations, + // partitioning, and kernel binding that occur during Initialize(). + // This is safe because: + // 1. Model metadata (containing compatibility strings) is available after Load() completes. + // 2. Compiling EPs are fully registered at this point. + // 3. Non-compiling EPs (like CPU EP, which may be implicitly added during Initialize()) don't participate + // in compatibility validation - they return NOT_APPLICABLE by default. ORT_API_RETURN_IF_STATUS_NOT_OK(ValidateCompiledModelCompatibility(sess)); #endif // !defined(ORT_MINIMAL_BUILD) + ORT_API_RETURN_IF_STATUS_NOT_OK(sess.Initialize()); + return nullptr; } @@ -546,5 +631,22 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic return Status::OK(); } + +Status AddEpCustomDomainsToSessionOptions(gsl::span ep_devices, + OrtSessionOptions& ort_session_options) { + for (const OrtEpDevice* ep_device : ep_devices) { + // Add custom domains if EP factory has any. + InlinedVector domains; + ORT_RETURN_IF_ERROR(GetCustomOpDomainsFromEpDevice(*ep_device, domains)); + + for (auto domain : domains) { + if (ShouldAddDomain(domain, ort_session_options.custom_op_domains_)) { + ort_session_options.custom_op_domains_.push_back(domain); + } + } + } + + return Status::OK(); +} #endif // !defined(ORT_MINIMAL_BUILD) } // namespace onnxruntime diff --git a/onnxruntime/core/session/utils.h b/onnxruntime/core/session/utils.h index 2ccd4d464a261..59b4d9f0944c3 100644 --- a/onnxruntime/core/session/utils.h +++ b/onnxruntime/core/session/utils.h @@ -71,5 +71,9 @@ Status AddEpOptionsToSessionOptions(gsl::span ep_devic gsl::span ep_options_vals, SessionOptions& session_options); +// Adss EP specific custom domains to the OrtSessionOptions configuration. +Status AddEpCustomDomainsToSessionOptions(gsl::span ep_devices, + OrtSessionOptions& ort_session_options); + } // namespace onnxruntime #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/core/util/qmath.h b/onnxruntime/core/util/qmath.h index c0bb69eb732be..6abe3e7f5996f 100644 --- a/onnxruntime/core/util/qmath.h +++ b/onnxruntime/core/util/qmath.h @@ -229,6 +229,93 @@ ParQuantizeLinearStd(const float* Input, DEFINE_PAR_QUANT_LINEAR_STD_4BIT(ParQuantizeLinearStdS4, Int4x2, MlasQuantizeLinearS4) DEFINE_PAR_QUANT_LINEAR_STD_4BIT(ParQuantizeLinearStdU4, UInt4x2, MlasQuantizeLinearU4) +// TODO: add MLAS kernels for 2-bit types and generalize DEFINE_PAR_QUANT_LINEAR_STD_4BIT macro +// For 2-bit types, we need a generic implementation since MLAS kernels don't support 2-bit yet. +// Define a generic quantization function that doesn't rely on MLAS. +#define DEFINE_PAR_QUANT_LINEAR_STD_2BIT_GENERIC(FUNC_NAME, SUB_BYTE_TYPE) \ + inline void FUNC_NAME(const float* Input, \ + SUB_BYTE_TYPE* Output, \ + size_t out_start, \ + size_t out_end, \ + float Scale, \ + SUB_BYTE_TYPE ZeroPoint, \ + concurrency::ThreadPool* thread_pool) { \ + constexpr int32_t low = static_cast(SUB_BYTE_TYPE::min_val); \ + constexpr int32_t high = static_cast(SUB_BYTE_TYPE::max_val); \ + const int32_t zp = static_cast(ZeroPoint.GetElem(0)); \ + size_t inp_start = 0; \ + size_t inp_end = out_end - out_start; \ + \ + /* If starting at a 2-bit element not at the start of a byte, quantize those elements by themselves. */ \ + /* For 2-bit: 4 elements per byte, so check if out_start % 4 != 0 */ \ + size_t start_offset = out_start & 0x3; \ + if (start_offset != 0) { \ + size_t output_index = out_start >> 2; \ + size_t num_boundary = 4 - start_offset; /* Number of elements until byte boundary */ \ + num_boundary = std::min(num_boundary, inp_end - inp_start); \ + for (size_t i = 0; i < num_boundary; ++i) { \ + int32_t ival = static_cast(std::nearbyintf(Input[inp_start + i] / Scale)) + zp; \ + SUB_BYTE_TYPE::UnpackedType quant_val = \ + static_cast(std::min(high, std::max(low, ival))); \ + Output[output_index].SetElem((start_offset + i) & 0x3, quant_val); \ + } \ + out_start += num_boundary; \ + inp_start += num_boundary; \ + } \ + \ + /* If ending at a 2-bit element not at the end of a byte, quantize those elements by themselves. */ \ + size_t end_offset = out_end & 0x3; \ + if (end_offset != 0) { \ + size_t output_index = (out_end - end_offset) >> 2; \ + size_t num_boundary = end_offset; \ + for (size_t i = 0; i < num_boundary; ++i) { \ + int32_t ival = static_cast(std::nearbyintf(Input[inp_end - num_boundary + i] / Scale)) + zp; \ + SUB_BYTE_TYPE::UnpackedType quant_val = \ + static_cast(std::min(high, std::max(low, ival))); \ + Output[output_index].SetElem(i, quant_val); \ + } \ + out_end -= num_boundary; \ + inp_end -= num_boundary; \ + } \ + \ + if (out_start == out_end) { \ + return; \ + } \ + \ + /* At this point, should only need to quantize a number of 2-bit elements that are multiples of 4 */ \ + /* and start/end at byte boundaries. This ensures no two threads write to the same byte. */ \ + size_t N = out_end - out_start; \ + assert(N % 4 == 0); /* Should be guaranteed by previous code that quantizes boundary elements. */ \ + \ + constexpr std::ptrdiff_t block_size = 128; \ + static_assert(block_size % 4 == 0, \ + "Block size must be a multiple of 4 to ensure no two threads write to the same byte."); \ + \ + const std::ptrdiff_t num_blocks = (N + block_size - 1) / block_size; \ + const TensorOpCost unit_cost{static_cast(block_size * sizeof(float)), \ + static_cast(block_size * sizeof(SUB_BYTE_TYPE::UnpackedType)) / 4.0, \ + static_cast(block_size) * 2.0}; \ + \ + concurrency::ThreadPool::TryParallelFor( \ + thread_pool, num_blocks, unit_cost, \ + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { \ + auto begin_idx = begin * block_size; \ + auto end_idx = std::min(static_cast(N), end * block_size); \ + \ + for (auto idx = begin_idx; idx < end_idx; ++idx) { \ + size_t inp_idx = inp_start + idx; \ + size_t out_idx = out_start + idx; \ + int32_t ival = static_cast(std::nearbyintf(Input[inp_idx] / Scale)) + zp; \ + SUB_BYTE_TYPE::UnpackedType quant_val = \ + static_cast(std::min(high, std::max(low, ival))); \ + Output[out_idx >> 2].SetElem(out_idx & 0x3, quant_val); \ + } \ + }); \ + } + +DEFINE_PAR_QUANT_LINEAR_STD_2BIT_GENERIC(ParQuantizeLinearStdS2, Int2x4) +DEFINE_PAR_QUANT_LINEAR_STD_2BIT_GENERIC(ParQuantizeLinearStdU2, UInt2x4) + // This implementation could be more efficient however the cast from float16 to other types // usually happens on GPU. template @@ -840,6 +927,195 @@ struct BlockedQuantizeLinear { } }; +// Template specializations for 2-bit types (group 3: Int2x4, UInt2x4) +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + ORT_UNUSED_PARAMETER(thread_block_size); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + // to avoid a byte being written from multiple threads, use 4 * N as thread block (4 elements per byte for 2-bit) + auto size_thread_block = 4 * N; + auto num_thread_block = (M * K + 3) / 4; + auto num_quant_block_K = (K + quant_block_size - 1) / quant_block_size; + auto num_quant_block_KN = num_quant_block_K * N; + auto MK = M * K; + const TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(float) * 2), + static_cast(size_thread_block * sizeof(typename TOut::UnpackedType)), + static_cast(size_thread_block) * 2.0}; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + begin <<= 2, end = std::min(end << 2, MK); + auto output_idx = begin * N; + auto m = begin / K, k = begin % K; + auto zp_idx = m * num_quant_block_KN + k / quant_block_size * N; + + for (; begin < end; ++begin) { + auto zp_idx_t = zp_idx; + // auto output_idx_end = output_idx + N; + + for (; zp_idx_t < zp_idx + N; ++zp_idx_t, ++output_idx) { + auto zp = zero_point + ? static_cast(zero_point[zp_idx_t >> 2].GetElem(zp_idx_t & 0x3)) + : 0; + auto sc = scale[zp_idx_t]; + auto v = std::clamp(static_cast(std::nearbyint(input[output_idx] / sc)) + zp, low, high); + output[output_idx >> 2].SetElem(output_idx & 0x3, static_cast(v)); + } + + ++k; + if (k == K) { + k = 0; + zp_idx += N; + } else if (k % quant_block_size == 0) { + zp_idx += N; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const float* input, const float* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + // to avoid a byte being written from multiple threads, use 4 * K as thread block (4 elements per byte for 2-bit) + auto size_thread_block = 4 * K; + auto quant_block_num_K = (K + quant_block_size - 1) / quant_block_size; + auto num_thread_block = (M + 3) / 4; + TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(float)), + static_cast(size_thread_block * sizeof(typename TOut::UnpackedType)), + static_cast(size_thread_block) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + begin <<= 2, end = std::min(end << 2, M); + auto output_idx = begin * K; + auto zp_idx = begin * quant_block_num_K; + + for (; begin < end; ++begin, output_idx += K) { + auto output_row_idx_start = output_idx; + auto output_row_idx_end = output_row_idx_start + K; + + for (; output_row_idx_start < output_row_idx_end; output_row_idx_start += quant_block_size, ++zp_idx) { + auto zp = zero_point ? static_cast(zero_point[zp_idx >> 2].GetElem(zp_idx & 0x3)) : 0; + auto sc = scale[zp_idx]; + + for (auto idx = output_row_idx_start; idx < std::min(output_row_idx_start + quant_block_size, output_row_idx_end); ++idx) { + auto v = std::clamp(static_cast(std::nearbyint(input[idx] / sc)) + zp, low, high); + output[idx >> 2].SetElem(idx & 0x3, static_cast(v)); + } + } + } + }); + } +}; + +template +struct BlockedQuantizeLinear { + static void opNotLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + std::ptrdiff_t N, const std::ptrdiff_t quant_block_size, + const std::ptrdiff_t thread_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + ORT_UNUSED_PARAMETER(thread_block_size); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + auto size_thread_block = 4 * N; + auto num_thread_block = (M * K + 3) / 4; + auto num_quant_block_K = (K + quant_block_size - 1) / quant_block_size; + auto num_quant_block_KN = num_quant_block_K * N; + auto MK = M * K; + const TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(MLFloat16) * 2), + static_cast(size_thread_block * sizeof(typename TOut::UnpackedType)), + static_cast(size_thread_block) * 2.0}; + + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + begin <<= 2, end = std::min(end << 2, MK); + auto output_idx = begin * N; + auto m = begin / K, k = begin % K; + auto zp_idx = m * num_quant_block_KN + k / quant_block_size * N; + + for (; begin < end; ++begin) { + auto zp_idx_t = zp_idx; + // auto output_idx_end = output_idx + N; + + for (; zp_idx_t < zp_idx + N; ++zp_idx_t, ++output_idx) { + auto zp = zero_point + ? static_cast(zero_point[zp_idx_t >> 2].GetElem(zp_idx_t & 0x3)) + : 0; + auto sc = scale[zp_idx_t].ToFloat(); + auto v = std::clamp( + static_cast(std::nearbyint(input[output_idx].ToFloat() / sc)) + zp, low, high); + output[output_idx >> 2].SetElem(output_idx & 0x3, static_cast(v)); + } + + ++k; + if (k == K) { + k = 0; + zp_idx += N; + } else if (k % quant_block_size == 0) { + zp_idx += N; + } + } + }); + } + + static void opLastAxis(concurrency::ThreadPool* thread_pool, const MLFloat16* input, const MLFloat16* scale, + const TOut* zero_point, TOut* output, std::ptrdiff_t M, std::ptrdiff_t K, + const std::ptrdiff_t quant_block_size, bool saturate) { + ORT_UNUSED_PARAMETER(saturate); + constexpr auto low = static_cast(TOut::min_val); + constexpr auto high = static_cast(TOut::max_val); + auto size_thread_block = 4 * K; + auto quant_block_num_K = (K + quant_block_size - 1) / quant_block_size; + auto num_thread_block = (M + 3) / 4; + TensorOpCost unit_cost{static_cast(size_thread_block * sizeof(MLFloat16)), + static_cast(size_thread_block * sizeof(typename TOut::UnpackedType)), + static_cast(size_thread_block) * 2.0}; + concurrency::ThreadPool::TryParallelFor( + thread_pool, + num_thread_block, + unit_cost, + [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + begin <<= 2, end = std::min(end << 2, M); + auto output_idx = begin * K; + auto zp_idx = begin * quant_block_num_K; + + for (; begin < end; ++begin, output_idx += K) { + auto output_row_idx_start = output_idx; + auto output_row_idx_end = output_row_idx_start + K; + + for (; output_row_idx_start < output_row_idx_end; output_row_idx_start += quant_block_size, ++zp_idx) { + auto zp = zero_point ? static_cast(zero_point[zp_idx >> 2].GetElem(zp_idx & 0x3)) : 0; + auto sc = scale[zp_idx].ToFloat(); + + for (auto idx = output_row_idx_start; idx < std::min(output_row_idx_start + quant_block_size, output_row_idx_end); ++idx) { + auto v = std::clamp( + static_cast(std::nearbyint(input[idx].ToFloat() / sc)) + zp, low, high); + output[idx >> 2].SetElem(idx & 0x3, static_cast(v)); + } + } + } + }); + } +}; + #if !defined(DISABLE_FLOAT8_TYPES) template diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 7b4f130cc2b93..1aa28cfd45873 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -219,6 +219,15 @@ def get_provider_options(self): "Return registered execution providers' configurations." return self._provider_options + def get_provider_graph_assignment_info(self) -> Sequence[onnxruntime.OrtEpAssignedSubgraph]: + """ + Get information about the subgraphs assigned to each execution provider and the nodes within. + + Application must enable the recording of graph assignment information by setting the session configuration + for the key "session.record_ep_graph_assignment_info" to "1". + """ + return self._sess.get_provider_graph_assignment_info() + def set_providers(self, providers=None, provider_options=None) -> None: """ Register the input list of execution providers. The underlying session is re-created. @@ -512,8 +521,25 @@ def __init__( def _create_inference_session(self, providers, provider_options, disabled_optimizers=None): available_providers = C.get_available_providers() - # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. - if "TensorrtExecutionProvider" in available_providers: + # Validate that TensorrtExecutionProvider and NvTensorRTRTXExecutionProvider are not both specified + if providers: + has_tensorrt = any( + provider == "TensorrtExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") + for provider in providers + ) + has_tensorrt_rtx = any( + provider == "NvTensorRTRTXExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider") + for provider in providers + ) + if has_tensorrt and has_tensorrt_rtx: + raise ValueError( + "Cannot enable both 'TensorrtExecutionProvider' and 'NvTensorRTRTXExecutionProvider' " + "in the same session." + ) + # Tensorrt and TensorRT RTX can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. + if "NvTensorRTRTXExecutionProvider" in available_providers: if ( providers and any( @@ -522,15 +548,15 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi for provider in providers ) and any( - provider == "TensorrtExecutionProvider" - or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") + provider == "NvTensorRTRTXExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "NvTensorRTRTXExecutionProvider") for provider in providers ) ): self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: self._fallback_providers = ["CPUExecutionProvider"] - if "NvTensorRTRTXExecutionProvider" in available_providers: + elif "TensorrtExecutionProvider" in available_providers: if ( providers and any( @@ -539,8 +565,8 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi for provider in providers ) and any( - provider == "NvTensorRTRTXExecutionProvider" - or (isinstance(provider, tuple) and provider[0] == "NvExecutionProvider") + provider == "TensorrtExecutionProvider" + or (isinstance(provider, tuple) and provider[0] == "TensorrtExecutionProvider") for provider in providers ) ): diff --git a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc index d96d229c942cb..89651c2d955de 100644 --- a/onnxruntime/python/onnxruntime_pybind_mlvalue.cc +++ b/onnxruntime/python/onnxruntime_pybind_mlvalue.cc @@ -794,7 +794,7 @@ std::string _get_type_name(std::string&) { #if !defined(DISABLE_ML_OPS) template static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string& name_input, PyObject*& value, - PyObject* item, std::map& current, + PyObject* item, bool owns_item_ref, std::map& current, KeyGetterType keyGetter, ValueGetterType valueGetter) { KeyType ckey; ValueType cvalue; @@ -806,7 +806,9 @@ static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string sType = spyType; Py_XDECREF(pStr); Py_XDECREF(pType); - Py_XDECREF(item); + if (owns_item_ref) { + Py_XDECREF(item); + } throw std::runtime_error(std::string("Unexpected key type ") + sType + std::string(", it cannot be linked to C type ") + _get_type_name(ckey) + std::string(" for input '") + @@ -820,7 +822,9 @@ static void CreateMapMLValue_LoopIntoMap(Py_ssize_t& pos, PyObject*& key, const std::string sType = spyType; Py_XDECREF(pStr); Py_XDECREF(pType); - Py_XDECREF(item); + if (owns_item_ref) { + Py_XDECREF(item); + } throw std::runtime_error(std::string("Unexpected value type ") + sType + std::string(", it cannot be linked to C type ") + _get_type_name(ckey) + std::string(" for input '") + @@ -836,7 +840,7 @@ static void CreateMapMLValue_Map(Py_ssize_t& pos, PyObject*& key, const std::str ValueGetterType valueGetter) { std::unique_ptr> dst; dst = std::make_unique>(); - CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, *dst, keyGetter, valueGetter); + CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, false, *dst, keyGetter, valueGetter); p_mlvalue->Init(dst.release(), DataTypeImpl::GetType>(), DataTypeImpl::GetType>()->GetDeleteFunc()); } @@ -850,7 +854,7 @@ void CreateMapMLValue_VectorMap(Py_ssize_t& pos, PyObject*& key, const std::stri int index = 0; do { dstVector->push_back(std::map()); - CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, (*dstVector)[index], keyGetter, valueGetter); + CreateMapMLValue_LoopIntoMap(pos, key, name_input, value, item, true, (*dstVector)[index], keyGetter, valueGetter); Py_DECREF(item); ++index; item = iterator == NULL ? NULL : PyIter_Next(iterator); diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 5a72ecb6849c3..fe30cc3f51d85 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -66,12 +66,12 @@ void QuantizeMatMulNBitsBlockwise( tp.get()); } -template -bool QuantizeQDQMatMul4BitsBlockwise( - py::array_t dst, // shape: [K, N / 2] - py::array_t src, // shape: [K, N] - py::array_t scale, // shape: [block_per_K, N] - py::array_t zero_points, // shape: [block_per_K, N / 2] +template +bool QuantizeQDQMatMulNBitsBlockwise( + py::array_t dst, + py::array_t src, + py::array_t scale, + py::array_t zero_points, int32_t quant_block_size, int32_t N, int32_t K, @@ -85,7 +85,7 @@ bool QuantizeQDQMatMul4BitsBlockwise( py::buffer_info scale_buf = scale.request(); py::buffer_info zp_buf = zero_points.request(); - return MlasQDQQuantizeBlockwise( + return MlasQDQQuantizeBlockwise( reinterpret_cast(src_buf.ptr), reinterpret_cast(scale_buf.ptr), is_symmetric ? nullptr : reinterpret_cast(zp_buf.ptr), @@ -97,6 +97,19 @@ bool QuantizeQDQMatMul4BitsBlockwise( tp.get()); } +template +bool QuantizeQDQMatMul4BitsBlockwise( + py::array_t dst, // shape: [K, N / 2] + py::array_t src, // shape: [K, N] + py::array_t scale, // shape: [block_per_K, N] + py::array_t zero_points, // shape: [block_per_K, N / 2] + int32_t quant_block_size, + int32_t N, + int32_t K, + bool is_symmetric) { + return QuantizeQDQMatMulNBitsBlockwise(dst, src, scale, zero_points, quant_block_size, N, K, is_symmetric); +} + template void QuantizeMatMulBnb4Blockwise( py::array_t dst, @@ -134,6 +147,8 @@ void CreateQuantPybindModule(py::module& m) { m.def("quantize_matmul_8bits", &QuantizeMatMulNBitsBlockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); + m.def("quantize_qdq_matmul_2bits", &QuantizeQDQMatMulNBitsBlockwise); + m.def("quantize_qdq_matmul_2bits", &QuantizeQDQMatMulNBitsBlockwise); m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise); m.def("quantize_qdq_matmul_4bits", &QuantizeQDQMatMul4BitsBlockwise); } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f0d8906d99c14..39f2988a89b2f 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -1348,6 +1348,9 @@ static Status AddEpFactoryFromEpDevices(PySessionOptions& py_sess_options, ep_option_vals, py_sess_options.value)); + ORT_RETURN_IF_ERROR(AddEpCustomDomainsToSessionOptions(ep_devices, + py_sess_options)); + py_sess_options.provider_factories.push_back(std::move(provider_factory)); return Status::OK(); } @@ -1892,6 +1895,47 @@ for model inference.)pbdoc"); }, R"pbdoc(The OrtSyncStream instance for the OrtEpDevice.)pbdoc"); + py::class_ py_ep_node(m, "OrtEpAssignedNode", + R"pbdoc(Contains information about a node assigned to an execution +provider)pbdoc"); + py_ep_node + .def_property_readonly( + "name", + [](const OrtEpAssignedNode* ep_node) -> std::string { + return ep_node->name; + }, + R"pbdoc(The node's name)pbdoc") + .def_property_readonly( + "domain", + [](const OrtEpAssignedNode* ep_node) -> std::string { + return ep_node->domain; + }, + R"pbdoc(The node's domain)pbdoc") + .def_property_readonly( + "op_type", + [](const OrtEpAssignedNode* ep_node) -> std::string { + return ep_node->op_type; + }, + R"pbdoc(The node's operator type)pbdoc"); + + py::class_ py_ep_subgraph(m, "OrtEpAssignedSubgraph", + R"pbdoc(Contains information about a subgraph assigned to an +execution provider)pbdoc"); + py_ep_subgraph + .def_property_readonly( + "ep_name", + [](const OrtEpAssignedSubgraph* ep_subgraph) -> std::string { + return ep_subgraph->ep_name; + }, + R"pbdoc(The name of the execution provider to which this subgraph is assigned.)pbdoc") + .def( + "get_nodes", + [](const OrtEpAssignedSubgraph* ep_subgraph) -> const std::vector& { + return ep_subgraph->nodes; + }, + py::return_value_policy::reference_internal, + R"pbdoc(List of nodes in the subgraph.)pbdoc"); + py::class_ ort_arena_cfg_binding(m, "OrtArenaCfg"); // Note: Doesn't expose initial_growth_chunk_sizes_bytes/max_power_of_two_extend_bytes option. // This constructor kept for backwards compatibility, key-value pair constructor overload exposes all options @@ -2677,6 +2721,25 @@ including arg name, arg type (contains both type and shape).)pbdoc") }) .def("get_providers", [](const PyInferenceSession* sess) -> const std::vector& { return sess->GetSessionHandle()->GetRegisteredProviderTypes(); }, py::return_value_policy::reference_internal) .def("get_provider_options", [](const PyInferenceSession* sess) -> const ProviderOptionsMap& { return sess->GetSessionHandle()->GetAllProviderOptions(); }, py::return_value_policy::reference_internal) + .def("get_provider_graph_assignment_info", [](const PyInferenceSession* sess) -> const std::vector& { +#if !defined(ORT_MINIMAL_BUILD) + const auto* inference_session = sess->GetSessionHandle(); + const auto& sess_options = inference_session->GetSessionOptions(); + bool is_enabled = + sess_options.config_options.GetConfigOrDefault(kOrtSessionOptionsRecordEpGraphAssignmentInfo, "0") == "1"; + + if (!is_enabled) { + OrtPybindThrowIfError(ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, !is_enabled, "Session configuration entry '", + kOrtSessionOptionsRecordEpGraphAssignmentInfo, + "' must be set to \"1\" to retrieve EP graph assignment information.")); + } + return inference_session->GetEpGraphAssignmentInfo(); +#else + ORT_UNUSED_PARAMETER(sess); + ORT_THROW("EP graph assignment information is not supported in this build"); +#endif + }, + py::return_value_policy::reference_internal, R"pbdoc(Returns information on the subgraph/nodes assigned to execution providers in the session.)pbdoc") .def_property_readonly("session_options", [](const PyInferenceSession* sess) -> PySessionOptions* { auto session_options = std::make_unique(); session_options->value = sess->GetSessionHandle()->GetSessionOptions(); diff --git a/onnxruntime/python/tools/pytorch_export_contrib_ops.py b/onnxruntime/python/tools/pytorch_export_contrib_ops.py index 1c5e31af99d82..0bd75e5c92e4c 100644 --- a/onnxruntime/python/tools/pytorch_export_contrib_ops.py +++ b/onnxruntime/python/tools/pytorch_export_contrib_ops.py @@ -6,6 +6,7 @@ PyTorch-ONNX exporter (torch.onnx.export). """ +import contextlib import typing try: @@ -22,7 +23,7 @@ _registered_ops: typing.AbstractSet[str] = set() -def _reg(symbolic_fn: typing.Callable, namespace: str = ""): +def _reg(symbolic_fn: typing.Callable, namespace: str = "aten"): name = f"{namespace}::{symbolic_fn.__name__}" torch.onnx.register_custom_op_symbolic(name, symbolic_fn, _OPSET_VERSION) _registered_ops.add(name) @@ -49,13 +50,6 @@ def grid_sampler(g, input, grid, mode, padding_mode, align_corners): padding_mode_str = ["zeros", "border", "reflection"][padding_mode] align_corners = int(symbolic_helper._maybe_get_const(align_corners, "b")) - # From opset v13 onward, the output shape can be specified with - # (N, C, H, W) (N, H_out, W_out, 2) => (N, C, H_out, W_out) - # input_shape = input.type().sizes() - # gird_shape = grid.type().sizes() - # output_shape = input_shape[:2] + gird_shape[1:3] - # g.op(...).setType(input.type().with_sizes(output_shape)) - return g.op( "com.microsoft::GridSample", input, @@ -71,15 +65,24 @@ def inverse(g, self): return g.op("com.microsoft::Inverse", self).setType(self.type()) _reg(inverse) + torch.onnx.register_custom_op_symbolic("aten::linalg_inv", inverse, _OPSET_VERSION) + _registered_ops.add("aten::linalg_inv") + + def gelu(g, self: torch._C.Value, approximate="none"): + # PyTorch can emit aten::gelu with or without the optional approximate arg. + if not isinstance(approximate, str): + approximate = symbolic_helper._maybe_get_const(approximate, "s") - @torch.onnx.symbolic_helper.parse_args("v", "s") - def gelu(g, self: torch._C.Value, approximate: str = "none"): - # Use microsoft::Gelu for performance if possible. It only supports approximate == "none" + # Use microsoft::Gelu for performance if possible. It only supports approximate == "none". if approximate == "none": return g.op("com.microsoft::Gelu", self).setType(self.type()) return torch.onnx.symbolic_opset9.gelu(g, self, approximate) _reg(gelu) + # Some PyTorch versions dispatch GELU symbolic lookup by exporter opset. + # Registering across stable opsets keeps ORT Gelu fusion consistently enabled. + for opset in range(9, 21): + torch.onnx.register_custom_op_symbolic("aten::gelu", gelu, opset) def triu(g, self, diagonal): return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type()) @@ -127,3 +130,8 @@ def unregister(): for version in symbolic_helper._onnx_stable_opsets: if version >= _OPSET_VERSION and symbolic_registry.is_registered_op(kind, namespace, version): del symbolic_registry._registry[(namespace, version)][kind] + + # Also clean up gelu's multi-opset registrations (see register()). + for opset in range(9, 21): + with contextlib.suppress(Exception): + torch.onnx.unregister_custom_op_symbolic("aten::gelu", opset) diff --git a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py index 7d58c1c180822..a28d4a32778fc 100644 --- a/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py +++ b/onnxruntime/python/tools/quantization/fusions/fusion_layernorm.py @@ -33,6 +33,16 @@ def fuse( | | +-------------------------------------------------+ + Or, using Mul instead of Pow: + + +----------------------+ + | | + | v + [Root] --> ReduceMean --> Sub --> Mul --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + (axis=2 or -1) | (in0=in1) (axis=2 or -1) (E-6 or E-12 or 0) ^ + | | + +-------------------------------------------------+ + It also handles cases of duplicated sub nodes exported from older version of PyTorch: +----------------------+ @@ -40,7 +50,7 @@ def fuse( | +-------> Sub-----------------------------------------------+ | | | | | v - [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add + [Root] --> ReduceMean --> Sub --> (Pow or Mul) --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add | ^ | | +----------------------+ @@ -70,10 +80,9 @@ def fuse( div_node, [ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), - ( - ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], - [1, 0, 0, 0, 0, 0], - ), + (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]), + (["Sqrt", "Add", "ReduceMean", "Mul", "Sub"], [1, 0, 0, 0, 0]), + (["Sqrt", "Add", "ReduceMean", "Mul", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]), ], output_name_to_node, ) @@ -90,8 +99,10 @@ def fuse( # Skip fusion since epsilon value is not expected. return - pow_node = parent_nodes[3] - if self.find_constant_input(pow_node, 2.0) != 1: + pow_or_mul_node = parent_nodes[3] + if pow_or_mul_node.op_type == "Pow" and self.find_constant_input(pow_or_mul_node, 2.0) != 1: + return + elif pow_or_mul_node.op_type == "Mul" and pow_or_mul_node.input[0] != pow_or_mul_node.input[1]: return mul_node = input_name_to_nodes[div_node.output[0]][0] diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index 66ab0c44f8814..2017cf154f21e 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -12,7 +12,6 @@ import timeit from datetime import datetime -import coloredlogs import numpy as np from perf_utils import ( acl, @@ -2259,12 +2258,13 @@ def parse_arguments(): def setup_logger(verbose): if verbose: - coloredlogs.install( - level="DEBUG", - fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + logging.basicConfig( + level=logging.DEBUG, + format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + force=True, ) else: - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s", force=True) logging.getLogger("transformers").setLevel(logging.WARNING) diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py index 204fe61396663..7bfe25b1549cf 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py @@ -11,7 +11,6 @@ import pprint import re -import coloredlogs # noqa: F401 from benchmark import * # noqa: F403 from perf_utils import * # noqa: F403 diff --git a/onnxruntime/python/tools/tensorrt/perf/perf_utils.py b/onnxruntime/python/tools/tensorrt/perf/perf_utils.py index 8d2f4b07b7984..4b83e1a8fc41f 100644 --- a/onnxruntime/python/tools/tensorrt/perf/perf_utils.py +++ b/onnxruntime/python/tools/tensorrt/perf/perf_utils.py @@ -5,8 +5,6 @@ import subprocess import sys -import coloredlogs # noqa: F401 - debug = False debug_verbose = False diff --git a/onnxruntime/python/tools/tensorrt/perf/requirements.txt b/onnxruntime/python/tools/tensorrt/perf/requirements.txt index 0afbf47e88307..2a4b319cfc57e 100644 --- a/onnxruntime/python/tools/tensorrt/perf/requirements.txt +++ b/onnxruntime/python/tools/tensorrt/perf/requirements.txt @@ -1,4 +1,3 @@ onnxconverter-common onnxmltools pandas -coloredlogs \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 8055e5e4ae876..56b670e8f2306 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -18,7 +18,6 @@ from time import sleep from typing import Any -import coloredlogs import numpy import torch import transformers @@ -147,12 +146,12 @@ def create_onnxruntime_session( def setup_logger(verbose=True): if verbose: - coloredlogs.install( - level="DEBUG", - fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + logging.basicConfig( + format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + level=logging.DEBUG, ) else: - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s", level=logging.INFO) logging.getLogger("transformers").setLevel(logging.WARNING) diff --git a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py index 9a6388b3f350d..d8177fcd3cb02 100644 --- a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py +++ b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py @@ -7,7 +7,6 @@ import logging import os -import coloredlogs from constants import ( AttentionInputIDs, AttentionOutputIDs, @@ -358,12 +357,12 @@ def _parse_arguments(): def _setup_logger(verbose): if verbose: - coloredlogs.install( - level="DEBUG", - fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + logging.basicConfig( + format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + level=logging.DEBUG, ) else: - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) def main(): diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 08f8691d8b2b5..de7f0a044c118 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -1112,11 +1112,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if ( (mul_val is None) or not (isinstance(mul_val, np.ndarray) and mul_val.size == 1) - or (float(mul_val) >= 0) + or (mul_val.item() >= 0) ): return - if float(mul_val) != -10000: - self.mask_filter_value = float(mul_val) + if mul_val.item() != -10000: + self.mask_filter_value = mul_val.item() if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input: mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) if not is_no_mask_attention else None diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py index 29829a6c475d9..f4d9e28d4ecb2 100644 --- a/onnxruntime/python/tools/transformers/large_model_exporter.py +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -290,6 +290,7 @@ def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tupl input_names=onnx_inp_names, output_names=onnx_out_names, dynamic_axes=onnx_dynamic_axes, + dynamo=False, ) onnx_path.unlink(missing_ok=True) diff --git a/onnxruntime/python/tools/transformers/machine_info.py b/onnxruntime/python/tools/transformers/machine_info.py index 55f71278dd458..f5c7a03fae91c 100644 --- a/onnxruntime/python/tools/transformers/machine_info.py +++ b/onnxruntime/python/tools/transformers/machine_info.py @@ -6,6 +6,7 @@ # It is used to dump machine information for Notebooks import argparse +import importlib.metadata import json import logging import platform @@ -122,10 +123,7 @@ def get_gpu_info_by_nvml(self) -> dict: return result def get_related_packages(self) -> list[str]: - import pkg_resources # noqa: PLC0415 - - installed_packages = pkg_resources.working_set - related_packages = [ + related_packages = { "onnxruntime-gpu", "onnxruntime", "onnx", @@ -137,8 +135,12 @@ def get_related_packages(self) -> list[str]: "flatbuffers", "numpy", "onnxconverter-common", - ] - related_packages_list = {i.key: i.version for i in installed_packages if i.key in related_packages} + } + related_packages_list = {} + for dist in importlib.metadata.distributions(): + if dist.metadata["Name"].lower() in related_packages: + related_packages_list[dist.metadata["Name"].lower()] = dist.version + return related_packages_list def get_onnxruntime_info(self) -> dict: diff --git a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py index a4015f50fdc13..841421a353b07 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py @@ -21,6 +21,7 @@ import os import shutil import sys +import warnings from pathlib import Path import numpy @@ -243,6 +244,13 @@ def get_latency_name(batch_size, sequence_length, past_sequence_length): def main(argv=None, experiment_name: str = "", run_id: str = "0", csv_filename: str = "gpt2_parity_results.csv"): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) + result = {} if version.parse(transformers_version) < version.parse( "3.1.0" diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py index b405c19b04689..0b86d5f038cd8 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py @@ -473,7 +473,7 @@ def export_onnx( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=11, + opset_version=14, do_constant_folding=True, use_external_data_format=True, verbose=verbose, diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index cd8a8756d681e..eccfb46582fbc 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Contents - [LLaMA-2](#llama-2) - [Prerequisites](#prerequisites) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index 6411dca00b5de..17a4ef58914d6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -12,6 +12,7 @@ import subprocess import sys import tempfile +import warnings from itertools import chain import onnx @@ -234,6 +235,7 @@ def run_torchscript_separate_export( opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, + dynamo=False, ) # Check decoder_model.onnx and save all external data to one file @@ -293,6 +295,7 @@ def run_torchscript_separate_export( opset_version=torch_export_onnx_opset_version, do_constant_folding=True, verbose=args.verbose, + dynamo=False, ) # Check decoder_with_past_model.onnx and save all external data to one file @@ -801,6 +804,12 @@ def get_args(): def main(): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) if version.parse(torch.__version__) < version.parse("2.2.0"): logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") return diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py index 21848deaf99fe..674dc831d70f9 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py @@ -11,7 +11,7 @@ # conda create -n gpu_env python=3.8 # conda activate gpu_env # pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113 -# pip3 install onnx transformers onnxruntime-gpu numpy sympy coloredlogs psutil py3nvml +# pip3 install onnx transformers onnxruntime-gpu numpy sympy psutil py3nvml # python benchmark_longformer.py # # When there is no parameter, pre-defined tests will run on the longformer-base-4096 model. diff --git a/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py index b80feec892994..513a115352556 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py @@ -18,7 +18,7 @@ # conda create -n longformer python=3.8 # conda activate longformer # python3 -m pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html -# python3 -m pip install coloredlogs flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0 +# python3 -m pip install flatbuffers numpy packaging sympy protobuf==3.20.1 onnx==1.12.0 transformers==4.18.0 # python3 -m pip install -i https://test.pypi.org/simple/ ort-nightly-gpu # cd ./torch_extensions # rm -rf build diff --git a/onnxruntime/python/tools/transformers/models/phi2/README.md b/onnxruntime/python/tools/transformers/models/phi2/README.md index da62bba0f02fb..eab31680e64c7 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/README.md +++ b/onnxruntime/python/tools/transformers/models/phi2/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Phi2 Optimizations ## Prerequisites A Linux machine for [TorchDynamo-based ONNX Exporter](https://pytorch.org/docs/stable/onnx.html#torchdynamo-based-onnx-exporter)\ diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index dd0accc5dd9e8..ebdb5e32b7184 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -7,6 +7,7 @@ import argparse import logging import os +import warnings from pathlib import Path import onnx @@ -375,6 +376,12 @@ def parse_arguments(): def main(): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) args = parse_arguments() device = torch.device("cuda", args.device_id) if torch.cuda.is_available() else torch.device("cpu") diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 12e6df53de577..4afede881fb93 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Stable Diffusion GPU Optimization ONNX Runtime uses the following optimizations to speed up Stable Diffusion in CUDA: diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index ed2e346972a6c..e90af970032e5 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -5,14 +5,13 @@ import argparse import csv +import logging import os import statistics import sys import time from pathlib import Path -import coloredlogs - # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. import torch from benchmark_helper import measure_memory @@ -1332,7 +1331,7 @@ def main(): if version.parse(ort_version) < version.parse("1.16"): raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later") - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO, force=True) memory_monitor_type = "cuda" diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py index a3caba138f44a..d851e785e8d84 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -20,7 +20,8 @@ # limitations under the License. # -------------------------------------------------------------------------- -import coloredlogs +import logging + from cuda import cudart from demo_utils import ( add_controlnet_arguments, @@ -86,7 +87,7 @@ def run_inference(warmup=False): if __name__ == "__main__": - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) parser = arg_parser("Options for Stable Diffusion Demo") add_controlnet_arguments(parser) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py index c3e91a405b53f..739f3cb5025e7 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -20,7 +20,8 @@ # limitations under the License. # -------------------------------------------------------------------------- -import coloredlogs +import logging + from cuda import cudart from demo_utils import ( add_controlnet_arguments, @@ -252,7 +253,7 @@ def main(args): if __name__ == "__main__": - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) parser = arg_parser("Options for Stable Diffusion XL Demo") add_controlnet_arguments(parser) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index eb4d7242f72fc..25c034f7b70b5 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -20,9 +20,9 @@ import os import shutil import tempfile +import warnings from pathlib import Path -import coloredlogs import onnx from fusion_options import FusionOptions from onnx_model_clip import ClipOnnxModel @@ -569,6 +569,12 @@ def parse_arguments(argv: list[str] | None = None): def main(argv: list[str] | None = None): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) args = parse_arguments(argv) logger.info("Arguments: %s", str(args)) @@ -580,5 +586,5 @@ def main(argv: list[str] | None = None): if __name__ == "__main__": - coloredlogs.install(fmt="%(funcName)20s: %(message)s") + logging.basicConfig(format="%(funcName)20s: %(message)s", level=logging.INFO) main() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt index 73929214b22ea..e7852f7478db8 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements/requirements.txt @@ -4,7 +4,6 @@ transformers==4.50.0 numpy>=1.24.1 accelerate onnx==1.18.0 -coloredlogs packaging # Use newer version of protobuf might cause crash protobuf==4.25.8 diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 9056ac07cc286..44a041d789b5d 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -1,3 +1,5 @@ +> **Deprecated:** This example is deprecated. Use the Olive recipes instead: https://github.com/microsoft/olive-recipes/tree/main + # Whisper ## Prerequisites diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index 79b508047da55..93b509eec6982 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -7,6 +7,7 @@ import argparse import logging import os +import warnings import onnx import torch @@ -493,6 +494,12 @@ def export_onnx_models( def main(argv=None): + warnings.warn( + "This example is deprecated. Use the Olive recipe instead: " + "https://github.com/microsoft/olive-recipes/tree/main", + DeprecationWarning, + stacklevel=2, + ) args = parse_arguments(argv) setup_logger(args.verbose) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py index e10e616d35d38..31fb60f86faf1 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_decoder.py @@ -391,8 +391,9 @@ def export_onnx( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=17, + opset_version=18, do_constant_folding=True, + dynamo=False, verbose=verbose, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py index 851f641442016..48d4e12a38a43 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder.py @@ -110,8 +110,9 @@ def export_onnx( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=17, + opset_version=18, do_constant_folding=True, + dynamo=False, verbose=verbose, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py index cd81edc1001be..35ec59b2bca69 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_encoder_decoder_init.py @@ -293,8 +293,9 @@ def export_onnx( input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=17, + opset_version=18, do_constant_folding=True, + dynamo=False, verbose=verbose, ) diff --git a/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb b/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb index 5e81e754e1109..6603c9c387517 100644 --- a/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb +++ b/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb @@ -52,7 +52,7 @@ "else:\n", " !{sys.executable} -m pip install install torch --index-url https://download.pytorch.org/whl/cpu -q\n", "\n", - "!{sys.executable} -m pip install onnxruntime transformers==4.18 onnx psutil pandas py-cpuinfo py3nvml netron coloredlogs --no-warn-script-location -q" + "!{sys.executable} -m pip install onnxruntime transformers==4.18 onnx psutil pandas py-cpuinfo py3nvml netron --no-warn-script-location -q" ] }, { @@ -719,4 +719,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb b/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb index 7295ae1436c99..76458ca3220c9 100644 --- a/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb +++ b/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb @@ -59,7 +59,7 @@ "\n", "if sys.platform in ['linux', 'win32']: # Linux or Windows\n", " !{sys.executable} -m pip install torch --index-url https://download.pytorch.org/whl/cu118 -q\n", - " !{sys.executable} -m pip install onnxruntime-gpu onnx transformers psutil pandas py-cpuinfo py3nvml coloredlogs wget netron sympy protobuf==3.20.3 -q\n", + " !{sys.executable} -m pip install onnxruntime-gpu onnx transformers psutil pandas py-cpuinfo py3nvml wget netron sympy protobuf==3.20.3 -q\n", "else: # Mac\n", " print(\"CUDA is not available on MacOS\")" ] @@ -196,9 +196,9 @@ "Some weights of the model checkpoint at bert-large-uncased-whole-word-masking-finetuned-squad were not used when initializing BertForQuestionAnswering: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']\n", "- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", - "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 48/48 [00:02<00:00, 16.27it/s]\n", - "convert squad examples to features: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:03<00:00, 256.11it/s]\n", - "add example index and unique id: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:00= 1.8 numpy >= 1.19.0 -coloredlogs psutil py-cpuinfo py3nvml diff --git a/onnxruntime/python/tools/transformers/run_benchmark.sh b/onnxruntime/python/tools/transformers/run_benchmark.sh index 25997f40d348f..c16d60d0d5046 100755 --- a/onnxruntime/python/tools/transformers/run_benchmark.sh +++ b/onnxruntime/python/tools/transformers/run_benchmark.sh @@ -95,7 +95,7 @@ if [ "$run_install" = true ] ; then else pip install onnxruntime-gpu fi - pip install --upgrade onnx coloredlogs packaging psutil py3nvml numpy transformers sympy + pip install --upgrade onnx packaging psutil py3nvml numpy transformers sympy fi if [ "$use_package" = true ] ; then diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index 66f24c47f6cdb..a8c2ad1967acb 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -49,6 +49,7 @@ def torch_onnx_export( keep_initializers_as_inputs=keep_initializers_as_inputs, custom_opsets=custom_opsets, export_modules_as_functions=export_modules_as_functions, + dynamo=False, ) else: torch.onnx.export( diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc index bce9b59ff0ea4..76b2502da5c3c 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.cc @@ -15,117 +15,97 @@ #include "ep_factory.h" #include "ep_stream_support.h" -/// -/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. -/// -struct MulKernel { - MulKernel(const OrtApi& ort_api, const OrtLogger& logger, - const std::unordered_map& float_initializers, - std::string input0_name, std::string input1_name) - : ort_api(ort_api), - logger(logger), - float_initializers(float_initializers), - input0_name(input0_name), - input1_name(input1_name) {} - - const FloatInitializer* TryGetSavedInitializer(const std::string& name) const { - auto iter = float_initializers.find(name); - return iter != float_initializers.end() ? &iter->second : nullptr; - } +const FloatInitializer* MulKernel::TryGetSavedInitializer(const std::string& name) const { + auto iter = float_initializers.find(name); + return iter != float_initializers.end() ? &iter->second : nullptr; +} - void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, - /*out*/ gsl::span& data, - /*out*/ std::vector& shape) const { - Ort::ConstValue input = kernel_context.GetInput(index); - auto type_shape = input.GetTensorTypeAndShapeInfo(); +void MulKernel::GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const { + Ort::ConstValue input = kernel_context.GetInput(index); + auto type_shape = input.GetTensorTypeAndShapeInfo(); - ONNXTensorElementDataType elem_type = type_shape.GetElementType(); - if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) - throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); + ONNXTensorElementDataType elem_type = type_shape.GetElementType(); + if (elem_type != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) + throw Ort::Exception("EP Expected float32 inputs", ORT_EP_FAIL); - const float* float_data = input.GetTensorData(); - size_t num_elems = type_shape.GetElementCount(); - data = gsl::span(float_data, num_elems); - shape = type_shape.GetShape(); - } - - OrtStatus* Compute(OrtKernelContext* kernel_ctx) { - RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, - OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, - "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); - Ort::KernelContext kernel_context(kernel_ctx); - try { - gsl::span input0; - gsl::span input1; - std::vector shape0; - std::vector shape1; - - size_t num_inputs = kernel_context.GetInputCount(); - - if (num_inputs == 2) { - // Both inputs are non-constant. Get them from ORT's KernelContext. - GetInputDataAndShape(kernel_context, 0, input0, shape0); - GetInputDataAndShape(kernel_context, 1, input1, shape1); - } else if (num_inputs == 1) { - // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. - // Get the constant input from the initializers saved by the EP. - // Refer to "NodeFusionOptions_DropConstantInitializers()". - - if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input1, shape1); - input0 = gsl::span(const_input0->data); - shape0 = const_input0->shape; - } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { - GetInputDataAndShape(kernel_context, 0, input0, shape0); - input1 = gsl::span(const_input1->data); - shape1 = const_input1->shape; - } - } else { - // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) - // are disabled. - const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); - const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); - RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, - "Expected 2 initializer inputs to be saved by EP"); + const float* float_data = input.GetTensorData(); + size_t num_elems = type_shape.GetElementCount(); + data = gsl::span(float_data, num_elems); + shape = type_shape.GetShape(); +} +OrtStatus* MulKernel::Compute(OrtKernelContext* kernel_ctx) { + RETURN_IF_ERROR(ort_api.Logger_LogMessage(&logger, + OrtLoggingLevel::ORT_LOGGING_LEVEL_INFO, + "MulKernel::Compute", ORT_FILE, __LINE__, __FUNCTION__)); + Ort::KernelContext kernel_context(kernel_ctx); + try { + gsl::span input0; + gsl::span input1; + std::vector shape0; + std::vector shape1; + + size_t num_inputs = kernel_context.GetInputCount(); + + if (num_inputs == 2) { + // Both inputs are non-constant. Get them from ORT's KernelContext. + GetInputDataAndShape(kernel_context, 0, input0, shape0); + GetInputDataAndShape(kernel_context, 1, input1, shape1); + } else if (num_inputs == 1) { + // ORT is only providing one non-constant input because this EP chose not to request constant initializer inputs. + // Get the constant input from the initializers saved by the EP. + // Refer to "NodeFusionOptions_DropConstantInitializers()". + + if (const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); const_input0 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input1, shape1); input0 = gsl::span(const_input0->data); - input1 = gsl::span(const_input1->data); shape0 = const_input0->shape; + } else if (const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); const_input1 != nullptr) { + GetInputDataAndShape(kernel_context, 0, input0, shape0); + input1 = gsl::span(const_input1->data); shape1 = const_input1->shape; } + } else { + // Both inputs are constant. Should never happen unless all ORT optimizations (specifically constant-folding) + // are disabled. + const FloatInitializer* const_input0 = TryGetSavedInitializer(input0_name); + const FloatInitializer* const_input1 = TryGetSavedInitializer(input1_name); + RETURN_IF(const_input0 == nullptr || const_input1 == nullptr, ort_api, + "Expected 2 initializer inputs to be saved by EP"); + + input0 = gsl::span(const_input0->data); + input1 = gsl::span(const_input1->data); + shape0 = const_input0->shape; + shape1 = const_input1->shape; + } - if (shape0 != shape1) { - throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); - } + if (shape0 != shape1) { + throw Ort::Exception("Expected same dimensions for both inputs", ORT_INVALID_ARGUMENT); + } - size_t num_outputs = kernel_context.GetOutputCount(); - if (num_outputs != 1) { - throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); - } + size_t num_outputs = kernel_context.GetOutputCount(); + if (num_outputs != 1) { + throw Ort::Exception("Expected 1 output for MulKernel", ORT_INVALID_ARGUMENT); + } - auto output = kernel_context.GetOutput(0, shape0); - float* output_data = output.GetTensorMutableData(); + auto output = kernel_context.GetOutput(0, shape0); + float* output_data = output.GetTensorMutableData(); - for (size_t i = 0; i < input0.size(); ++i) { - output_data[i] = input0[i] * input1[i]; - } - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status.release(); - } catch (const std::exception& ex) { - Ort::Status status(ex.what(), ORT_EP_FAIL); - return status.release(); + for (size_t i = 0; i < input0.size(); ++i) { + output_data[i] = input0[i] * input1[i]; } - - return nullptr; + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status.release(); + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status.release(); } - const OrtApi& ort_api; - const OrtLogger& logger; - const std::unordered_map& float_initializers; - std::string input0_name; - std::string input1_name; -}; + return nullptr; +} /// /// Example OrtNodeComputeInfo that represents the computation function for a compiled OrtGraph. @@ -230,6 +210,7 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG for (const auto& node : nodes) { auto op_type = node.GetOperatorType(); + auto domain = node.GetDomain(); if (op_type == "Mul") { // Check that Mul has inputs/output of type float @@ -262,6 +243,8 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG supported_nodes.push_back(node); // Only support a single Mul for now. break; + } else if (op_type == "Custom_Mul" && domain == "test") { + supported_nodes.push_back(node); } } @@ -269,19 +252,26 @@ OrtStatus* ORT_API_CALL ExampleEp::GetCapabilityImpl(OrtEp* this_ptr, const OrtG return nullptr; } - // Create (optional) fusion options for the supported nodes to fuse. - OrtNodeFusionOptions node_fusion_options = {}; - node_fusion_options.ort_version_supported = ORT_API_VERSION; - - // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers - // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. - // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use - // during inference. - node_fusion_options.drop_constant_initializers = true; - RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, - reinterpret_cast(supported_nodes.data()), - supported_nodes.size(), - &node_fusion_options)); + if (supported_nodes[0].GetOperatorType() == "Mul") { + // Create (optional) fusion options for the supported nodes to fuse. + OrtNodeFusionOptions node_fusion_options = {}; + node_fusion_options.ort_version_supported = ORT_API_VERSION; + + // Set "drop constant initializers" to true if the compiling EP doesn't need ORT to provide constant initializers + // as inputs to the fused/compiled node at inference time. This allows ORT to release unused initializers. + // This example EP sets this to true and saves initializers during the call to OrtEp::Compile for use + // during inference. + node_fusion_options.drop_constant_initializers = true; + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddNodesToFuse(graph_support_info, + reinterpret_cast(supported_nodes.data()), + supported_nodes.size(), + &node_fusion_options)); + } else if (supported_nodes[0].GetOperatorType() == "Custom_Mul") { + // Calls EpGraphSupportInfo_AddSingleNode() to inform ORT that the custom node should NOT be fused or compiled, + // as CustomMul has the concrete kernel implementation. + RETURN_IF_ERROR(ep->ep_api.EpGraphSupportInfo_AddSingleNode(graph_support_info, supported_nodes[0])); + } + } catch (const Ort::Exception& ex) { Ort::Status status(ex); return status.release(); diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h index 7e96a523cf285..5d4788ed76bf2 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep.h @@ -8,7 +8,34 @@ #include "../plugin_ep_utils.h" class ExampleEpFactory; -struct MulKernel; + +/// +/// Example implementation of ONNX Mul. Does not handle many things like broadcasting. +/// +struct MulKernel { + MulKernel(const OrtApi& ort_api, const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, std::string input1_name) + : ort_api(ort_api), + logger(logger), + float_initializers(float_initializers), + input0_name(input0_name), + input1_name(input1_name) {} + + const FloatInitializer* TryGetSavedInitializer(const std::string& name) const; + + void GetInputDataAndShape(Ort::KernelContext kernel_context, size_t index, + /*out*/ gsl::span& data, + /*out*/ std::vector& shape) const; + + OrtStatus* Compute(OrtKernelContext* kernel_ctx); + + const OrtApi& ort_api; + const OrtLogger& logger; + const std::unordered_map& float_initializers; + std::string input0_name; + std::string input1_name; +}; /// /// Example EP that can compile a single Mul operator. diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h new file mode 100644 index 0000000000000..c37038a727067 --- /dev/null +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_custom_op.h @@ -0,0 +1,93 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "onnxruntime_c_api.h" +#include "ep.h" + +// Plugin EPs can provide two types of custom ops: +// +// 1. A full OrtCustomOp with a concrete kernel implementation +// - This Example EP demonstrates this approach. +// - In GetCapability(), it calls EpGraphSupportInfo_AddSingleNode() to inform ORT +// that the custom node should NOT be fused or compiled. Instead, ORT should invoke +// the custom node's Compute() function at runtime. +// +// 2. A "placeholder" OrtCustomOp with an empty kernel implementation +// - A compile-based Plugin EP can supply an OrtCustomOp whose CustomKernel::Compute() +// does nothing. The purpose is to satisfy model validation during model loading by +// registering the custom op as a valid operator in the session. +// - In GetCapability(), the EP should call EpGraphSupportInfo_AddNodesToFuse() to +// notify ORT that this custom node should be fused and compiled by the EP. +// - In Compile(), the EP executes its compiled bits to perform inference for +// the fused custom node. +// +// Note: Approach #2 is suitable for plugin TRT RTX EP to support TRT plugins. + +struct CustomMulKernel : MulKernel { + CustomMulKernel(const OrtApi& ort_api, + const OrtLogger& logger, + const std::unordered_map& float_initializers, + std::string input0_name, + std::string input1_name) : MulKernel(ort_api, logger, float_initializers, + input0_name, input1_name) { + } + + OrtStatusPtr ComputeV2(OrtKernelContext* kernel_ctx) { + return MulKernel::Compute(kernel_ctx); + } +}; + +struct ExampleEpCustomOp : Ort::CustomOpBase { + explicit ExampleEpCustomOp(const char* provider, ExampleEpFactory* factory) : provider_(provider), + factory_(factory) { + } + + OrtStatusPtr CreateKernelV2(const OrtApi& api, const OrtKernelInfo* info, void** op_kernel) const; + + OrtStatusPtr KernelComputeV2(void* op_kernel, OrtKernelContext* context) const; + + const char* GetName() const { return name_; }; + + void SetName(const char* name) { name_ = name; }; + + const char* GetExecutionProviderType() const { return provider_; }; + + size_t GetInputTypeCount() const { return num_inputs_; }; + + void SetInputTypeCount(size_t num) { num_inputs_ = num; }; + + ONNXTensorElementDataType GetInputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + size_t GetOutputTypeCount() const { return num_outputs_; }; + + void SetOutputTypeCount(size_t num) { num_outputs_ = num; }; + + ONNXTensorElementDataType GetOutputType(size_t /*index*/) const { return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; }; + + OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t) const { + return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC; + }; + + bool GetVariadicInputHomogeneity() const { + return false; // heterogenous + } + + bool GetVariadicOutputHomogeneity() const { + return false; // heterogeneous + } + + private: + const char* provider_ = nullptr; + const char* name_ = nullptr; + size_t num_inputs_ = 1; // set to 1 to match with default min_arity for variadic input + size_t num_outputs_ = 1; // set to 1 to match with default min_arity for variadic output + ExampleEpFactory* factory_ = nullptr; + std::unordered_map float_initializers_; +}; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc index 5230064138d03..198f98243af2f 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.cc @@ -65,7 +65,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( // For testing purposes, we simulate this by allocating CPU memory // that mirrors the size of the external allocation. - auto* handle = new (std::nothrow) ExampleExternalMemoryHandle(); + auto* handle = new (std::nothrow) ExampleExternalMemoryHandle(*desc); if (handle == nullptr) { return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external memory handle"); } @@ -74,10 +74,6 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportMemoryImpl( size_t effective_size = desc->size_bytes - desc->offset_bytes; handle->simulated_ptr = std::make_unique(effective_size); - handle->size_bytes = desc->size_bytes; - handle->offset_bytes = desc->offset_bytes; - handle->handle_type = desc->handle_type; - *out_handle = handle; return nullptr; } @@ -132,7 +128,7 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::CreateTensorFromMemoryI // 1. Calculate actual tensor size from shape + element_type // 2. Validate it fits within available memory region // 3. Use that validated size rather than subtracting offsets - size_t buffer_size = handle->size_bytes - handle->offset_bytes - tensor_desc->offset_bytes; + size_t buffer_size = handle->descriptor.size_bytes - handle->descriptor.offset_bytes - tensor_desc->offset_bytes; // Create tensor with pre-allocated memory status = impl.apis_.ort_api.CreateTensorWithDataAsOrtValue( @@ -180,12 +176,11 @@ OrtStatus* ORT_API_CALL ExampleExternalResourceImporter::ImportSemaphoreImpl( // // For testing purposes, we create a simulated semaphore using an atomic counter - auto* handle = new (std::nothrow) ExampleExternalSemaphoreHandle(); + auto* handle = new (std::nothrow) ExampleExternalSemaphoreHandle(*desc); if (handle == nullptr) { return impl.apis_.ort_api.CreateStatus(ORT_FAIL, "Failed to allocate external semaphore handle"); } - handle->type = desc->type; handle->value.store(0); *out_handle = handle; diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h index 4721367c68963..e33fbcc25f826 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_external_resource_importer.h @@ -19,14 +19,12 @@ struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { std::unique_ptr simulated_ptr; ///< Simulated mapped pointer (CPU memory for testing) - ExampleExternalMemoryHandle() + ExampleExternalMemoryHandle(const OrtExternalMemoryDescriptor& descriptor_in) : simulated_ptr(nullptr) { // Initialize base struct fields version = ORT_API_VERSION; ep_device = nullptr; - handle_type = ORT_EXTERNAL_MEMORY_HANDLE_TYPE_D3D12_RESOURCE; - size_bytes = 0; - offset_bytes = 0; + descriptor = descriptor_in; Release = ReleaseCallback; } @@ -48,12 +46,12 @@ struct ExampleExternalMemoryHandle : OrtExternalMemoryHandle { struct ExampleExternalSemaphoreHandle : OrtExternalSemaphoreHandle { std::atomic value; ///< Simulated fence value for testing - ExampleExternalSemaphoreHandle() + ExampleExternalSemaphoreHandle(const OrtExternalSemaphoreDescriptor& descriptor_in) : value(0) { // Initialize base struct fields version = ORT_API_VERSION; ep_device = nullptr; - type = ORT_EXTERNAL_SEMAPHORE_D3D12_FENCE; + descriptor = descriptor_in; Release = ReleaseCallback; } diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc index 79ec3fe3a3780..c56f0f74ab74a 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.cc @@ -40,6 +40,9 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL CreateExternalResourceImporterForDevice = CreateExternalResourceImporterForDeviceImpl; + GetNumCustomOpDomains = GetNumCustomOpDomainsImpl; + GetCustomOpDomains = GetCustomOpDomainsImpl; + // setup the OrtMemoryInfo instances required by the EP. // We pretend the device the EP is running on is GPU. default_memory_info_ = Ort::MemoryInfo{"ExampleEP GPU", @@ -71,6 +74,22 @@ ExampleEpFactory::ExampleEpFactory(const char* ep_name, ApiPtrs apis, const OrtL OrtDeviceMemoryType_HOST_ACCESSIBLE, /*alignment*/ 0, OrtAllocatorType::OrtDeviceAllocator}; + // Custom Op Domains + custom_op_domains_[0] = Ort::CustomOpDomain{"test"}; + custom_op_domains_[1] = Ort::CustomOpDomain{"test2"}; + + std::vector> created_custom_op_list; + created_custom_op_list.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list.back().get()->SetName("Custom_Mul"); + custom_op_domains_[0].Add(created_custom_op_list.back().get()); + + std::vector> created_custom_op_list_2; + created_custom_op_list_2.push_back(std::make_unique(ep_name_.c_str(), this)); + created_custom_op_list_2.back().get()->SetName("Custom_Mul2"); + custom_op_domains_[1].Add(created_custom_op_list_2.back().get()); + + created_custom_op_lists_[0] = std::move(created_custom_op_list); + created_custom_op_lists_[1] = std::move(created_custom_op_list_2); } /*static*/ @@ -313,6 +332,48 @@ OrtStatus* ORT_API_CALL ExampleEpFactory::CreateSyncStreamForDeviceImpl(OrtEpFac } /*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Out_ size_t* num_domains) noexcept { + auto* factory = static_cast(this_ptr); + *num_domains = factory->custom_op_domains_.size(); + + return nullptr; +} + +/*static*/ +OrtStatus* ORT_API_CALL ExampleEpFactory::GetCustomOpDomainsImpl( + OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept { + auto* factory = static_cast(this_ptr); + + // The `num_domains` should be 2 as ORT calls GetNumCustomOpDomainsImpl() to get the number prior to + // call this function. + gsl::span domains_span(domains, num_domains); + domains_span[0] = factory->custom_op_domains_[0]; + domains_span[1] = factory->custom_op_domains_[1]; + + return nullptr; +} + +OrtStatusPtr ExampleEpCustomOp::CreateKernelV2(const OrtApi& /*api*/, + const OrtKernelInfo* /*info*/, + void** op_kernel) const { + std::string node_input_0 = "X"; + std::string node_input_1 = "W"; + auto custom_kernel_op = std::make_unique(factory_->ort_api, + factory_->default_logger_, + float_initializers_, + node_input_0, + node_input_1); + *op_kernel = custom_kernel_op.release(); + return nullptr; +} + +OrtStatusPtr ExampleEpCustomOp::KernelComputeV2(void* op_kernel, OrtKernelContext* context) const { + return static_cast(op_kernel)->ComputeV2(context); +} + OrtStatus* ORT_API_CALL ExampleEpFactory::GetHardwareDeviceIncompatibilityDetailsImpl( OrtEpFactory* this_ptr, const OrtHardwareDevice* hw, diff --git a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h index 9306b0fc88ec9..244051dd5e4d0 100644 --- a/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h +++ b/onnxruntime/test/autoep/library/example_plugin_ep/ep_factory.h @@ -9,6 +9,8 @@ #include "ep_data_transfer.h" #include "ep_external_resource_importer.h" #include "../plugin_ep_utils.h" +#include "ep.h" +#include "ep_custom_op.h" /// /// Example EP factory that can create an OrtEp and return information about the supported hardware devices. @@ -26,6 +28,8 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { return arena_allocator_.get(); } + const OrtLogger& default_logger_; // default logger for the EP factory + private: static const char* ORT_API_CALL GetNameImpl(const OrtEpFactory* this_ptr) noexcept; @@ -78,7 +82,13 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { const OrtHardwareDevice* hw, OrtDeviceEpIncompatibilityDetails* details) noexcept; - const OrtLogger& default_logger_; // default logger for the EP factory + static OrtStatus* ORT_API_CALL GetNumCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Out_ size_t* num_domains) noexcept; + + static OrtStatus* ORT_API_CALL GetCustomOpDomainsImpl(OrtEpFactory* this_ptr, + _Outptr_result_maybenull_ OrtCustomOpDomain** domains, + _Out_ size_t num_domains) noexcept; + const std::string ep_name_; // EP name const std::string vendor_{"Contoso"}; // EP vendor name const uint32_t vendor_id_{0xB357}; // EP vendor ID @@ -94,4 +104,7 @@ class ExampleEpFactory : public OrtEpFactory, public ApiPtrs { std::mutex mutex_; // mutex to protect arena_allocator_ and num_arena_users_ std::unique_ptr data_transfer_impl_; // data transfer implementation for this factory + + std::vector custom_op_domains_{2}; + std::vector>> created_custom_op_lists_{2}; }; diff --git a/onnxruntime/test/autoep/test_execution.cc b/onnxruntime/test/autoep/test_execution.cc index 0970654b48ca1..a3cca42d81c6e 100644 --- a/onnxruntime/test/autoep/test_execution.cc +++ b/onnxruntime/test/autoep/test_execution.cc @@ -7,6 +7,7 @@ #include #include +#include "core/graph/constants.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -48,6 +49,35 @@ void RunMulModelWithPluginEp(const Ort::SessionOptions& session_options) { EXPECT_THAT(output_span, ::testing::ElementsAre(2, 4, 6, 8, 10, 12)); } +void RunCustomMulModelWithPluginEp(const Ort::SessionOptions& session_options) { + Ort::Session session(*ort_env, ORT_TSTR("testdata/custom_mul.onnx"), session_options); + + // Create two inputs with same values + Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + std::vector shape = {3, 2}; + std::vector input0_data{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + std::vector ort_inputs; + std::vector ort_input_names; + + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_inputs.emplace_back(Ort::Value::CreateTensor( + memory_info, input0_data.data(), input0_data.size(), shape.data(), shape.size())); + ort_input_names.push_back("X"); + ort_input_names.push_back("W"); + + // Run session and get outputs + std::array output_names{"Y"}; + std::vector ort_outputs = session.Run(Ort::RunOptions{nullptr}, ort_input_names.data(), ort_inputs.data(), + ort_inputs.size(), output_names.data(), output_names.size()); + + // Check expected output values + Ort::Value& ort_output = ort_outputs[0]; + const float* output_data = ort_output.GetTensorData(); + gsl::span output_span(output_data, 6); + EXPECT_THAT(output_span, ::testing::ElementsAre(1, 4, 9, 16, 25, 36)); +} + void RunSqueezeMulReluModel(const Ort::SessionOptions& session_options) { Ort::Session session(*ort_env, ORT_TSTR("testdata/squeeze_mul_relu.onnx"), session_options); @@ -219,10 +249,17 @@ void RunScanMulModel(const Ort::SessionOptions& session_options) { EXPECT_THAT(output_span, ::testing::ElementsAre(2.f, 4.f, 6.f, 20.f, 40.f, 60.f, 200.f, 400.f, 600.f)); } -void RunPartiallySupportedModelWithPluginEp(const Ort::SessionOptions& session_options) { +using CheckEpNodeAssignmentFunc = std::function; + +void RunAddMulAddModel(const Ort::SessionOptions& session_options, + CheckEpNodeAssignmentFunc check_ep_node_assignment_func = {}) { // This model has Add -> Mul -> Add. The example plugin EP supports Mul but not Add. Ort::Session session(*ort_env, ORT_TSTR("testdata/add_mul_add.onnx"), session_options); + if (check_ep_node_assignment_func) { + ASSERT_NO_FATAL_FAILURE(check_ep_node_assignment_func(session)); + } + // Create inputs Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); std::vector shape = {3, 2}; @@ -289,9 +326,46 @@ TEST(OrtEpLibrary, PluginEp_AppendV2_PartiallySupportedModelInference) { // Create session with example plugin EP Ort::SessionOptions session_options; std::unordered_map ep_options; + + // Create session that enables recording of EP-graph assignment info + session_options.AddConfigEntry(kOrtSessionOptionsRecordEpGraphAssignmentInfo, "1"); session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); - RunPartiallySupportedModelWithPluginEp(session_options); + // Function that checks the ep graph/node assignment (Mul on plugin EP, others on CPU EP). + // Model has 3 subgraphs (in no particular order): + // - Subgraph 1: Add assigned to CPU EP. + // - Subgraph 2: Mul assigned to plugin EP. + // - Subgraph 3: Add assigned to CPU EP. + auto check_ep_node_assignment = [](const Ort::Session& session) -> void { + std::vector ep_subgraphs = session.GetEpGraphAssignmentInfo(); + ASSERT_EQ(ep_subgraphs.size(), 3); + + for (Ort::ConstEpAssignedSubgraph ep_subgraph : ep_subgraphs) { + std::string ep_name = ep_subgraph.GetEpName(); + ASSERT_TRUE(ep_name == Utils::example_ep_info.ep_name || ep_name == kCpuExecutionProvider); + + const std::vector ep_nodes = ep_subgraph.GetNodes(); + + ASSERT_GE(ep_nodes.size(), 1); // All of these subgraphs just have one node. + std::string domain = ep_nodes[0].GetDomain(); + std::string op_type = ep_nodes[0].GetOperatorType(); + std::string node_name = ep_nodes[0].GetName(); + + ASSERT_EQ(domain, kOnnxDomain); // All node ops should have the ONNX domain + + // Check that CPU EP has the Adds and that the example EP has the Mul. + if (ep_name == kCpuExecutionProvider) { + ASSERT_EQ(op_type, "Add"); + ASSERT_TRUE(node_name == "add_0" || node_name == "add_1"); + } else { + ASSERT_TRUE(ep_name == Utils::example_ep_info.ep_name); + ASSERT_EQ(op_type, "Mul"); + ASSERT_EQ(node_name, "mul_0"); + } + } + }; + + RunAddMulAddModel(session_options, check_ep_node_assignment); } // Generate an EPContext model with a plugin EP. @@ -551,6 +625,36 @@ TEST(OrtEpLibrary, KernelPluginEp_ControlFlow_Scan) { } } +// Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. +// Uses AppendExecutionProvider_V2 to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_Custom_Op_Inference_With_Explicit_Ep) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + // Create session with example plugin EP + Ort::SessionOptions session_options; + std::unordered_map ep_options; + session_options.AppendExecutionProvider_V2(*ort_env, {plugin_ep_device}, ep_options); + + RunCustomMulModelWithPluginEp(session_options); +} + +// Creates a session with the example plugin EP and runs a model with a single Costom_Mul node. +// Uses the PREFER_CPU policy to append the example plugin EP to the session. +TEST(OrtEpLibrary, PluginEp_Custom_Op_Inference_With_Prefer_Cpu) { + RegisteredEpDeviceUniquePtr example_ep; + ASSERT_NO_FATAL_FAILURE(Utils::RegisterAndGetExampleEp(*ort_env, Utils::example_ep_info, example_ep)); + Ort::ConstEpDevice plugin_ep_device(example_ep.get()); + + { + // PREFER_CPU pick our example EP over ORT CPU EP. TODO: Actually assert this. + Ort::SessionOptions session_options; + session_options.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_CPU); + RunCustomMulModelWithPluginEp(session_options); + } +} + // Tests the GetHardwareDeviceEpIncompatibilityDetails C API with the example plugin EP. // The example plugin EP supports CPU devices, so this test verifies that a CPU device // is reported as compatible (reasons_bitmask == 0). @@ -646,6 +750,5 @@ TEST(OrtEpLibrary, PluginEp_GpuDevice_ReturnsInCompatible) { api->ReleaseDeviceEpIncompatibilityDetails(details); } - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/common/random_generator.h b/onnxruntime/test/common/random_generator.h index 7c95b4d10a872..1fd3f8eb76e5f 100644 --- a/onnxruntime/test/common/random_generator.h +++ b/onnxruntime/test/common/random_generator.h @@ -15,6 +15,7 @@ #include "core/common/type_utils.h" #include "core/common/float16.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "test/util/include/test_random_seed.h" namespace onnxruntime { @@ -126,6 +127,22 @@ class RandomValueGenerator { return data; } + template + typename std::enable_if< + std::is_same_v || std::is_same_v, + std::vector>::type + Uniform(gsl::span dims, TInt2 min, TInt2 max) { + using UnpackedType = typename TInt2::UnpackedType; + std::vector data_int8 = Uniform(dims, min.GetElem(0), max.GetElem(0)); + std::vector data(TInt2::CalcNumInt2Quads(data_int8.size())); + for (size_t i = 0; i < data_int8.size(); i++) { + size_t r = i >> 2; + size_t c = i & 0x3; + data[r].SetElem(c, data_int8[i]); + } + return data; + } + // Gaussian distribution for float template typename std::enable_if< diff --git a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc index 3aafd413486c1..fb64d6fa9b66d 100644 --- a/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc +++ b/onnxruntime/test/contrib_ops/dynamic_quantize_matmul_test.cc @@ -10,6 +10,7 @@ #include "test/unittest_util/graph_transform_test_builder.h" #include "test/util/include/default_providers.h" #include "core/util/qmath.h" +#include "core/mlas/lib/mlasi.h" // for MLAS_CPUIDINFO #include #include @@ -263,6 +264,165 @@ TEST(DynamicQuantizeMatMul, UInt8_test_with_empty_input) { test.Run(); } +#if defined(USE_KLEIDIAI) + +static bool HasArmSME() { + return (MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME() || MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2()); +} + +// Helper to build a tiny 2x3×4 case we reuse. +struct KleidiDynMatMulData { + static constexpr int64_t M = 2; + static constexpr int64_t K = 4; + static constexpr int64_t N = 3; + + std::vector a = { + 1.f, 2.f, 3.f, 4.f, + -1.f, -2.f, -3.f, -4.f}; + std::vector b = { + 1, 0, -1, + 2, -1, 0, + 0, 1, 2, + -2, 0, 1}; + std::vector b_scale = {0.5f, 0.25f, 0.125f}; + std::vector b_zp = {0, 0, 0}; + + std::vector Reference(float bias0, float bias1, float bias2) const { + std::vector out(M * N, 0.f); + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.f; + for (int64_t k = 0; k < K; ++k) { + const float b_val = (static_cast(b[k * N + n]) - b_zp[n]) * b_scale[n]; + sum += a[m * K + k] * b_val; + } + const float bias = (n == 0 ? bias0 : n == 1 ? bias1 + : bias2); + out[m * N + n] = sum + bias; + } + } + return out; + } + + std::vector Reference3D(float bias0, float bias1, float bias2, int64_t leading = 1) const { + auto base = Reference(bias0, bias1, bias2); + std::vector out; + out.reserve(leading * M * N); + for (int64_t i = 0; i < leading; ++i) { + out.insert(out.end(), base.begin(), base.end()); + } + return out; + } +}; + +// 1. Bias provided as initializer -> Kleidi packs bias and skips runtime add. +TEST(DynamicQuantizeMatMul, KleidiBiasInitializer) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + const std::vector bias = {0.25f, -0.5f, 1.125f}; + auto expected = data.Reference(bias[0], bias[1], bias[2]); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true /*initializer*/); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true /*initializer*/); + test.AddInput("bias", {data.N}, bias, true /*initializer*/); + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.03f); + test.Run(); +} + +// 2. Bias as runtime tensor -> exercise deferred bias add branch. +TEST(DynamicQuantizeMatMul, KleidiBiasRuntime) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + const std::vector bias = {1.0f, 0.0f, -0.75f}; + auto expected = data.Reference(bias[0], bias[1], bias[2]); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddInput("bias", {data.N}, bias, false /*runtime*/); + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.03f); + test.Run(); +} + +// 3. Non-zero zero-points -> Kleidi pack rejected, falls back to generic path. +TEST(DynamicQuantizeMatMul, KleidiRejectsNonZeroZeroPoint) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + data.b_zp = {1, 0, 0}; // violates symmetry, Kleidi path disabled + auto expected = data.Reference(0.f, 0.f, 0.f); // still compare to reference + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp); + test.AddOptionalInputEdge(); // no bias + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.03f); + test.Run(); // succeeds, but exercises the “fallback” branch +} + +// 4. Invalid scales -> Kleidi pack rejected. +TEST(DynamicQuantizeMatMul, KleidiRejectsInvalidScale) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + data.b_scale[1] = 0.f; // invalid + auto expected = data.Reference(0.f, 0.f, 0.f); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", {data.K, data.N}, data.b, true); + test.AddInput("b_scale", {data.N}, data.b_scale, true); + test.AddInput("b_zero_point", {data.N}, data.b_zp, true); + test.AddOptionalInputEdge(); + test.AddOutput("Y", {data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.03f); + test.Run(); +} + +// 5. Unsupported B-shape (e.g., 3D) -> Kleidi pack rejected. +TEST(DynamicQuantizeMatMul, KleidiRejectsUnsupportedBShape) { + if (!HasArmSME()) GTEST_SKIP(); + KleidiDynMatMulData data; + std::vector b_3d; + b_3d.reserve(2 * data.b.size()); + b_3d.insert(b_3d.end(), data.b.begin(), data.b.end()); + b_3d.insert(b_3d.end(), data.b.begin(), data.b.end()); + std::vector b_shape = {2, data.K, data.N}; + + std::vector b_scale_3d; + b_scale_3d.reserve(2 * data.N); + b_scale_3d.insert(b_scale_3d.end(), data.b_scale.begin(), data.b_scale.end()); + b_scale_3d.insert(b_scale_3d.end(), data.b_scale.begin(), data.b_scale.end()); + + std::vector b_zp_3d; + b_zp_3d.reserve(2 * data.N); + b_zp_3d.insert(b_zp_3d.end(), data.b_zp.begin(), data.b_zp.end()); + b_zp_3d.insert(b_zp_3d.end(), data.b_zp.begin(), data.b_zp.end()); + + auto expected = data.Reference3D(0.f, 0.f, 0.f, /*leading=*/2); + + OpTester test("DynamicQuantizeMatMul", 1, kMSDomain); + test.AddInput("A", {data.M, data.K}, data.a); + test.AddInput("B", b_shape, b_3d, true); + test.AddInput("b_scale", {2, 1, data.N}, b_scale_3d, true); + test.AddInput("b_zero_point", {2, 1, data.N}, b_zp_3d, true); + + test.AddOptionalInputEdge(); + test.AddOutput("Y", {2, data.M, data.N}, expected); + test.SetOutputAbsErr("Y", 0.03f); + test.Run(); +} + +#endif // USE_KLEIDIAI + TEST(DynamicQuantizeMatMul, B_PerColumn_ND) { auto test_case = [&](const std::vector& input_shape, const std::vector& weights_shape, diff --git a/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc b/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc index 8b15ac5300a82..84ecadf7c2d74 100644 --- a/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc +++ b/onnxruntime/test/contrib_ops/fused_matmul_op_test.cc @@ -213,7 +213,7 @@ void RunFusedMatMulTest(const char* op_name, int32_t opset_version = 7, bool tra test.AddOutput("Y", t.expected_dims, t.expected_vals); // Disable OpenVINO, TensorRT because of unsupported data type - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider}); } } diff --git a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc index cfdce9479843c..3d5e3e5f360b4 100644 --- a/onnxruntime/test/contrib_ops/matmul_2bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_2bits_test.cc @@ -15,6 +15,7 @@ #include "core/mlas/inc/mlas_q4.h" #include "core/mlas/inc/mlas.h" #include "core/session/inference_session.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/common/cuda_op_test_utils.h" #include "test/common/tensor_op_test_utils.h" #include "test/unittest_util/framework_test_utils.h" @@ -249,45 +250,203 @@ void TestMatMul2BitsTyped(float abs_error = 0.1f, float rel_error = 0.02f) { } // namespace -template -struct TypedTestParams { - static constexpr int batch_size = BatchSize; - static constexpr int M = MVal; - static constexpr int N = NVal; - static constexpr int K = KVal; -}; +template +void TestMatMul2BitsLutGemm(int64_t M, int64_t N, int64_t K, int64_t block_size, + bool has_zero_point, float abs_error = 0.15f, float rel_error = 0.05f) { + if (K % 32 != 0 || N % 128 != 0 || block_size % 32 != 0) { + GTEST_SKIP() << "LUT GEMM requires K multiple of 32, N multiple of 128, block_size multiple of 32"; + } -using TestTypes = ::testing::Types< - TypedTestParams<1, 1, 16, 16>, - TypedTestParams<1, 2, 16, 16>, - TypedTestParams<1, 32, 16, 16>, - TypedTestParams<1, 32, 32, 16>, - TypedTestParams<1, 32, 16, 128>, - TypedTestParams<1, 288, 16, 16>, - TypedTestParams<4, 1, 16, 16>, - TypedTestParams<4, 2, 16, 16>, - TypedTestParams<4, 32, 16, 16>, - TypedTestParams<4, 32, 32, 16>, - TypedTestParams<4, 32, 16, 128>, - TypedTestParams<4, 288, 16, 16>>; - -template -class MatMulNBits : public ::testing::Test { - public: - static constexpr int batch_size = T::batch_size; - static constexpr int M = T::M; - static constexpr int N = T::N; - static constexpr int K = T::K; -}; + if (!MlasIsLutGemmAvailable(static_cast(N), static_cast(K), 2, static_cast(block_size))) { + GTEST_SKIP() << "LUT GEMM not available on this platform"; + } + + RandomValueGenerator random{1234}; + std::vector input0_fp32_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); + std::vector input1_fp32_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape(static_cast(block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_rows, q_cols); + + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes(static_cast(block_size), /* columnwise */ true, + static_cast(K), static_cast(N), + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + std::vector input1_vals(q_data_size_in_bytes); + std::vector scales(q_scale_size); + std::vector zp(q_zp_size_in_bytes); + + auto& ortenv = **ort_env.get(); + onnxruntime::concurrency::ThreadPool* tp = ortenv.GetEnvironment().GetIntraOpThreadPool(); + + MlasQuantizeBlockwise( + input1_vals.data(), + scales.data(), + has_zero_point ? zp.data() : nullptr, + input1_fp32_vals.data(), + static_cast(block_size), + true, + static_cast(K), + static_cast(N), + static_cast(N), + tp); + + // Dequantize for reference computation + MlasDequantizeBlockwise( + input1_fp32_vals.data(), + input1_vals.data(), + scales.data(), + has_zero_point ? zp.data() : nullptr, + static_cast(block_size), + true, + static_cast(K), + static_cast(N), + tp); -TYPED_TEST_SUITE(MatMulNBits, TestTypes); + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_fp32_vals[m * K + k] * input1_fp32_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", static_cast(0)); + + if constexpr (std::is_same::value) { + test.AddInput("A", {M, K}, input0_fp32_vals, false); + } + + int64_t k_blocks = (K + block_size - 1) / block_size; + test.AddInput("B", {q_cols, k_blocks, q_rows / k_blocks}, input1_vals, true); + + if constexpr (std::is_same::value) { + test.AddInput("scales", {N, static_cast(q_scale_size) / N}, scales, true); + } + + if (has_zero_point) { + test.AddInput("zero_points", {N, static_cast(q_zp_size_in_bytes) / N}, zp, true); + } else { + test.AddOptionalInputEdge(); + } + + test.AddOptionalInputEdge(); + test.AddOptionalInputEdge(); + + if constexpr (std::is_same::value) { + test.AddOutput("Y", {M, N}, expected_vals); + } + + test.SetOutputAbsErr("Y", abs_error); + test.SetOutputRelErr("Y", rel_error); + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsMlasLutGemm, "1")); + + test.Config(so) + .ConfigEp(DefaultCpuExecutionProvider()) + .RunWithConfig(); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_128x128) { + TestMatMul2BitsLutGemm(1, 128, 128, 32, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x128) { + TestMatMul2BitsLutGemm(1, 128, 128, 32, true); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256) { + TestMatMul2BitsLutGemm(1, 256, 256, 32, false); +} + +// This test was previously disabled due to accuracy issues related to non-deterministic +// gather operations. It is now re-enabled after replacing gather with deterministic +// load+shuffle operations to improve determinism and stability. +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256) { + TestMatMul2BitsLutGemm(1, 256, 256, 32, true); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_256x256_BlkLen64) { + TestMatMul2BitsLutGemm(1, 256, 256, 64, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_256x256_BlkLen64) { + TestMatMul2BitsLutGemm(1, 256, 256, 64, true); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_128x256_BlkLen128) { + TestMatMul2BitsLutGemm(1, 128, 256, 128, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_128x256_BlkLen128) { + TestMatMul2BitsLutGemm(1, 128, 256, 128, true); +} + +// Batch tests (M > 1) +TEST(MatMulNBitsLutGemm, Float32_2Bits_Symmetric_Batch32_128x128) { + TestMatMul2BitsLutGemm(32, 128, 128, 32, false); +} + +TEST(MatMulNBitsLutGemm, Float32_2Bits_Asymmetric_Batch32_256x256) { + TestMatMul2BitsLutGemm(32, 256, 256, 32, true); +} -TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy0) { - TestMatMul2BitsTyped(); +TEST(MatMul2Bits, Float32_2b_Accuracy0) { + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); } -TYPED_TEST(MatMulNBits, Float32_2Bits_Accuracy4) { - TestMatMul2BitsTyped(); +TEST(MatMul2Bits, Float32_2b_Accuracy4) { + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); + TestMatMul2BitsTyped(); } } // namespace test diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index f0d5ddb422404..66f87142d3a34 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -495,6 +495,16 @@ TEST(MatMulNBits, Float16_4b_Accuracy4) { TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); TestMatMulNBitsTyped(); + + // See PR #27412 for details on the following test case, + // which is added to cover a specific failure case in the past. + // 6144, 2048 + + // Since K is larger (more change of larger error), + // and N is larger (more chance of havinga value with larger error), + // we set a higher tolerance for this case to avoid false positives + // and flaky failures. + TestMatMulNBitsTyped(0.2f, 0.03f); } TEST(MatMulNBits, LegacyShape_4b) { @@ -814,6 +824,62 @@ TEST(MatMulNBits, BFloat16_Int4_NoZeroPoint) { #endif #endif // defined(USE_CUDA) || defined(USE_DML) + +#if defined(USE_QNN) && defined(_M_ARM64) + +namespace { +// QNN-EP Test Function +// This has too many parameters of the same type that must be specified in the correct order. +// Consider using the overload with a TestOptions parameter. +void RunQnnEpTest(int64_t M, int64_t N, int64_t K, bool has_zeropoint = true, float abs_error = 0.05f) { + TestOptions opts{}; + opts.M = M; + opts.N = N; + opts.K = K; + opts.block_size = 32; + opts.accuracy_level = 4; + opts.has_zero_point = has_zeropoint; + opts.zp_is_4bit = true; + opts.has_g_idx = false; + opts.has_bias = false; + + if (abs_error > 0.f) { + opts.output_abs_error = abs_error; + } + + std::vector> execution_providers; + ProviderOptions provider_options; + provider_options["backend_type"] = "gpu"; + provider_options["offload_graph_io_quantization"] = "0"; + execution_providers.push_back(QnnExecutionProviderWithOptions(provider_options)); + + RunTest(opts, std::move(execution_providers)); +} +} // namespace + +// QNN GPU Only support FP16 activations and Q4_0 weights, with zero_points = 8 +// Accumulation with larger channel accumulates more error. Set higher abs_error with respect to K. +TEST(MatMulNBits, Basic_M1_N128_K512_withZp) { + constexpr float abs_error = 0.05f; + RunQnnEpTest(1, 128, 512, true, abs_error); +} + +TEST(MatMulNBits, Basic_M1_N128_K512) { + constexpr float abs_error = 0.05f; + RunQnnEpTest(1, 128, 512, false, abs_error); +} + +TEST(MatMulNBits, Basic_M10_N128_K512_withZp) { + constexpr float abs_error = 0.05f; + RunQnnEpTest(10, 128, 512, true, abs_error); +} + +TEST(MatMulNBits, Basic_M10_N128_K512) { + constexpr float abs_error = 0.05f; + RunQnnEpTest(10, 128, 512, false, abs_error); +} +#endif + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/contrib_ops/quantize_ops_test.cc b/onnxruntime/test/contrib_ops/quantize_ops_test.cc index db685967ae5ff..f3bf09c9533d1 100644 --- a/onnxruntime/test/contrib_ops/quantize_ops_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_ops_test.cc @@ -287,9 +287,46 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int8) { 127, -127, 127, -128, 127, -128}); + std::unordered_set excluded_providers; // Disable Tensorrt EP due to error: node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); +} + +#ifdef USE_OPENVINO +TEST(QuantizeLinearContribOpTest, OVEPQuantizeLinear_per_tensor_float_int8) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{16}; + test.AddInput("x", dims, { + 0.f, 2.f, // + 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // low case + 3.1f, -3.1f, // up case + 254.f, -256.f, // critical point + 255.f, -257.f, // critical point + 256.f, -258.f, // critical point + 1000.f, -1000.f // saturate case + }); + test.AddInput("y_scale", {}, {2.0f}); + test.AddInput("y_zero_point", {}, {1}); + test.AddOutput("y", dims, + {1, 2, + 2, 0, + 2, 0, + 3, -1, + 127, -127, + 127, -128, + 127, -128, + 127, -128}); + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); } +#endif // USE_OPENVINO // Test uint16 com.microsoft.QuantizeLinear (per tensor) TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_uint16) { @@ -311,10 +348,41 @@ TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_uint16) { 32769, 32765, 65535, 0, 65535, 0}); - + std::unordered_set excluded_providers; // Disable Tensorrt EP due to error: unsupported data type - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); +} + +#ifdef USE_OPENVINO +TEST(QuantizeLinearContribOpTest, OVEPQuantizeLinear_per_tensor_float_uint16) { + OpTester test("QuantizeLinear", 1, onnxruntime::kMSDomain); + std::vector dims{12}; + test.AddInput("x", dims, { + 0.f, -128.f, 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // round < .5 + 3.1f, -3.1f, // round > .5 + 65536.f, -65534.f, // critical point + 70000.f, -70000.f // saturate case + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {32767}, true); + test.AddOutput("y", dims, + {32767, 32703, + 32768, 32766, + 32768, 32766, + 32769, 32765, + 65535, 0, + 65535, 0}); + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); } +#endif // USE_OPENVINO // Test int16 com.microsoft.QuantizeLinear (per tensor) TEST(QuantizeLinearContribOpTest, QuantizeLinear_per_tensor_float_int16) { diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index c33e9e19e1858..055b2551328d9 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -34,7 +34,9 @@ namespace test { // forward-declaration for utility that uses public C APIs to check that an OrtGraph is equivalent // to a graph represented by the internal ORT GraphViewer class. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph); -static void Check_Graph_GetSubgraph(const OrtGraph& api_graph); +static void CheckGetSubGraph(const OrtGraph& api_graph); +static void CheckGetSubGraphForSpecificModel(const OrtGraph& api_graph); +static void CheckGraphWithDFSTraversal(const GraphViewer& graph_viewer); // // Tests @@ -58,6 +60,23 @@ TEST(EpGraphTest, CheckModelWithSubgraphs) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +// Use public C APIs to check that the OrtGraph from a subset of nodes from another OrtGraph is correct. +TEST(EpGraphTest, CheckModelWithGetGraphFromSubsetOfNodes) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/topk_and_multiple_graph_outputs.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGetSubGraphForSpecificModel(test_graph->GetOrtGraph()); +} + +// Use public C APIs to check that the OrtGraph for a model with subgraphs is correct. +// Subgraph inside the control flow op first, then the op itself. Simialr to EP's GetCapability() bottom-up approach. +TEST(EpGraphTest, CheckModelWithSubgraphsWithDFSTraversal) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/scan_1.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphWithDFSTraversal(test_graph->GetGraphViewer()); +} + // Use public C APIs to check that the OrtGraph for bart_tiny.onnx is correct. // This model is used in an example topological sort implementation. TEST(EpGraphTest, CheckModelBartTiny) { @@ -834,8 +853,8 @@ static void CheckValueInfosCApi(const GraphViewer& graph_viewer, gsl::span nodes = ort_graph.GetNodes(); @@ -867,6 +886,196 @@ static void Check_Graph_GetSubgraph(const OrtGraph& api_graph) { // model_proto->SerializeToOstream(&dump); } +static void CheckSubGraphTopoSort(const OrtGraph& api_graph) { + /* + * topk_and_multiple_graph_outputs.onnx: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + + Ort::ConstGraph ort_graph{&api_graph}; + auto nodes = ort_graph.GetNodes(); + + // Select three nodes from four nodes to create a OrtGraph + size_t num_selected_nodes = 3; + std::vector selected_nodes(num_selected_nodes); + + // The subgraph contains Less, Div and Mod ops. + selected_nodes[0] = nodes[1]; + selected_nodes[1] = nodes[2]; + selected_nodes[2] = nodes[3]; + + Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); + + // When doing Kahns's Topo sort, it will try to get the producer node outside of the subgraph, + // i.e. auto producer_info = input.GetProducerNode(). + // Here is to check that Topo sort's implementation will return nullptr for the outside node and won't hit assert. + std::vector nodes_with_priority; + Ort::Status status(KahnsTopologicalSort( + *sub_graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ORT_ENFORCE(status.IsOK()); + + nodes_with_priority.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); +} + +// Checks the Graph_GetGraphView C API +static void CheckGetSubGraphForSpecificModel(const OrtGraph& api_graph) { + /* + * topk_and_multiple_graph_outputs.onnx: + * + * "input" ---> TopK --- + * |---> "scores" + * |--- Less ---> "Less_output_0" + * |--- Div ---> "Div_output_0" + * |--- Mod ---> "labels" + */ + + Ort::ConstGraph ort_graph{&api_graph}; + + // The node order returning from Graph_GetNumNodes() is using ORT's default topological sort. + // For this model, the node order in onnx GraphProto is not the same as the node order in "nodes", + // So here we sort OrtGraph with a custom Kahn's topological sorting algorithm. + // i.e. + // onnx GraphProto: TopK, Less, Div, Mod + // Graph_GetNumNodes(): TopK, Mode, Div, Less + // priority-based sort: TopK, Less, Div, Mod + std::vector nodes; + Ort::Status status(KahnsTopologicalSort( + api_graph, + [&](const OrtNode* node) { + size_t node_id = 0; + Ort::Status status(Ort::GetApi().Node_GetId(node, &node_id)); + ORT_ENFORCE(status.IsOK()); + + nodes.push_back(Ort::ConstNode(node)); + }, + PriorityNodeCompare())); + ASSERT_TRUE(status.IsOK()) << status.GetErrorMessage(); + + // Select three nodes from four nodes to create a OrtGraph + size_t num_selected_nodes = 3; + std::vector selected_nodes(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + /* + * After calling Graph_GetGraphView(), the graph should be: + * + * "input" ---> TopK --- + * |---> "scores" + * |---> "topk_indices" (Note: This output will be consumbed by node not in this subgraph) + * |--- Less---> "Less_output_0" + * |--- Div ---> "Div_output_0" + */ + Ort::Graph sub_graph = ort_graph.GetGraphView(selected_nodes); + const GraphViewer& sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + + ASSERT_EQ(sub_graph.GetNodes().size(), 3); + + ASSERT_EQ(sub_graph_viewer.GetInputs().size(), 1); + const auto* input = sub_graph_viewer.GetInputs()[0]; + ASSERT_TRUE(input->Name() == "input"); + + ASSERT_EQ(sub_graph_viewer.GetOutputs().size(), 4); + const auto* output_1 = sub_graph_viewer.GetOutputs()[0]; + ASSERT_TRUE(output_1->Name() == "scores"); + const auto* output_2 = sub_graph_viewer.GetOutputs()[1]; + ASSERT_TRUE(output_2->Name() == "topk_indices"); + const auto* output_3 = sub_graph_viewer.GetOutputs()[2]; + ASSERT_TRUE(output_3->Name() == "Less_output_0"); + const auto* output_4 = sub_graph_viewer.GetOutputs()[3]; + ASSERT_TRUE(output_4->Name() == "Div_17_output_0"); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + std::unique_ptr model = std::make_unique(sub_graph_viewer.Name(), true, sub_graph_viewer.GetGraph().GetLogger()); + auto model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + // Test the case where the subgraph equals to srouce graph + num_selected_nodes = 4; + selected_nodes.resize(num_selected_nodes); + + for (size_t i = 0; i < num_selected_nodes; i++) { + selected_nodes[i] = nodes[i]; + } + + sub_graph = ort_graph.GetGraphView(selected_nodes); + const GraphViewer& new_sub_graph_viewer = EpGraph::ToInternal(sub_graph)->GetGraphViewer(); + + ASSERT_EQ(sub_graph.GetNodes().size(), 4); + + ASSERT_EQ(new_sub_graph_viewer.GetInputs().size(), 1); + input = new_sub_graph_viewer.GetInputs()[0]; + ASSERT_TRUE(input->Name() == "input"); + + ASSERT_EQ(new_sub_graph_viewer.GetOutputs().size(), 4); + output_1 = new_sub_graph_viewer.GetOutputs()[0]; + ASSERT_TRUE(output_1->Name() == "scores"); + output_2 = new_sub_graph_viewer.GetOutputs()[1]; + ASSERT_TRUE(output_2->Name() == "Less_output_0"); + output_3 = new_sub_graph_viewer.GetOutputs()[2]; + ASSERT_TRUE(output_3->Name() == "Div_17_output_0"); + output_4 = new_sub_graph_viewer.GetOutputs()[3]; + ASSERT_TRUE(output_4->Name() == "labels"); + + // Convert OrtGraph/GraphViewer to ModelProto and dump it to disk. + // If the GraphViewer associated with the OrtGraph somehow is incorrect, GraphViewerToProto() will throw. + model = std::make_unique(new_sub_graph_viewer.Name(), true, new_sub_graph_viewer.GetGraph().GetLogger()); + model_proto = std::make_unique(model->ToProto()); + GraphViewerToProto(new_sub_graph_viewer, *model_proto->mutable_graph(), true, true, static_cast(1)); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + + // Dump the graph for debugging + // auto graph_name = ort_graph.GetName(); + // std::string name = graph_name; + // name += "_half.onnx"; + // std::fstream dump(name, std::ios::out | std::ios::trunc | std::ios::binary); + // model_proto->SerializeToOstream(&dump); + + CheckSubGraphTopoSort(api_graph); +} + +static void CheckGraphWithDFSTraversal(const GraphViewer& graph_viewer) { + std::vector node_indices = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); + for (const auto& node_idx : node_indices) { + const Node* node = graph_viewer.GetNode(node_indices[node_idx]); + + // Check node subgraphs + std::unordered_map> node_subgraphs_map = + node->GetAttributeNameToSubgraphMap(); + + if (!node_subgraphs_map.empty()) { + for (const auto& name_subgraph : node_subgraphs_map) { + auto subgraph_viewer = std::make_unique(*name_subgraph.second); + CheckGraphWithDFSTraversal(*subgraph_viewer); + } + } + } + + std::unique_ptr ep_graph = nullptr; + ORT_ENFORCE(EpGraph::Create(graph_viewer, ep_graph, true).IsOK()); + + if (graph_viewer.ParentNode()) { + const OrtNode* parent_node = nullptr; + ORT_ENFORCE(ep_graph->GetParentNode(parent_node).IsOK()); + ASSERT_NE(parent_node, nullptr); + } +} + // Checks that the contents of the original GraphViewer matches the contents of the OrtGraph. // Uses the public C APIs to traverse the OrtGraph. static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_graph) { @@ -1048,7 +1257,7 @@ static void CheckGraphCApi(const GraphViewer& graph_viewer, const OrtGraph& api_ } // Check creating an OrtGraph from a subset of nodes in an OrtGraph - Check_Graph_GetSubgraph(api_graph); + CheckGetSubGraph(api_graph); } } // namespace test diff --git a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc index 2e2bce97f0cb9..3bc2bd9052fa1 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_topo_sort.cc @@ -16,220 +16,8 @@ // Test implementation of Kahn's Topological sort using public C graph APIs and C++ STL. // -#define RETURN_IF_API_ERROR(fn) \ - do { \ - Ort::Status status(fn); \ - if (!status.IsOK()) { \ - return status; \ - } \ - } while (0) - namespace onnxruntime { namespace test { -template -struct VisitorPriorityQueue { - using ComparatorType = std::function; - std::list list_; - const ComparatorType comparator_ = nullptr; - VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} - - void push(T node) { - list_.insert( - std::upper_bound(list_.begin(), list_.end(), node, comparator_), - node); - } - bool empty() { return list_.empty(); } - T top() { return list_.back(); } - void pop() { list_.pop_back(); } -}; - -// Get the number of input edges that come from another node upstream. -static Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_edges) { - const OrtApi& ort_api = Ort::GetApi(); - - size_t num_inputs = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetNumInputs(node, &num_inputs)); - - std::vector inputs(num_inputs); - RETURN_IF_API_ERROR(ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); - - // Sum the number of inputs with a producer node. - num_input_edges = 0; - - for (const OrtValueInfo* ort_input : inputs) { - Ort::ConstValueInfo input{ort_input}; - if (input == nullptr) continue; // Skip missing optional input - - auto producer_info = input.GetProducerNode(); - num_input_edges += static_cast(producer_info.node != nullptr); - } - - return Ort::Status{nullptr}; -} - -// Get all output nodes that consume an output from the given node. -static Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { - const OrtApi& ort_api = Ort::GetApi(); - - size_t num_outputs = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetNumOutputs(node, &num_outputs)); - - std::vector outputs(num_outputs); - RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); - - std::vector output_nodes; - output_nodes.reserve(num_outputs); // May have more than `num_outputs` - - // Gather the OrtNode consumers of every output. - for (const OrtValueInfo* ort_output : outputs) { - Ort::ConstValueInfo output{ort_output}; - if (output == nullptr) continue; // Skip missing optional output - - auto consumers_info = output.GetConsumers(); - for (const auto& consumer : consumers_info) { - output_nodes.push_back(consumer.node); - } - } - - result = std::move(output_nodes); - return Ort::Status{nullptr}; -} - -// Kahn's topological sort. -// Adapted from onnxruntime/core/graph/graph.cc to use public C API graph types. -static Ort::Status KahnsTopologicalSort(const OrtGraph& graph, - const std::function& enter, - const std::function& comp) { - const OrtApi& ort_api = Ort::GetApi(); - - try { - // Get all nodes - size_t num_nodes = 0; - RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); - - if (num_nodes == 0) { - return Ort::Status{nullptr}; // Nothing to sort. - } - - std::vector nodes(num_nodes); - RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); - - // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. - size_t max_node_id = 0; - for (const OrtNode* node : nodes) { - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - max_node_id = std::max(max_node_id, node_id); - } - - std::vector in_degree(max_node_id + 1, 0); - std::vector topo_order; - VisitorPriorityQueue to_visit(comp); - - topo_order.reserve(num_nodes); - - // Initialize in_degree and initial nodes to visit first. - for (const OrtNode* node : nodes) { - size_t input_edge_count = 0; - RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); - - size_t node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); - - in_degree[node_id] = input_edge_count; - if (input_edge_count == 0) { - to_visit.push(node); - } - } - - while (!to_visit.empty()) { - const OrtNode* current_node = to_visit.top(); - to_visit.pop(); - - if (!current_node) continue; - - if (enter) { - enter(current_node); - } - - std::vector output_nodes; - RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); - - for (const auto& output_node : output_nodes) { - size_t output_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); - - auto& node_in_degree = in_degree[output_node_id]; - node_in_degree--; - - if (node_in_degree == 0) { - to_visit.push(output_node); - } - } - - size_t current_node_id = 0; - RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); - topo_order.push_back(current_node_id); - } - - if (num_nodes != topo_order.size()) { - return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); - } - } catch (const Ort::Exception& ex) { - Ort::Status status(ex); - return status; - } catch (const std::exception& ex) { - Ort::Status status(ex.what(), ORT_EP_FAIL); - return status; - } - - return Ort::Status{nullptr}; -} - -// Node comparison functor copied from onnxruntime/core/graph/graph.cc -struct PriorityNodeCompare { - inline bool IsHighPri(const OrtNode* n) const { - // local statics so we can compare std::strings in the checks - static constexpr std::string_view shape_op("Shape"); - static constexpr std::string_view size_op("Size"); - - const char* op_type = nullptr; - Ort::Status status(Ort::GetApi().Node_GetOperatorType(n, &op_type)); - ORT_ENFORCE(status.IsOK()); - - return shape_op == op_type || size_op == op_type; - } - - // Used for std::priority_queue - // If return false, n1 will be output first - // If return true, n2 will be output first - bool operator()(const OrtNode* n1, const OrtNode* n2) const { - // nodes in global high priority list will be output first - const bool isN1HighPri = IsHighPri(n1); - const bool isN2HighPri = IsHighPri(n2); - if (isN1HighPri != isN2HighPri) { - return isN2HighPri; - } - - // nodes with lower priority value will be output first - const auto n1_priority = 0; // n1->Priority(); // Looks to always be 0 inside ORT? - const auto n2_priority = 0; // n2->Priority(); // Looks to always be 0 inside ORT? - if (n1_priority != n2_priority) { - return n1_priority > n2_priority; - } - - // otherwise, nodes with lower index will be output first - size_t n1_id = 0; - Ort::Status status1(Ort::GetApi().Node_GetId(n1, &n1_id)); - ORT_ENFORCE(status1.IsOK()); - - size_t n2_id = 0; - Ort::Status status2(Ort::GetApi().Node_GetId(n2, &n2_id)); - ORT_ENFORCE(status2.IsOK()); - - return n1_id > n2_id; - } -}; TEST(EpGraphTest, BasicKahnTopoSort) { auto test_graph = TestGraph::Load(ORT_TSTR("testdata/bart_tiny.onnx")); diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc index 3b3bc4c6da911..11133213d0746 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.cc @@ -90,5 +90,148 @@ Status GetNodeArgConsumers(const GraphViewer& graph_viewer, const NodeArg& node_ } return Status::OK(); } + +// Get the number of input edges that come from another node upstream. +Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_edges) { + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_inputs = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetNumInputs(node, &num_inputs)); + + std::vector inputs(num_inputs); + RETURN_IF_API_ERROR(ort_api.Node_GetInputs(node, inputs.data(), inputs.size())); + + // Sum the number of inputs with a producer node. + num_input_edges = 0; + + for (const OrtValueInfo* ort_input : inputs) { + Ort::ConstValueInfo input{ort_input}; + if (input == nullptr) continue; // Skip missing optional input + + auto producer_info = input.GetProducerNode(); + num_input_edges += static_cast(producer_info.node != nullptr); + } + + return Ort::Status{nullptr}; +} + +// Get all output nodes that consume an output from the given node. +Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result) { + const OrtApi& ort_api = Ort::GetApi(); + + size_t num_outputs = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetNumOutputs(node, &num_outputs)); + + std::vector outputs(num_outputs); + RETURN_IF_API_ERROR(ort_api.Node_GetOutputs(node, outputs.data(), outputs.size())); + + std::vector output_nodes; + output_nodes.reserve(num_outputs); // May have more than `num_outputs` + + // Gather the OrtNode consumers of every output. + for (const OrtValueInfo* ort_output : outputs) { + Ort::ConstValueInfo output{ort_output}; + if (output == nullptr) continue; // Skip missing optional output + + auto consumers_info = output.GetConsumers(); + for (const auto& consumer : consumers_info) { + output_nodes.push_back(consumer.node); + } + } + + result = std::move(output_nodes); + return Ort::Status{nullptr}; +} + +// Kahn's topological sort. +// Adapted from onnxruntime/core/graph/graph.cc to use public C API graph types. +Ort::Status KahnsTopologicalSort(const OrtGraph& graph, + const std::function& enter, + const std::function& comp) { + const OrtApi& ort_api = Ort::GetApi(); + + try { + // Get all nodes + size_t num_nodes = 0; + RETURN_IF_API_ERROR(ort_api.Graph_GetNumNodes(&graph, &num_nodes)); + + if (num_nodes == 0) { + return Ort::Status{nullptr}; // Nothing to sort. + } + + std::vector nodes(num_nodes); + RETURN_IF_API_ERROR(ort_api.Graph_GetNodes(&graph, nodes.data(), nodes.size())); + + // Get the maximum node ID. Not really required if we chose to represent the `in_degree` as a map instead of vector. + size_t max_node_id = 0; + for (const OrtNode* node : nodes) { + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + max_node_id = std::max(max_node_id, node_id); + } + + std::vector in_degree(max_node_id + 1, 0); + std::vector topo_order; + VisitorPriorityQueue to_visit(comp); + + topo_order.reserve(num_nodes); + + // Initialize in_degree and initial nodes to visit first. + for (const OrtNode* node : nodes) { + size_t input_edge_count = 0; + RETURN_IF_API_ERROR(GetNodeInputEdgeCount(node, input_edge_count)); + + size_t node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(node, &node_id)); + + in_degree[node_id] = input_edge_count; + if (input_edge_count == 0) { + to_visit.push(node); + } + } + + while (!to_visit.empty()) { + const OrtNode* current_node = to_visit.top(); + to_visit.pop(); + + if (!current_node) continue; + + if (enter) { + enter(current_node); + } + + std::vector output_nodes; + RETURN_IF_API_ERROR(GetOutputNodes(current_node, output_nodes)); + + for (const auto& output_node : output_nodes) { + size_t output_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(output_node, &output_node_id)); + + auto& node_in_degree = in_degree[output_node_id]; + node_in_degree--; + + if (node_in_degree == 0) { + to_visit.push(output_node); + } + } + + size_t current_node_id = 0; + RETURN_IF_API_ERROR(ort_api.Node_GetId(current_node, ¤t_node_id)); + topo_order.push_back(current_node_id); + } + + if (num_nodes != topo_order.size()) { + return Ort::Status("Some nodes are not included in the topological sort: graph has a cycle", ORT_FAIL); + } + } catch (const Ort::Exception& ex) { + Ort::Status status(ex); + return status; + } catch (const std::exception& ex) { + Ort::Status status(ex.what(), ORT_EP_FAIL); + return status; + } + + return Ort::Status{nullptr}; +} } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/ep_graph/test_ep_graph_utils.h b/onnxruntime/test/ep_graph/test_ep_graph_utils.h index 2aebd75e0aaac..85ad113ea62e7 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph_utils.h +++ b/onnxruntime/test/ep_graph/test_ep_graph_utils.h @@ -13,6 +13,14 @@ #include "test/util/include/test_environment.h" +#define RETURN_IF_API_ERROR(fn) \ + do { \ + Ort::Status status(fn); \ + if (!status.IsOK()) { \ + return status; \ + } \ + } while (0) + struct OrtGraph; namespace onnxruntime { namespace test { @@ -72,5 +80,80 @@ Status GetNodeArgConsumers(const GraphViewer& graph_viewer, const NodeArg& node_ // Get output index for the given NodeArg name. Returns error if the node does not produce that node arg as an output. Status GetOutputIndex(const Node& producer_node, const std::string& name, /*out*/ size_t& index); + +template +struct VisitorPriorityQueue { + using ComparatorType = std::function; + std::list list_; + const ComparatorType comparator_ = nullptr; + VisitorPriorityQueue(const ComparatorType& comp) : comparator_(comp) {} + + void push(T node) { + list_.insert( + std::upper_bound(list_.begin(), list_.end(), node, comparator_), + node); + } + bool empty() { return list_.empty(); } + T top() { return list_.back(); } + void pop() { list_.pop_back(); } +}; + +// Get the number of input edges that come from another node upstream. +Ort::Status GetNodeInputEdgeCount(const OrtNode* node, size_t& num_input_edges); + +// Get all output nodes that consume an output from the given node. +Ort::Status GetOutputNodes(const OrtNode* node, std::vector& result); + +// Kahn's topological sort. +// Adapted from onnxruntime/core/graph/graph.cc to use public C API graph types. +Ort::Status KahnsTopologicalSort(const OrtGraph& graph, + const std::function& enter, + const std::function& comp); + +// Node comparison functor copied from onnxruntime/core/graph/graph.cc +struct PriorityNodeCompare { + inline bool IsHighPri(const OrtNode* n) const { + // local statics so we can compare std::strings in the checks + static constexpr std::string_view shape_op("Shape"); + static constexpr std::string_view size_op("Size"); + + const char* op_type = nullptr; + Ort::Status status(Ort::GetApi().Node_GetOperatorType(n, &op_type)); + ORT_ENFORCE(status.IsOK()); + + return shape_op == op_type || size_op == op_type; + } + + // Used for std::priority_queue + // If return false, n1 will be output first + // If return true, n2 will be output first + bool operator()(const OrtNode* n1, const OrtNode* n2) const { + // nodes in global high priority list will be output first + const bool isN1HighPri = IsHighPri(n1); + const bool isN2HighPri = IsHighPri(n2); + if (isN1HighPri != isN2HighPri) { + return isN2HighPri; + } + + // nodes with lower priority value will be output first + const auto n1_priority = 0; // n1->Priority(); // Looks to always be 0 inside ORT? + const auto n2_priority = 0; // n2->Priority(); // Looks to always be 0 inside ORT? + if (n1_priority != n2_priority) { + return n1_priority > n2_priority; + } + + // otherwise, nodes with lower index will be output first + size_t n1_id = 0; + Ort::Status status1(Ort::GetApi().Node_GetId(n1, &n1_id)); + ORT_ENFORCE(status1.IsOK()); + + size_t n2_id = 0; + Ort::Status status2(Ort::GetApi().Node_GetId(n2, &n2_id)); + ORT_ENFORCE(status2.IsOK()); + + return n1_id > n2_id; + } +}; + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc index 15bce163ba16a..55e0660622f87 100644 --- a/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc +++ b/onnxruntime/test/ep_weight_sharing_ctx_gen/command_args_parser.cc @@ -73,6 +73,8 @@ namespace qnnctxgen { "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '1' (another EP (typically CPU EP) handles the graph I/O quantization and dequantization). \n" "\t [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill file buffer, used while generating QNN context binary.\n" + "\t [QNN only] [extended_udma]: Enable HTP extended UDMA mode for better performance on supported hardware, options: \n" + "\t '0' (disabled), '1' (enabled). Default: '0'. \n" "\t [Example] -i \"vtcm_mb|8 htp_arch|73\" \n" "\n" "\t-h: help\n"); @@ -253,7 +255,7 @@ static bool ParsePluginEpConfig(const std::string& json_file_path, PluginEpConfi ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" || - key == "enable_htp_spill_fill_buffer") { + key == "enable_htp_spill_fill_buffer" || key == "extended_udma") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -266,7 +268,7 @@ static bool ParsePluginEpConfig(const std::string& json_file_path, PluginEpConfi ORT_THROW( "Wrong key type entered. Choose from options: ['backend_type', 'backend_path', 'vtcm_mb', " "'htp_performance_mode', 'htp_graph_finalization_optimization_mode', 'soc_model', 'htp_arch', " - "'enable_htp_fp16_precision', 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer']"); + "'enable_htp_fp16_precision', 'offload_graph_io_quantization', 'enable_htp_spill_fill_buffer', 'extended_udma']"); } test_config.run_config.provider_options[key] = value; diff --git a/onnxruntime/test/framework/ep_compatibility_test.cc b/onnxruntime/test/framework/ep_compatibility_test.cc index 0ae3fb746dd24..288023a130529 100644 --- a/onnxruntime/test/framework/ep_compatibility_test.cc +++ b/onnxruntime/test/framework/ep_compatibility_test.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "gtest/gtest.h" #include "gmock/gmock.h" @@ -90,6 +92,63 @@ class TestCompatibilityExecutionProvider : public IExecutionProvider { bool should_fail_validation_ = false; }; +// Test execution provider that tracks whether GetCapability is called. +// This is used to verify that early validation fails BEFORE Initialize() does expensive work. +class TestEarlyValidationExecutionProvider : public IExecutionProvider { + public: + static constexpr const char* kTestEarlyValidationExecutionProviderType = "TestEarlyValidationExecutionProvider"; + + TestEarlyValidationExecutionProvider() : IExecutionProvider(kTestEarlyValidationExecutionProviderType) { + } + + std::shared_ptr GetKernelRegistry() const override { + return std::make_shared(); + } + + std::vector CreatePreferredAllocators() override { + return {}; + } + + // Override GetCapability to track if it's called (happens during Initialize()) + std::vector> GetCapability( + const onnxruntime::GraphViewer& graph_viewer, + const IKernelLookup& kernel_lookup, + const GraphOptimizerRegistry& graph_optimizer_registry, + IResourceAccountant* resource_accountant = nullptr) const override { + ORT_UNUSED_PARAMETER(graph_viewer); + ORT_UNUSED_PARAMETER(kernel_lookup); + ORT_UNUSED_PARAMETER(graph_optimizer_registry); + ORT_UNUSED_PARAMETER(resource_accountant); + get_capability_called_ = true; + return {}; // Return empty - we don't actually want to handle any nodes + } + + // Configurable mock behavior for validation + void SetMockCompatibilityStatus(OrtCompiledModelCompatibility status) { + mock_compatibility_status_ = status; + } + + common::Status ValidateCompiledModelCompatibilityInfo(const std::string& compatibility_info, + OrtCompiledModelCompatibility& model_compatibility) const override { + ORT_UNUSED_PARAMETER(compatibility_info); + model_compatibility = mock_compatibility_status_; + return Status::OK(); + } + + // Query whether GetCapability was called + bool WasGetCapabilityCalled() const { + return get_capability_called_; + } + + void ResetGetCapabilityCalled() { + get_capability_called_ = false; + } + + private: + OrtCompiledModelCompatibility mock_compatibility_status_ = OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL; + mutable bool get_capability_called_ = false; +}; + // Helper class to create test models class ModelBuilderWithCompatibility { public: @@ -388,6 +447,72 @@ TEST_F(EpCompatibilityTest, TestEpValidationFailure) { EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Mock validation failure")); } +// Test that early validation optimization works: when a model is incompatible, +// validation should fail BEFORE Initialize() performs expensive graph partitioning. +// We verify this by checking that GetCapability() is NOT called when validation fails. +TEST_F(EpCompatibilityTest, TestEarlyValidation_FailsBeforeGetCapability) { + const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType; + const std::string compatibility_string = "test_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_UNSUPPORTED); + + // Verify GetCapability hasn't been called yet + EXPECT_FALSE(test_ep->WasGetCapabilityCalled()); + + // Create model with compatibility metadata for this EP + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + + // Keep a raw pointer to check state after move + auto* test_ep_ptr = test_ep.get(); + + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Initialization should fail due to incompatible model + auto status = InitializeSessionWithValidation(*session); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("not supported")); + + // CRITICAL: GetCapability should NOT have been called because validation failed early, + // before Initialize() could perform graph partitioning + EXPECT_FALSE(test_ep_ptr->WasGetCapabilityCalled()) + << "GetCapability was called, indicating validation did not fail early before Initialize()"; +} + +// Test that when validation succeeds, GetCapability IS called (normal flow) +TEST_F(EpCompatibilityTest, TestEarlyValidation_SucceedsAndProceedsToGetCapability) { + const std::string ep_type = TestEarlyValidationExecutionProvider::kTestEarlyValidationExecutionProviderType; + const std::string compatibility_string = "test_compatibility_v1.0"; + + auto test_ep = std::make_unique(); + test_ep->SetMockCompatibilityStatus(OrtCompiledModelCompatibility_EP_SUPPORTED_OPTIMAL); + + // Verify GetCapability hasn't been called yet + EXPECT_FALSE(test_ep->WasGetCapabilityCalled()); + + // Create model with compatibility metadata for this EP + std::map compatibility_info = {{ep_type, compatibility_string}}; + auto model_with_metadata = ModelBuilderWithCompatibility::CreateModelWithCompatibilityMetadata(compatibility_info); + + auto session = SessionBuilderWithCompatibility::CreateTestSession(std::move(model_with_metadata)); + + // Keep a raw pointer to check state after move + auto* test_ep_ptr = test_ep.get(); + + ASSERT_STATUS_OK(session->RegisterExecutionProvider(std::move(test_ep))); + + // Initialization should succeed + ASSERT_STATUS_OK(InitializeSessionWithValidation(*session)); + + // GetCapability SHOULD have been called because validation succeeded and + // Initialize() proceeded normally with graph partitioning + EXPECT_TRUE(test_ep_ptr->WasGetCapabilityCalled()) + << "GetCapability was not called, but it should have been after successful validation"; +} + // Test session option configuration for fail on suboptimal TEST_F(EpCompatibilityTest, TestSessionOptionConfiguration) { SessionOptions so; @@ -528,3 +653,246 @@ TEST(EpCompatibilityCxxApiTest, SingleDeviceCpuProvider) { ASSERT_TRUE(status == OrtCompiledModelCompatibility_EP_NOT_APPLICABLE); } + +// ----------------------------- +// GetCompatibilityInfoFromModel Tests +// ----------------------------- + +TEST(EpCompatibilityCapiTest, GetCompatibilityInfoFromModel_InvalidArgs) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtAllocator* allocator = nullptr; + ASSERT_EQ(api->GetAllocatorWithDefaultOptions(&allocator), nullptr); + ASSERT_NE(allocator, nullptr); + + char* compat_info = nullptr; + + // model_path == nullptr + OrtStatus* st = api->GetCompatibilityInfoFromModel(nullptr, "TestEP", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // ep_type == nullptr + st = api->GetCompatibilityInfoFromModel(ORT_TSTR("test.onnx"), nullptr, allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // ep_type == empty string + st = api->GetCompatibilityInfoFromModel(ORT_TSTR("test.onnx"), "", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // allocator == nullptr + st = api->GetCompatibilityInfoFromModel(ORT_TSTR("test.onnx"), "TestEP", nullptr, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // compatibility_info == nullptr + st = api->GetCompatibilityInfoFromModel(ORT_TSTR("test.onnx"), "TestEP", allocator, nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); +} + +TEST(EpCompatibilityCapiTest, GetCompatibilityInfoFromModel_FileNotFound) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtAllocator* allocator = nullptr; + ASSERT_EQ(api->GetAllocatorWithDefaultOptions(&allocator), nullptr); + ASSERT_NE(allocator, nullptr); + + char* compat_info = nullptr; + OrtStatus* st = api->GetCompatibilityInfoFromModel(ORT_TSTR("nonexistent_model.onnx"), "TestEP", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_NO_SUCHFILE); + api->ReleaseStatus(st); +} + +TEST(EpCompatibilityCapiTest, GetCompatibilityInfoFromModelBytes_InvalidArgs) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtAllocator* allocator = nullptr; + ASSERT_EQ(api->GetAllocatorWithDefaultOptions(&allocator), nullptr); + ASSERT_NE(allocator, nullptr); + + char* compat_info = nullptr; + const char dummy_data[] = "dummy"; + + // model_data == nullptr + OrtStatus* st = api->GetCompatibilityInfoFromModelBytes(nullptr, 10, "TestEP", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // model_data_length == 0 + st = api->GetCompatibilityInfoFromModelBytes(dummy_data, 0, "TestEP", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // ep_type == nullptr + st = api->GetCompatibilityInfoFromModelBytes(dummy_data, sizeof(dummy_data), nullptr, allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // ep_type == empty string + st = api->GetCompatibilityInfoFromModelBytes(dummy_data, sizeof(dummy_data), "", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // allocator == nullptr + st = api->GetCompatibilityInfoFromModelBytes(dummy_data, sizeof(dummy_data), "TestEP", nullptr, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // compatibility_info == nullptr + st = api->GetCompatibilityInfoFromModelBytes(dummy_data, sizeof(dummy_data), "TestEP", allocator, nullptr); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); + + // model_data_length > INT_MAX (should return error, not crash) + // We can't actually allocate this much memory, but we can pass the size + // The API should validate the size before attempting to use the data + size_t oversized_length = static_cast(INT_MAX) + 1; + st = api->GetCompatibilityInfoFromModelBytes(dummy_data, oversized_length, "TestEP", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_ARGUMENT); + api->ReleaseStatus(st); +} + +TEST(EpCompatibilityCapiTest, GetCompatibilityInfoFromModelBytes_InvalidModelData) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtAllocator* allocator = nullptr; + ASSERT_EQ(api->GetAllocatorWithDefaultOptions(&allocator), nullptr); + ASSERT_NE(allocator, nullptr); + + char* compat_info = nullptr; + const char invalid_data[] = "this is not a valid ONNX model"; + + OrtStatus* st = api->GetCompatibilityInfoFromModelBytes(invalid_data, sizeof(invalid_data), "TestEP", allocator, &compat_info); + ASSERT_NE(st, nullptr); + EXPECT_EQ(api->GetErrorCode(st), ORT_INVALID_GRAPH); + api->ReleaseStatus(st); +} + +// Test extracting compatibility info from a model with metadata +TEST(EpCompatibilityCapiTest, GetCompatibilityInfoFromModelBytes_WithMetadata) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtAllocator* allocator = nullptr; + ASSERT_EQ(api->GetAllocatorWithDefaultOptions(&allocator), nullptr); + ASSERT_NE(allocator, nullptr); + + // Create a minimal ModelProto with compatibility metadata + ONNX_NAMESPACE::ModelProto model_proto; + model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + model_proto.mutable_graph()->set_name("test_graph"); + + // Add an opset import (required) + auto* opset = model_proto.add_opset_import(); + opset->set_domain(""); + opset->set_version(13); + + // Add compatibility metadata + const std::string ep_type = "TestCompatEP"; + const std::string expected_compat_info = "test_compat_v1.0_driver_123"; + auto* prop = model_proto.add_metadata_props(); + prop->set_key(std::string("ep_compatibility_info.") + ep_type); + prop->set_value(expected_compat_info); + + // Serialize the model + std::string model_data; + ASSERT_TRUE(model_proto.SerializeToString(&model_data)); + + // Extract compatibility info + char* compat_info = nullptr; + OrtStatus* st = api->GetCompatibilityInfoFromModelBytes( + model_data.data(), model_data.size(), ep_type.c_str(), allocator, &compat_info); + ASSERT_EQ(st, nullptr) << (st ? api->GetErrorMessage(st) : ""); + ASSERT_NE(compat_info, nullptr); + EXPECT_STREQ(compat_info, expected_compat_info.c_str()); + ASSERT_EQ(api->AllocatorFree(allocator, compat_info), nullptr); +} + +// Test when compatibility info is not found for the EP +TEST(EpCompatibilityCapiTest, GetCompatibilityInfoFromModelBytes_NotFound) { + const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); + ASSERT_NE(api, nullptr); + + OrtAllocator* allocator = nullptr; + ASSERT_EQ(api->GetAllocatorWithDefaultOptions(&allocator), nullptr); + ASSERT_NE(allocator, nullptr); + + // Create a minimal ModelProto without compatibility metadata for our EP + ONNX_NAMESPACE::ModelProto model_proto; + model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + model_proto.mutable_graph()->set_name("test_graph"); + + auto* opset = model_proto.add_opset_import(); + opset->set_domain(""); + opset->set_version(13); + + // Add metadata for a different EP + auto* prop = model_proto.add_metadata_props(); + prop->set_key("ep_compatibility_info.DifferentEP"); + prop->set_value("some_value"); + + std::string model_data; + ASSERT_TRUE(model_proto.SerializeToString(&model_data)); + + // Try to get compatibility info for an EP that doesn't have it + char* compat_info = nullptr; + OrtStatus* st = api->GetCompatibilityInfoFromModelBytes( + model_data.data(), model_data.size(), "NonExistentEP", allocator, &compat_info); + ASSERT_EQ(st, nullptr); // Not an error - just not found + EXPECT_EQ(compat_info, nullptr); // Should be nullptr when not found +} + +// C++ API test +TEST(EpCompatibilityCxxApiTest, GetCompatibilityInfoFromModelBytes) { + // Create a minimal ModelProto with compatibility metadata + ONNX_NAMESPACE::ModelProto model_proto; + model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + model_proto.mutable_graph()->set_name("test_graph"); + + auto* opset = model_proto.add_opset_import(); + opset->set_domain(""); + opset->set_version(13); + + const std::string ep_type = "CxxTestEP"; + const std::string expected_compat_info = "cxx_compat_v2.0"; + auto* prop = model_proto.add_metadata_props(); + prop->set_key(std::string("ep_compatibility_info.") + ep_type); + prop->set_value(expected_compat_info); + + std::string model_data; + ASSERT_TRUE(model_proto.SerializeToString(&model_data)); + + // Get allocator + Ort::AllocatorWithDefaultOptions allocator; + + // Test C++ API - found case + Ort::AllocatedStringPtr result = Ort::GetCompatibilityInfoFromModelBytesAllocated( + model_data.data(), model_data.size(), ep_type.c_str(), allocator); + ASSERT_NE(result.get(), nullptr); + EXPECT_STREQ(result.get(), expected_compat_info.c_str()); + + // Test when not found - should return nullptr + Ort::AllocatedStringPtr not_found = Ort::GetCompatibilityInfoFromModelBytesAllocated( + model_data.data(), model_data.size(), "NonExistentEP", allocator); + EXPECT_EQ(not_found.get(), nullptr); +} diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 07f2cc8581ed5..1c4e7800b7d2e 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -3089,5 +3089,70 @@ TEST(InferenceSessionTests, GraphResolveHandlesNodeWithSubgraphBeingRemoved) { ASSERT_STATUS_OK(session.Load(model_uri)); } +#ifdef ORT_ENABLE_STREAM +namespace { + +struct TestNotification : public synchronize::Notification { + explicit TestNotification(Stream& s) : Notification(s) {} + void Activate() override {} +}; + +struct TestOverrideStream : Stream { + TestOverrideStream(StreamHandle h, const OrtDevice& d) : Stream(h, d) {} + std::unique_ptr CreateNotification(size_t /*num_consumers*/) override { + return std::make_unique(*this); + } +}; +} // namespace + +TEST(DeviceStreamCollection, TestOverride) { + // We need an allocator map for the constructor, but it's not used in this test scenario. + AllocatorMap allocators; + DeviceStreamCollection collection(2, allocators, false); + + OrtDevice cpu_device(OrtDevice::CPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0); + OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NVIDIA, 0); + + auto cpu_stream = std::make_unique(nullptr, cpu_device); + auto* cpu_stream_ptr = cpu_stream.get(); + collection.AddDeviceStream(0, std::move(cpu_stream)); + + auto gpu_stream = std::make_unique(nullptr, gpu_device); + auto* gpu_stream_ptr = gpu_stream.get(); + collection.AddDeviceStream(1, std::move(gpu_stream)); + + ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr); + ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr); + + // 1. Override CPU stream + TestOverrideStream cpu_override_stream(nullptr, cpu_device); + ASSERT_STATUS_OK(collection.SetStreamOverride(&cpu_override_stream)); + + // Verify override took effect for correct device match + ASSERT_EQ(collection.GetStream(0), &cpu_override_stream); + ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr); + + // 2. Reset Override + collection.ResetStreamOverride(); + ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr); + ASSERT_EQ(collection.GetStream(1), gpu_stream_ptr); + + // 3. Override GPU stream + TestOverrideStream gpu_override_stream(nullptr, gpu_device); + ASSERT_STATUS_OK(collection.SetStreamOverride(&gpu_override_stream)); + + ASSERT_EQ(collection.GetStream(0), cpu_stream_ptr); + ASSERT_EQ(collection.GetStream(1), &gpu_override_stream); + + collection.ResetStreamOverride(); + + // 4. Override with non-matching device + OrtDevice other_device(OrtDevice::FPGA, OrtDevice::MemType::DEFAULT, OrtDevice::VendorIds::NONE, 0); + TestOverrideStream other_stream(nullptr, other_device); + ASSERT_FALSE(collection.SetStreamOverride(&other_stream).IsOK()); +} + +#endif // ORT_ENABLE_STREAM + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/int2_test.cc b/onnxruntime/test/framework/int2_test.cc new file mode 100644 index 0000000000000..cc7c4c1b54f97 --- /dev/null +++ b/onnxruntime/test/framework/int2_test.cc @@ -0,0 +1,322 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include "core/framework/int2.h" +#include "core/framework/data_types.h" +#include "core/framework/tensorprotoutils.h" +#include "core/platform/env.h" +#include "test/test_environment.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// ============================================== +// Int2x4 Tests (signed 2-bit integer, 4 per byte) +// ============================================== + +TEST(Int2_Tests, Int2x4_DefaultConstructor) { + Int2x4 int2; + EXPECT_EQ(static_cast(int2.ToBits()), 0); +} + +TEST(Int2_Tests, Int2x4_BitsConstructor) { + // Pack 4 signed 2-bit values: val0=1, val1=-1 (0b11), val2=-2 (0b10), val3=0 + // Binary: 0b00'10'11'01 = 0x2D + Int2x4 int2(std::byte{0x2D}); + EXPECT_EQ(int2.GetElem(0), 1); + EXPECT_EQ(int2.GetElem(1), -1); // 0b11 sign-extended is -1 + EXPECT_EQ(int2.GetElem(2), -2); // 0b10 sign-extended is -2 + EXPECT_EQ(int2.GetElem(3), 0); +} + +TEST(Int2_Tests, Int2x4_FourValueConstructor) { + Int2x4 int2(1, -1, -2, 0); + EXPECT_EQ(int2.GetElem(0), 1); + EXPECT_EQ(int2.GetElem(1), -1); + EXPECT_EQ(int2.GetElem(2), -2); + EXPECT_EQ(int2.GetElem(3), 0); +} + +TEST(Int2_Tests, Int2x4_GetSetElem) { + Int2x4 int2; + + // Set and get each element + int2.SetElem(0, 1); + int2.SetElem(1, -1); + int2.SetElem(2, -2); + int2.SetElem(3, 0); + + EXPECT_EQ(int2.GetElem(0), 1); + EXPECT_EQ(int2.GetElem(1), -1); + EXPECT_EQ(int2.GetElem(2), -2); + EXPECT_EQ(int2.GetElem(3), 0); +} + +TEST(Int2_Tests, Int2x4_ValueRange) { + // Verify min/max values + EXPECT_EQ(Int2x4::min_val, -2); + EXPECT_EQ(Int2x4::max_val, 1); + + // Test all valid signed 2-bit values: -2, -1, 0, 1 + Int2x4 int2(-2, -1, 0, 1); + EXPECT_EQ(int2.GetElem(0), -2); + EXPECT_EQ(int2.GetElem(1), -1); + EXPECT_EQ(int2.GetElem(2), 0); + EXPECT_EQ(int2.GetElem(3), 1); +} + +TEST(Int2_Tests, Int2x4_CalcNumInt2Quads) { + // 0 elements -> 0 bytes + EXPECT_EQ(Int2x4::CalcNumInt2Quads(0), 0u); + // 1 element -> 1 byte + EXPECT_EQ(Int2x4::CalcNumInt2Quads(1), 1u); + // 4 elements -> 1 byte + EXPECT_EQ(Int2x4::CalcNumInt2Quads(4), 1u); + // 5 elements -> 2 bytes + EXPECT_EQ(Int2x4::CalcNumInt2Quads(5), 2u); + // 8 elements -> 2 bytes + EXPECT_EQ(Int2x4::CalcNumInt2Quads(8), 2u); +} + +TEST(Int2_Tests, Int2x4_PackUnpack) { + std::vector src_values = {1, -1, -2, 0, 1, -1, -2, 0}; + std::vector packed(Int2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = Int2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = Int2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +TEST(Int2_Tests, Int2x4_PackUnpackOddElements) { + // Test with non-multiple-of-4 element count + std::vector src_values = {1, -1, -2}; + std::vector packed(Int2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = Int2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = Int2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +// ============================================== +// UInt2x4 Tests (unsigned 2-bit integer, 4 per byte) +// ============================================== + +TEST(Int2_Tests, UInt2x4_DefaultConstructor) { + UInt2x4 uint2; + EXPECT_EQ(static_cast(uint2.ToBits()), 0); +} + +TEST(Int2_Tests, UInt2x4_BitsConstructor) { + // Pack 4 unsigned 2-bit values: val0=0, val1=1, val2=2, val3=3 + // Binary: 0b11'10'01'00 = 0xE4 + UInt2x4 uint2(std::byte{0xE4}); + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_FourValueConstructor) { + UInt2x4 uint2(0, 1, 2, 3); + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_GetSetElem) { + UInt2x4 uint2; + + // Set and get each element + uint2.SetElem(0, 0); + uint2.SetElem(1, 1); + uint2.SetElem(2, 2); + uint2.SetElem(3, 3); + + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_ValueRange) { + // Verify min/max values + EXPECT_EQ(UInt2x4::min_val, 0); + EXPECT_EQ(UInt2x4::max_val, 3); + + // Test all valid unsigned 2-bit values: 0, 1, 2, 3 + UInt2x4 uint2(0, 1, 2, 3); + EXPECT_EQ(uint2.GetElem(0), 0); + EXPECT_EQ(uint2.GetElem(1), 1); + EXPECT_EQ(uint2.GetElem(2), 2); + EXPECT_EQ(uint2.GetElem(3), 3); +} + +TEST(Int2_Tests, UInt2x4_CalcNumInt2Quads) { + // Same as Int2x4 + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(0), 0u); + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(1), 1u); + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(4), 1u); + EXPECT_EQ(UInt2x4::CalcNumInt2Quads(5), 2u); +} + +TEST(Int2_Tests, UInt2x4_PackUnpack) { + std::vector src_values = {0, 1, 2, 3, 3, 2, 1, 0}; + std::vector packed(UInt2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = UInt2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = UInt2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +TEST(Int2_Tests, UInt2x4_PackUnpackOddElements) { + // Test with non-multiple-of-4 element count + std::vector src_values = {3, 2, 1}; + std::vector packed(UInt2x4::CalcNumInt2Quads(src_values.size())); + + // Pack + bool pack_result = UInt2x4::Pack(gsl::make_span(packed), gsl::make_span(src_values)); + EXPECT_TRUE(pack_result); + + // Unpack + std::vector unpacked(src_values.size()); + bool unpack_result = UInt2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(packed)); + EXPECT_TRUE(unpack_result); + + // Verify + for (size_t i = 0; i < src_values.size(); i++) { + EXPECT_EQ(unpacked[i], src_values[i]) << "Mismatch at index " << i; + } +} + +// ============================================== +// Additional edge case tests +// ============================================== + +TEST(Int2_Tests, Int2x4_AllSameValue) { + // All values are -2 (minimum signed value) + Int2x4 int2_min(-2, -2, -2, -2); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(int2_min.GetElem(i), -2); + } + + // All values are 1 (maximum signed value) + Int2x4 int2_max(1, 1, 1, 1); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(int2_max.GetElem(i), 1); + } +} + +TEST(Int2_Tests, UInt2x4_AllSameValue) { + // All values are 0 (minimum unsigned value) + UInt2x4 uint2_min(0, 0, 0, 0); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(uint2_min.GetElem(i), 0); + } + + // All values are 3 (maximum unsigned value) + UInt2x4 uint2_max(3, 3, 3, 3); + for (size_t i = 0; i < 4; i++) { + EXPECT_EQ(uint2_max.GetElem(i), 3); + } +} + +TEST(Int2_Tests, Int2x4_BitManipulation) { + // Test that ToBits returns correct packed representation + Int2x4 int2(0, 1, -1, -2); // 0b00, 0b01, 0b11, 0b10 + // Expected: 0b10'11'01'00 = 0xB4 + EXPECT_EQ(static_cast(int2.ToBits()), 0xB4); +} + +// ============================================== +// TypeProto / TypeFromProto smoke checks +// ============================================== + +TEST(Int2_Tests, TensorTypeFromONNXEnum_Int2UInt2) { + auto* int2_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_INT2); + auto* uint2_type = DataTypeImpl::TensorTypeFromONNXEnum(ONNX_NAMESPACE::TensorProto_DataType_UINT2); + + ASSERT_NE(int2_type, nullptr); + ASSERT_NE(uint2_type, nullptr); + EXPECT_EQ(int2_type->GetElementType(), DataTypeImpl::GetType()); + EXPECT_EQ(uint2_type->GetElementType(), DataTypeImpl::GetType()); +} + +TEST(Int2_Tests, TypeFromProto_TensorProto_Int2) { + ONNX_NAMESPACE::TypeProto tp; + tp.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TensorProto_DataType_INT2); + auto mltype = DataTypeImpl::TypeFromProto(tp); + ASSERT_NE(mltype, nullptr); + const auto* tensor_type = mltype->AsTensorType(); + ASSERT_NE(tensor_type, nullptr); + EXPECT_EQ(tensor_type->GetElementType(), DataTypeImpl::GetType()); +} + +TEST(Int2_Tests, TensorProtoRoundTrip_Int2) { + // Build a tiny TensorProto with raw_data containing 2 bytes (8 int2 elements packed) + ONNX_NAMESPACE::TensorProto proto; + proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT2); + proto.add_dims(8); + // pack values [1, -1, -2, 0, 1, -1, -2, 0] + std::array values = {1, -1, -2, 0, 1, -1, -2, 0}; + std::vector packed(Int2x4::CalcNumInt2Quads(values.size())); + ASSERT_TRUE(Int2x4::Pack(gsl::make_span(packed), gsl::make_span(values))); + proto.set_raw_data(packed.data(), packed.size() * sizeof(Int2x4)); + + Tensor result; + // Use CreateTensorFromTensorProto which pre-allocates the tensor with proper shape + ORT_THROW_IF_ERROR(utils::CreateTensorFromTensorProto(onnxruntime::Env::Default(), std::filesystem::path{}, proto, result)); + ASSERT_TRUE(result.IsDataType()); + const auto* data = result.Data(); + std::vector unpacked(values.size()); + ASSERT_TRUE(Int2x4::Unpack(gsl::make_span(unpacked), gsl::make_span(data, packed.size()))); + for (size_t i = 0; i < values.size(); ++i) { + EXPECT_EQ(unpacked[i], values[i]) << "Mismatch at index " << i; + } +} + +TEST(Int2_Tests, UInt2x4_BitManipulation) { + // Test that ToBits returns correct packed representation + UInt2x4 uint2(3, 2, 1, 0); // 0b11, 0b10, 0b01, 0b00 + // Expected: 0b00'01'10'11 = 0x1B + EXPECT_EQ(static_cast(uint2.ToBits()), 0x1B); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/sparse_kernels_test.cc b/onnxruntime/test/framework/sparse_kernels_test.cc index 89e928af10b8b..d2842d203dccd 100644 --- a/onnxruntime/test/framework/sparse_kernels_test.cc +++ b/onnxruntime/test/framework/sparse_kernels_test.cc @@ -756,8 +756,7 @@ static NodeProto CreateConstantNodeAllZeros(bool indices_1D, std::vector& exp constant_node.set_op_type("Constant"); constant_node.add_output("dense_tensor_output"); - std::vector indices; - std::vector shape{2, 3, 2}; + const std::array shape{2, 3, 2}; AttributeProto& attrib = *constant_node.mutable_attribute()->Add(); attrib.set_name("sparse_value_all_zeros"); @@ -772,7 +771,7 @@ static NodeProto CreateConstantNodeAllZeros(bool indices_1D, std::vector& exp } else { // indices are shape {NNZ, rank} so convert flattened values of 2, 5, 6 and 10 to rank 3 values indices_tp.add_dims(0); - indices_tp.add_dims(0); + indices_tp.add_dims(3); } indices_tp.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); @@ -1870,7 +1869,387 @@ TEST(SparseTensorConversionTests, BlockSparse) { indices_span.begin(), indices_span.end())); } } -#endif // !defined(DISABLE_SPARSE_TENSORS) +template +void TestSparseToDenseConversion(gsl::span dense_shape, + const std::vector& values, + gsl::span indices, + gsl::span indices_shape, + bool raw_data_indices, + const std::vector& expected_dense_data) { + ONNX_NAMESPACE::SparseTensorProto sparse_proto; + for (auto dim : dense_shape) { + sparse_proto.add_dims(dim); + } + + // Create values tensor + auto* values_tensor = sparse_proto.mutable_values(); + values_tensor->set_name("values"); + // Simplification: assuming float/int32 for now based on T + if constexpr (std::is_same_v) { + values_tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + for (float v : values) values_tensor->add_float_data(v); + } else if constexpr (std::is_same_v) { + values_tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + for (int32_t v : values) values_tensor->add_int32_data(v); + } + // Set values shape [NNZ] + values_tensor->add_dims(values.size()); + + // Create indices tensor + auto* indices_tensor = sparse_proto.mutable_indices(); + indices_tensor->set_name("indices"); + indices_tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + if constexpr (std::is_same_v) { + indices_tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT8); + if (raw_data_indices) { + indices_tensor->set_raw_data(indices.data(), indices.size() * sizeof(I)); + } else { + for (auto idx : indices) { + indices_tensor->add_int32_data(static_cast(idx)); // indices are stored in int32_data for types < int32 + } + } + } else if constexpr (std::is_same_v) { + indices_tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT16); + if (raw_data_indices) { + indices_tensor->set_raw_data(indices.data(), indices.size() * sizeof(I)); + } else { + for (auto idx : indices) { + indices_tensor->add_int32_data(static_cast(idx)); + } + } + } else if constexpr (std::is_same_v) { + indices_tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT32); + if (raw_data_indices) { + indices_tensor->set_raw_data(indices.data(), indices.size() * sizeof(I)); + } else { + for (auto idx : indices) { + indices_tensor->add_int32_data(idx); + } + } + } else if constexpr (std::is_same_v) { + indices_tensor->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + if (raw_data_indices) { + indices_tensor->set_raw_data(indices.data(), indices.size() * sizeof(I)); + } else { + for (auto idx : indices) { + indices_tensor->add_int64_data(idx); + } + } + } + for (auto dim : indices_shape) { + indices_tensor->add_dims(dim); + } + + ONNX_NAMESPACE::TensorProto dense_proto; + std::filesystem::path model_path; // empty path + ASSERT_STATUS_OK(utils::SparseTensorProtoToDenseTensorProto(sparse_proto, model_path, dense_proto)); + + // Verify dense proto + ASSERT_EQ(dense_proto.dims_size(), dense_shape.size()); + for (size_t i = 0; i < (size_t)dense_shape.size(); ++i) { + ASSERT_EQ(dense_proto.dims(static_cast(i)), dense_shape[i]); + } + + std::vector unpacked_data(expected_dense_data.size()); + ASSERT_STATUS_OK(utils::UnpackTensor(dense_proto, model_path, unpacked_data.data(), unpacked_data.size())); + + EXPECT_EQ(unpacked_data, expected_dense_data); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_Rank1Indices64) { + // Dense Shape: [2, 2] -> 4 elements + // Indices: [0, 3] (linear) + // Values: [1.0, 2.0] + // Expected: [1.0, 0.0, 0.0, 2.0] + std::vector dense_shape = {2, 2}; + std::vector values = {1.0f, 2.0f}; + std::vector indices = {0, 3}; + std::vector indices_shape = {2}; // [NNZ] + std::vector expected = {1.0f, 0.0f, 0.0f, 2.0f}; + + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, false, expected); + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, true, expected); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_Rank1Indices32) { + // Dense Shape: [2, 2] -> 4 elements + // Indices: [0, 3] (linear) + // Values: [1.0, 2.0] + // Expected: [1.0, 0.0, 0.0, 2.0] + std::vector dense_shape = {2, 2}; + std::vector values = {1.0f, 2.0f}; + std::vector indices = {0, 3}; + std::vector indices_shape = {2}; // [NNZ] + std::vector expected = {1.0f, 0.0f, 0.0f, 2.0f}; + + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, false, expected); + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, true, expected); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_Rank1Indices16) { + // Dense Shape: [2, 2] -> 4 elements + // Indices: [0, 3] (linear) + // Values: [1.0, 2.0] + // Expected: [1.0, 0.0, 0.0, 2.0] + std::vector dense_shape = {2, 2}; + std::vector values = {1.0f, 2.0f}; + std::vector indices = {0, 3}; + std::vector indices_shape = {2}; // [NNZ] + std::vector expected = {1.0f, 0.0f, 0.0f, 2.0f}; + + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, false, expected); + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, true, expected); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_Rank1Indices8) { + // Dense Shape: [2, 2] -> 4 elements + // Indices: [0, 3] (linear) + // Values: [1.0, 2.0] + // Expected: [1.0, 0.0, 0.0, 2.0] + std::vector dense_shape = {2, 2}; + std::vector values = {1.0f, 2.0f}; + std::vector indices = {0, 3}; + std::vector indices_shape = {2}; // [NNZ] + std::vector expected = {1.0f, 0.0f, 0.0f, 2.0f}; + + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, false, expected); + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, true, expected); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_Rank2Indices_COO) { + // Dense Shape: [3, 3] -> 9 elements + // Indices: [[0, 0], [1, 1], [2, 2]] -> flattened: 0,0, 1,1, 2,2 + // Shape: [3, 2] (NNZ=3, Rank=2) + // Values: [10, 20, 30] + // Expected: [10, 0, 0, 0, 20, 0, 0, 0, 30] + std::vector dense_shape = {3, 3}; + std::vector values = {10, 20, 30}; + std::vector indices = {0, 0, 1, 1, 2, 2}; + std::vector indices_shape = {3, 2}; + std::vector expected = { + 10, 0, 0, + 0, 20, 0, + 0, 0, 30}; + + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, false, expected); + TestSparseToDenseConversion(dense_shape, values, indices, indices_shape, true, expected); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_OutOfBounds_Rank1) { + // Dense size 4 + // Index 5 -> Out of bounds + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor"); + sparse.add_dims(4); + + auto* val = sparse.mutable_values(); + val->add_dims(1); + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(1); + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(5); // Out of bounds + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor index is out of bounds")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_OutOfBounds_Rank2) { + // Dense Shape [2, 2] -> linear 0..3 + // Index [2, 0] -> linear 4 -> Out of bounds + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor"); + sparse.add_dims(2); + sparse.add_dims(2); + + auto* val = sparse.mutable_values(); + val->add_dims(1); + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(1); // NNZ=1 + ind->add_dims(2); // Rank=2 + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(2); + ind->add_int64_data(0); + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor index is out of bounds")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_OutOfBounds_Rank2_Dim1) { + // Dense Shape [2, 2] + // Index [0, 2] -> 2 is out of bounds for the 2nd dimension (size 2) + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor_dim1_oob"); + sparse.add_dims(2); + sparse.add_dims(2); + + auto* val = sparse.mutable_values(); + val->add_dims(1); + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(1); // NNZ=1 + ind->add_dims(2); // Rank=2 + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(0); + ind->add_int64_data(2); // Out of bounds for dim 1 + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor_dim1_oob index is out of bounds")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InvalidValuesRank) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor"); + sparse.add_dims(10); + + auto* val = sparse.mutable_values(); + // Set values rank to 2 (invalid) + val->add_dims(1); + val->add_dims(1); + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(1); + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(0); + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor values should be rank 1")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_NegativeValuesShape) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor"); + sparse.add_dims(10); // Dense shape + + auto* val = sparse.mutable_values(); + val->add_dims(-5); // Negative dimension in values + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(1); + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(0); + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor tensor dims expected to be non-negative")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_NegativeDenseShape) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor"); + sparse.add_dims(10); + sparse.add_dims(-2); // Negative dimension in dense shape + + auto* val = sparse.mutable_values(); + val->add_dims(1); + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(1); + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(0); + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor dense dims expected to be non-negative")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InvalidValuesRank_Zero) { + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor_val_rank_0"); + sparse.add_dims(10); + + auto* val = sparse.mutable_values(); + // No dims added -> Rank 0 (Scalar) + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(1); + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(0); + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor_val_rank_0 values should be rank 1")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_ValuesSizeMismatch) { + // Case where the actual data in 'values' doesn't match the dimension specified in 'values' + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor_val_size_mismatch"); + sparse.add_dims(10); + + auto* val = sparse.mutable_values(); + val->add_dims(2); // Claiming 2 elements + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + val->add_float_data(1.0f); + // Only added 1 element, this should fail during UnpackInitializerData or subsequent checks depending on where it's caught + // Note: UnpackTensor checks if size matches. + + auto* ind = sparse.mutable_indices(); + ind->add_dims(2); + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(0); + ind->add_int64_data(1); + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + // The error comes from UnpackTensor usually + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("data size")); +} + +TEST(SparseTensorConversionTests, SparseTensorProtoToDense_ValuesSizeMismatch_RawData) { + // Case where raw data size doesn't match the shape size * element size + ONNX_NAMESPACE::SparseTensorProto sparse; + sparse.mutable_values()->set_name("test_tensor_val_size_mismatch_raw"); + sparse.add_dims(10); + + auto* val = sparse.mutable_values(); + val->add_dims(2); // Claiming 2 elements + val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + + // 1 float is 4 bytes. We provide 4 bytes, but claim 2 elements (8 bytes needed). + float raw_val = 1.0f; + val->set_raw_data(&raw_val, sizeof(float)); + + auto* ind = sparse.mutable_indices(); + ind->add_dims(2); + ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64); + ind->add_int64_data(0); + ind->add_int64_data(1); + + ONNX_NAMESPACE::TensorProto dense; + auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense); + EXPECT_FALSE(status.IsOK()); + EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("values data size does not match expected")); +} + +#endif // !defined(DISABLE_SPARSE_TENSORS) } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/framework/tensorutils_test.cc b/onnxruntime/test/framework/tensorutils_test.cc index 0d7b583faf27b..c9b61a7a39632 100644 --- a/onnxruntime/test/framework/tensorutils_test.cc +++ b/onnxruntime/test/framework/tensorutils_test.cc @@ -530,9 +530,6 @@ TEST_F(PathValidationTest, ValidateExternalDataPath) { // Valid relative path. ASSERT_STATUS_OK(utils::ValidateExternalDataPath(base_dir_, "data.bin")); - // Empty base directory. - ASSERT_STATUS_OK(utils::ValidateExternalDataPath("", "data.bin")); - // Empty location. // Only validate it is not an absolute path. ASSERT_TRUE(utils::ValidateExternalDataPath(base_dir_, "").IsOK()); @@ -555,6 +552,29 @@ TEST_F(PathValidationTest, ValidateExternalDataPath) { // Base directory does not exist. ASSERT_STATUS_OK(utils::ValidateExternalDataPath("non_existent_dir", "data.bin")); + + // + // Tests for an empty base directory. + // The base directory would be empty when 1) the session loads a model from bytes and 2) the application does not + // set an external file folder path via the session config option + // kOrtSessionOptionsModelExternalInitializersFileFolderPath. + // + + // A simple filename is ok (would not escape current working directory). + ASSERT_STATUS_OK(utils::ValidateExternalDataPath("", "data.bin")); + ASSERT_STATUS_OK(utils::ValidateExternalDataPath("", "./data.bin")); + + // A ".." that is not a path component (part of the filename) is ok + ASSERT_STATUS_OK(utils::ValidateExternalDataPath("", "data..bin")); + + // A path that would escape the current working directory is invalid. + ASSERT_FALSE(utils::ValidateExternalDataPath("", "../data.bin").IsOK()); + + // A path that uses ".." but would not escape the current working directory should be fine. + ASSERT_STATUS_OK(utils::ValidateExternalDataPath("", "a/../data.bin")); + + // A path with multiple internal ".." that would escape current working direction should fail. + ASSERT_FALSE(utils::ValidateExternalDataPath("", "a/../../data.bin").IsOK()); } TEST_F(PathValidationTest, ValidateExternalDataPathWithSymlinkInside) { diff --git a/onnxruntime/test/mlas/bench/bench_sconv.cpp b/onnxruntime/test/mlas/bench/bench_sconv.cpp index dc37980002978..849911e322214 100644 --- a/onnxruntime/test/mlas/bench/bench_sconv.cpp +++ b/onnxruntime/test/mlas/bench/bench_sconv.cpp @@ -326,6 +326,10 @@ static void TeamsModel(benchmark::internal::Benchmark* b) { b->Args({2, 1, 1, 12, 12, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // fused Conv_376 => 48x80 b->Args({2, 1, 1, 12, 72, 48, 80, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}); // Conv_59 => 24x40 + + b->Args({2, 1, 256, 1, 1, 378, 378, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // External customer model + b->Args({2, 1, 512, 1, 1, 378, 378, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // External customer model + b->Args({2, 1, 960, 1, 1, 378, 378, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1}); // External customer model } BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp index e5a536eb9e4f0..d8b76407edf08 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.cpp @@ -12,6 +12,14 @@ static size_t Conv2dNchwcRegistLongExecute() { if (GetMlasThreadPool() != nullptr) { count += MlasLongExecuteTests>::RegisterLongExecute(); } +#if defined(__aarch64__) && defined(__linux__) + if (MlasBf16AccelerationSupported()) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } +#endif } return count; @@ -25,6 +33,14 @@ static size_t Conv2dNchwcRegistShortExecute() { if (GetMlasThreadPool() != nullptr) { count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); } +#if defined(__aarch64__) && defined(__linux__) + if (MlasBf16AccelerationSupported()) { + count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); + if (GetMlasThreadPool() != nullptr) { + count += Conv2dShortExecuteTest>::RegisterShortExecuteTests(); + } + } +#endif } return count; diff --git a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h index c125720668381..c1162c8d150c4 100644 --- a/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h +++ b/onnxruntime/test/mlas/unittest/test_conv2d_nchwc.h @@ -8,6 +8,8 @@ template class MlasNchwcConv2DTest : public MlasConv2DTest { protected: + bool UseBf16_ = false; + void MlasConv2D( size_t BatchCount, size_t GroupCount, @@ -131,7 +133,8 @@ class MlasNchwcConv2DTest : public MlasConv2DTest { NchwcOutput, &Activation, true, - MlasConv2DTest::threadpool_); + MlasConv2DTest::threadpool_, + UseBf16_); // // Reorder the output buffer. @@ -224,3 +227,51 @@ class MlasNchwcConv2DTest : public MlasConv2DTest { } } }; + +#if defined(__aarch64__) && defined(__linux__) +template +class MlasNchwcConv2DBf16Test : public MlasNchwcConv2DTest { + public: + MlasNchwcConv2DBf16Test() { this->UseBf16_ = true; } + + static const char* GetTestSuiteName() { + static const std::string suite_name(Threaded ? "Conv2dNchwcBf16_Threaded" : "Conv2dNchwcBf16_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong() override { + // BF16 pointwise tests (1x1 kernel, no padding, InputChannels >= BlockSize) + for (unsigned ic : {32u, 64u, 128u}) { + for (unsigned fc : {32u, 64u, 128u}) { + for (unsigned sz : {28u, 14u, 7u}) { + TestBf16(1, 1, ic, sz, sz, fc, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); + TestBf16(1, 1, ic, sz, sz, fc, 1, 1, 0, 0, 0, 0, 1, 1, 2, 2); + TestBf16(4, 1, ic, sz, sz, fc, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1); + } + } + } + } + + private: + void TestBf16(size_t B, size_t G, size_t IC, size_t IH, size_t IW, size_t FC, + size_t KH, size_t KW, size_t p0, size_t p1, size_t p2, size_t p3, + size_t DH, size_t DW, size_t SH, size_t SW) { + size_t OH = (IH + p0 + p2 - DH * (KH - 1) - 1) / SH + 1; + size_t OW = (IW + p1 + p3 - DW * (KW - 1) - 1) / SW + 1; + size_t OutputElements = B * G * FC * OH * OW; + + const float* Input = MlasConv2DTest::BufferInput.GetBuffer(B * G * IC * IH * IW); + const float* Filter = MlasConv2DTest::BufferFilter.GetBuffer(G * FC * IC * KH * KW); + const float* Bias = MlasConv2DTest::BufferBias.GetBuffer(G * FC); + float* Output = MlasConv2DTest::BufferOutput.GetBuffer(OutputElements); + float* OutputRef = MlasConv2DTest::BufferOutputReference.GetBuffer(OutputElements); + + this->MlasConv2D(B, G, IC, IH, IW, FC, KH, KW, p0, p1, p2, p3, DH, DW, SH, SW, OH, OW, Input, Filter, Bias, Output); + MlasConv2DTest::ReferenceConv2D(B, G, IC, IH, IW, FC, KH, KW, p0, p1, DH, DW, SH, SW, OH, OW, Input, Filter, Bias, OutputRef); + + for (size_t i = 0; i < OutputElements; i++) { + ASSERT_TRUE(CloseEnough(Output[i], OutputRef[i])) << " @" << i << " got " << Output[i] << " expected " << OutputRef[i]; + } + } +}; +#endif diff --git a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp index 83f5b7f106d3e..4cedef6588125 100644 --- a/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp +++ b/onnxruntime/test/mlas/unittest/test_dynamic_qgemm.cpp @@ -6,8 +6,12 @@ #include "mlas.h" #include "test_util.h" +#include "core/mlas/inc/mlas.h" -class MlasDynamicQgemmTest { +#include +#include + +class MlasDynamicQgemmTestBase { private: MatrixGuardBuffer buffer_a; MatrixGuardBuffer buffer_bf; @@ -15,10 +19,19 @@ class MlasDynamicQgemmTest { MatrixGuardBuffer buffer_c; MatrixGuardBuffer buffer_c_ref; - public: - void Test(size_t M, size_t N, size_t K, size_t BatchSize) { - // Setup buffers for holding various data + protected: + void Run(size_t M, size_t N, size_t K, size_t BatchSize, + MLAS_THREADPOOL* threadpool, bool require_threadpool, const char* run_tag) { + if (require_threadpool && threadpool == nullptr) + GTEST_SKIP() << "Dynamic QGEMM threading path requested but no MLAS thread pool is available."; + + // The test harness assumes K>0 for generating/quantizing B (computes per-column min/max across K). + // When K==0, the buffers are size 0 and the min/max logic dereferences invalid memory. + if (K == 0) { + GTEST_SKIP() << "Skipping DynamicQGEMM test with K==0: test harness assumes K>0 for quantization setup."; + } + // Setup buffers for holding various data float* A = buffer_a.GetBuffer(M * K * BatchSize); // Buffer for holding floating point version of weight matrix float* Bf = buffer_bf.GetBuffer(K * N * BatchSize); @@ -36,6 +49,9 @@ class MlasDynamicQgemmTest { // Quantize Bf → Bq and compute per-column scale and bias per batch std::vector> b_scale_batches(BatchSize, std::vector(N)); std::vector> b_bias_batches(BatchSize, std::vector(N, 0.0f)); + std::vector> a_quant_batches(BatchSize, std::vector(M * K)); + std::vector> a_scale_batches(BatchSize, std::vector(M)); + std::vector> a_zero_point_batches(BatchSize, std::vector(M)); for (size_t b = 0; b < BatchSize; ++b) { for (size_t n = 0; n < N; ++n) { @@ -58,9 +74,47 @@ class MlasDynamicQgemmTest { } } + // Quantize A rows to match the dynamic quantization performed by the kernel. + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t m = 0; m < M; ++m) { + float min_val = std::numeric_limits::max(); + float max_val = std::numeric_limits::lowest(); + for (size_t k = 0; k < K; ++k) { + float v = A[b * M * K + m * K + k]; + min_val = std::min(min_val, v); + max_val = std::max(max_val, v); + } + float rmin = std::min(0.0f, min_val); + float rmax = std::max(0.0f, max_val); + float inv_scale = (rmax == rmin) ? 1.0f : 255.0f / (rmax - rmin); + float scale = inv_scale ? 1.0f / inv_scale : 0.0f; + float descaled_min = rmin * inv_scale; + float descaled_max = rmax * inv_scale; + float zero_point_from_min_error = -128.0f + descaled_min; + float zero_point_from_max_error = 127.0f + descaled_max; + float zero_point = (zero_point_from_min_error + zero_point_from_max_error > 0.0f) + ? (-128.0f - descaled_min) + : (127.0f - descaled_max); + zero_point = std::clamp(zero_point, -128.0f, 127.0f); + int32_t zp = static_cast(std::nearbyint(zero_point)); + + a_scale_batches[b][m] = scale; + a_zero_point_batches[b][m] = zp; + + for (size_t k = 0; k < K; ++k) { + float v = A[b * M * K + m * K + k]; + int32_t q = static_cast(std::round(v * inv_scale)) + zp; + q = std::clamp(q, -128, 127); + a_quant_batches[b][m * K + k] = static_cast(q); + } + } + } + // Prepare kernel parameters MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K}; - std::vector packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K)); + + const size_t packed_b_stride = MlasDynamicQgemmPackBSize(N, K); + std::vector packed_b_storage(BatchSize * packed_b_stride); std::vector params(BatchSize); for (size_t b = 0; b < BatchSize; ++b) { @@ -68,26 +122,33 @@ class MlasDynamicQgemmTest { params[b].lda = K; params[b].C = C + b * M * N; params[b].ldc = N; - // Pack b matrix using MlasDynamicQgemmPackBSize & MlasDynamicQgemmPackB - void* packed_b = packed_b_storage.data() + b * MlasDynamicQgemmPackBSize(N, K); - MlasDynamicQgemmPackB(N, K, - Bq + b * K * N, - b_scale_batches[b].data(), - b_bias_batches[b].data(), - packed_b); + + // Pack b matrix using MlasDynamicQgemmPackBSize & MlasDynamicQgemmPackB. + // When packed_b_stride is 0 (e.g., degenerate shapes like K==0), avoid taking data() + // from a zero-sized vector as that may be null/invalid on some platforms. + void* packed_b = packed_b_stride == 0 ? nullptr : (packed_b_storage.data() + b * packed_b_stride); + + if (packed_b_stride != 0) { + MlasDynamicQgemmPackB(N, K, + Bq + b * K * N, + b_scale_batches[b].data(), + b_bias_batches[b].data(), + packed_b); + } + params[b].PackedB = packed_b; } - // call MlasDynamicQGemmBatch Function - MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr); - // Compute reference result for (size_t b = 0; b < BatchSize; ++b) { for (size_t m = 0; m < M; ++m) { for (size_t n = 0; n < N; ++n) { float sum = 0.0f; + const float a_scale = a_scale_batches[b][m]; + const int32_t a_zero_point = a_zero_point_batches[b][m]; for (size_t k = 0; k < K; ++k) { - float a = A[b * M * K + m * K + k]; + int32_t a_q = static_cast(a_quant_batches[b][m * K + k]); + float a = static_cast(a_q - a_zero_point) * a_scale; float bval = static_cast(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n]; sum += a * bval; } @@ -96,45 +157,70 @@ class MlasDynamicQgemmTest { } } + std::fill(C, C + M * N * BatchSize, 0.0f); + MlasDynamicQGemmBatch(shape, params.data(), BatchSize, threadpool); + // Validate results - for (size_t i = 0; i < M * N * BatchSize; ++i) { - float abs_c_ref = std::abs(CRef[i]); - float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; - float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); - float abs_tol = 3.0f; - float allowed = std::max(rel_tol, abs_tol); - float diff = std::abs(C[i] - CRef[i]); - ASSERT_LE(diff, allowed); - } + auto validate = [&](const char* tag) { + SCOPED_TRACE(tag); + + for (size_t i = 0; i < M * N * BatchSize; ++i) { + float abs_tol = 0.001f; + float diff = std::abs(C[i] - CRef[i]); + ASSERT_LE(diff, abs_tol); + } + }; + + validate(run_tag); + } +}; + +class MlasDynamicQgemmSingleThreadTest : public MlasDynamicQgemmTestBase { + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + if (!MlasIsDynamicQGemmAvailable()) + GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME or SME2 but it was not detected. Skipping test."; + Run(M, N, K, BatchSize, /*threadpool*/ nullptr, /*require_threadpool*/ false, "SingleThread"); } + static const char* GetTestSuiteName() { return "DynamicQgemmSingleThread"; } +}; - static const char* GetTestSuiteName() { - return "DynamicQgemm"; +class MlasDynamicQgemmThreadPoolTest : public MlasDynamicQgemmTestBase { + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + if (!MlasIsDynamicQGemmAvailable()) + GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME or SME2 but it was not detected. Skipping test."; + MLAS_THREADPOOL* tp = GetMlasThreadPool(); + if (!tp) GTEST_SKIP() << "Mlas thread pool not available"; + Run(M, N, K, BatchSize, tp, /*require_threadpool*/ true, "ThreadPool"); } + static const char* GetTestSuiteName() { return "DynamicQgemmThreaded"; } }; -class DynamicQgemmExecuteTest : public MlasTestFixture { +template +class DynamicQgemmExecuteTest : public MlasTestFixture { public: DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize) : M_(M), N_(N), K_(K), BatchSize_(BatchSize) {} void TestBody() override { - this->mlas_tester->Test(M_, N_, K_, BatchSize_); + MlasTestFixture::mlas_tester->Test(M_, N_, K_, BatchSize_); } static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) { std::stringstream ss; ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize; - std::string test_name = ss.str(); testing::RegisterTest( - MlasDynamicQgemmTest::GetTestSuiteName(), + TMlasTester::GetTestSuiteName(), test_name.c_str(), nullptr, test_name.c_str(), __FILE__, __LINE__, - [=]() -> MlasTestFixture* { + [=]() -> MlasTestFixture* { return new DynamicQgemmExecuteTest(M, N, K, BatchSize); }); @@ -151,6 +237,15 @@ class DynamicQgemmExecuteTest : public MlasTestFixture { for (size_t K : sizes) for (size_t B : batch_size) count += RegisterSingleTest(M, N, K, B); + + // Zero-dimension probes: these should exercise the early-return behavior in implementations + // that treat M==0 or N==0 as a no-op. + count += RegisterSingleTest(0, 16, 16, 1); // M==0 + count += RegisterSingleTest(16, 0, 16, 1); // N==0 + + // K==0 probe: included to observe behavior when the reduction dimension is zero. + count += RegisterSingleTest(16, 16, 0, 1); // K==0 + return count; } @@ -158,11 +253,10 @@ class DynamicQgemmExecuteTest : public MlasTestFixture { size_t M_, N_, K_, BatchSize_; }; -static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { - // Only register tests if MlasDynamicQGemmBatch() has an implementation available. - if (!MlasIsDynamicQGemmAvailable()) { - return size_t{0}; - } +static UNUSED_VARIABLE bool added_single = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +}); - return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +static UNUSED_VARIABLE bool added_threaded = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); }); diff --git a/onnxruntime/test/mlas/unittest/test_eltwise.cpp b/onnxruntime/test/mlas/unittest/test_eltwise.cpp index c4d4b9c0eb317..136d3a9a756b4 100644 --- a/onnxruntime/test/mlas/unittest/test_eltwise.cpp +++ b/onnxruntime/test/mlas/unittest/test_eltwise.cpp @@ -97,10 +97,62 @@ class MlasEltwiseAddTest : public MlasTestBase { } }; +class MlasEltwiseMulTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputLeft; + MatrixGuardBuffer BufferInputRight; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + float* InputLeft = BufferInputLeft.GetBuffer(N); + float* InputRight = BufferInputRight.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = distribution(generator); + InputRight[n] = ScalarValue.value_or(distribution(generator)); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = InputLeft[n] * InputRight[n]; + } + + MlasEltwiseMul(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Eltwise_Mul"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); + Test(n, -10.f, 10.f, -5000.f); + } + } +}; + static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); } return count; }); diff --git a/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp new file mode 100644 index 0000000000000..3ff4fee69eac9 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_rope_neon_fp16.cpp @@ -0,0 +1,104 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope_neon_fp16.cpp + +Abstract: + + Tests for MLAS fp16 RoPE on NEON. + +--*/ + +#include +#include + +#include "core/mlas/inc/mlas.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/rotary_embedding.h" +#include "core/mlas/lib/rotary_embedding_kernel_neon.h" + +class MlasNeonFp16RoPETest : public MlasTestBase { + private: + const float Pi = 2 * std::acos(0.0f); + + void Test(size_t rotary_emb_dim, bool interleaved) { + // Per kernel logic (both fallback and optimized), the sin/cos tables + // are always half the rotary embedding dimension. + const size_t table_len = rotary_emb_dim / 2; + + std::vector input(rotary_emb_dim); + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim); + std::vector output_impl(rotary_emb_dim); + + // Initialize input data + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = MLAS_FP16(static_cast(i + 1)); + } + + // Initialize sin/cos tables + for (size_t i = 0; i < table_len; ++i) { + float theta = static_cast(i) / 1000.0f * Pi; + sin_data[i] = MLAS_FP16(std::sin(theta)); + cos_data[i] = MLAS_FP16(std::cos(theta)); + } + + // Call fallback implementation + MlasRotaryEmbedOneRow_FallBack(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_ref.data()); + + // Call dispatched implementation (which should pick up the NEON kernel) + MlasRotaryEmbedOneRow(input.data(), sin_data.data(), cos_data.data(), rotary_emb_dim, interleaved, output_impl.data()); + + // Compare results + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_TRUE(CloseEnough(output_impl[i].ToFloat(), output_ref[i].ToFloat())) + << "Expected bits: " << output_ref[i].val << " (" << output_ref[i].ToFloat() << ")" + << " Actual bits: " << output_impl[i].val << " (" << output_impl[i].ToFloat() << ")" + << " @[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16RoPE"; + } + + void ExecuteShort(void) override { + // Test dimensions that cover main loops and various remainders + Test(6, false); + Test(6, true); + Test(16, false); + Test(16, true); + Test(24, false); + Test(24, true); + Test(32, false); + Test(32, true); + Test(42, false); + Test(42, true); + Test(64, false); + Test(64, true); + Test(70, false); + Test(70, true); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp new file mode 100644 index 0000000000000..12ec5ec78f599 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sqlutgemm.cpp @@ -0,0 +1,240 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sqlutgemm.cpp + +Abstract: + + Tests for MLAS LUT-based n-bit GEMM (TMAC/LUT path) for 2-bit.. + +--*/ + +#include "test_util.h" +#include "mlas_qnbit.h" +#include "mlas_q4.h" + +// Generic template to future-proof for different bit widths; instantiate with 2 for now. +template +class MlasSQLutGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + + MatrixGuardBuffer BufferQuantBData; + MatrixGuardBuffer BufferQuantBScale; + MatrixGuardBuffer BufferQuantBZeroPoint; + MatrixGuardBuffer BufferDequantizedB; + MatrixGuardBuffer BufferPackedB; // Single buffer for packed weights and scales + + void CallReferenceGemm(size_t M, + size_t N, + size_t K, + const float* A, + const uint8_t* QuantBData, + const float* QuantBScale, + const uint8_t* QuantBZeroPoint, + float* C) { + float* DequantizedBData = BufferDequantizedB.GetBuffer(K * N); + MlasDequantizeBlockwise( + DequantizedBData, QuantBData, QuantBScale, QuantBZeroPoint, BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), GetMlasThreadPool()); + + // Note: DequantizedBData is in column major layout. + + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const float* a = A + m * K; + const float* b = DequantizedBData + n * K; + float* c = C + (m * N) + n; + float sum = 0.0f; + for (size_t k = 0; k < K; k++) { + sum += (*a) * (*b); + b += 1; + a += 1; + } + *c = sum; + } + } + } + + public: + void Test(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) { + MLAS_THREADPOOL* tp = WithThreadpool ? GetMlasThreadPool() : nullptr; + + // Clear config cache to ensure fresh config for each test case + MlasClearLutGemmKernelConfig(); + + const float* A = BufferA.GetBuffer(K * M); + const float* B = BufferB.GetBuffer(N * K); + float* C = BufferC.GetBuffer(N * M, true); + float* CReference = BufferCReference.GetBuffer(N * M, true); + + // quantize B + uint8_t* QuantBData = nullptr; + float* QuantBScale = nullptr; + uint8_t* QuantBZeroPoint = nullptr; + + { + size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; + MlasBlockwiseQuantizedBufferSizes(BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + + QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); + QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); + if (!Symmetric) { + QuantBZeroPoint = BufferQuantBZeroPoint.GetBuffer(QuantBZeroPointSizeInBytes); + } + + MlasQuantizeBlockwise(QuantBData, QuantBScale, QuantBZeroPoint, + B, BlkLen, + /* columnwise */ true, + static_cast(K), static_cast(N), + static_cast(N), + GetMlasThreadPool()); + } + + MlasInitLutGemmKernelConfig(N, K, BlkBitWidth, BlkLen, !Symmetric); + + // Use unified packing - single buffer for weights and scales/zp + size_t PackedBufSize = MlasLutGemmPackedSize(N, K, BlkBitWidth, BlkLen, !Symmetric); + std::byte* PackedBuf = BufferPackedB.GetBuffer(PackedBufSize); + + MlasLutGemmPack( + N, + K, + BlkBitWidth, + BlkLen, + !Symmetric, + reinterpret_cast(QuantBData), + QuantBScale, + QuantBZeroPoint, + PackedBuf, + tp); + + MlasLutGemm( + A, + BlkLen, + PackedBuf, + C, + static_cast(K), + static_cast(M), + static_cast(N), + !Symmetric, + tp); + + // Reference computation + CallReferenceGemm(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, CReference); + + size_t f = 0; + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + ASSERT_TRUE(CloseEnough(C[f], CReference[f])) + << "Expected: " << CReference[f] << " Actual: " << C[f] << "@[" << m << "x" << n << "], " + << "M=" << M << ", N=" << N << ", K=" << K; + } + } + } + + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SQLutGemm") + + "BlkBitWidth" + std::to_string(BlkBitWidth) + + "BlkLen" + std::to_string(BlkLen); + return suite_name.c_str(); + } +}; + +// Fixture to register parameterized tests quickly +template +class SQLutGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SQLutGemmShortExecuteTest(size_t M, size_t N, size_t K, + bool WithThreadpool, bool Symmetric) + : M_(M), + N_(N), + K_(K), + WithThreadpool_(WithThreadpool), + Symmetric_(Symmetric) { + } + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test( + M_, N_, K_, WithThreadpool_, Symmetric_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, bool WithThreadpool, bool Symmetric) { + if (!MlasIsLutGemmAvailable(N, K, BlkBitWidth, BlkLen)) { + return 0; + } + + if (M < BlkLen || N < BlkLen) { + return 0; + } + + std::stringstream ss; + ss << (WithThreadpool ? "Threaded" : "SingleThread") + << "/isSymmetric" << Symmetric + << "/M" << M << "xN" << N << "xK" << K; + + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSQLutGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SQLutGemmShortExecuteTest( + M, N, K, WithThreadpool, Symmetric); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t count = 0; + for (bool with_threadpool : {true}) { + for (bool symmetric : {true, false}) { // Test both symmetric and asymmetric + for (size_t b = 256; b < 320; b += 32) { + count += RegisterSingleTest(b, b, b, with_threadpool, symmetric); + } + + count += RegisterSingleTest(64, 128, 128, with_threadpool, symmetric); + count += RegisterSingleTest(128, 256, 256, with_threadpool, symmetric); + } + } + return count; + } + + private: + size_t M_, N_, K_; + bool WithThreadpool_, Symmetric_; +}; + +static size_t SQLutGemmRegisterAllShortExecuteTests() { + size_t count = 0; + count += SQLutGemmShortExecuteTest<2, 32>::RegisterShortExecuteTests(); + count += SQLutGemmShortExecuteTest<2, 64>::RegisterShortExecuteTests(); + count += SQLutGemmShortExecuteTest<2, 128>::RegisterShortExecuteTests(); + count += SQLutGemmShortExecuteTest<2, 256>::RegisterShortExecuteTests(); + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return SQLutGemmRegisterAllShortExecuteTests(); + } + return 0; + }); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 8446f88639436..f4e15c49d92f0 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -90,6 +90,8 @@ void usage() { "\t Otherwise, it will be fp32 precision. Works for float32 model for HTP backend. Defaults to '1' (with FP16 precision.). \n" "\t [QNN only] [offload_graph_io_quantization]: Offload graph input quantization and graph output dequantization to another EP (typically CPU EP). \n" "\t Defaults to '0' (QNN EP handles the graph I/O quantization and dequantization). \n" + "\t [QNN only] [extended_udma]: Enable HTP extended UDMA mode for better performance on supported hardware, options: \n" + "\t '0' (disabled), '1' (enabled). Default: '0'. \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_type|cpu\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -612,7 +614,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_arch. select from: " + str); } - } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization") { + } else if (key == "enable_htp_fp16_precision" || key == "offload_graph_io_quantization" || key == "extended_udma") { std::unordered_set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -626,7 +628,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { "Wrong key type entered. Choose from options: ['backend_type', 'backend_path', " "'profiling_level', 'profiling_file_path', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', " "'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'op_packages', 'qnn_context_priority', " - "'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization']"); + "'soc_model', 'htp_arch', 'device_id', 'enable_htp_fp16_precision', 'offload_graph_io_quantization', 'extended_udma']"); } qnn_options[key] = value; diff --git a/onnxruntime/test/onnx/tensorprotoutils.cc b/onnxruntime/test/onnx/tensorprotoutils.cc index bf2e19aa37371..d6b10beef11e0 100644 --- a/onnxruntime/test/onnx/tensorprotoutils.cc +++ b/onnxruntime/test/onnx/tensorprotoutils.cc @@ -14,6 +14,7 @@ #include "core/common/status.h" #include "core/framework/allocator.h" #include "core/framework/data_types.h" +#include "core/framework/int2.h" #include "core/common/endian.h" #include "core/framework/endian_utils.h" #include "core/graph/onnx_protobuf.h" @@ -122,6 +123,48 @@ void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, std::memcpy(dst_span.data(), src_span.data(), num_packed_pairs); } +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ Int2x4* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_quads = (expected_num_elements + 3) / 4; + + if (num_packed_quads != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed int2 quads", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_quads); + gsl::span dst_span = gsl::make_span(p_data, num_packed_quads); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_quads); +} + +template <> +void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, + /*out*/ UInt2x4* p_data) { + static_assert(std::is_trivially_copyable::value, "T must be trivially copyable"); + + if (p_data == nullptr) { + ORT_CXX_API_THROW("nullptr == p_data", OrtErrorCode::ORT_FAIL); + } + + size_t num_packed_quads = (expected_num_elements + 3) / 4; + + if (num_packed_quads != raw_data_len) { + ORT_CXX_API_THROW("Unexpected number of packed uint2 quads", OrtErrorCode::ORT_FAIL); + } + + gsl::span src_span = gsl::make_span(reinterpret_cast(raw_data), num_packed_quads); + gsl::span dst_span = gsl::make_span(p_data, num_packed_quads); + + std::memcpy(dst_span.data(), src_span.data(), num_packed_quads); +} + #if !defined(DISABLE_FLOAT4_TYPES) template <> void UnpackTensorWithRawData(const void* raw_data, size_t raw_data_len, size_t expected_num_elements, @@ -373,6 +416,41 @@ DEFINE_UNPACK_TENSOR_FLOAT8(Float8E5M2FNUZ, TensorProto_DataType_FLOAT8E5M2FNUZ) DEFINE_UNPACK_TENSOR_INT4(Int4x2, TensorProto_DataType_INT4) DEFINE_UNPACK_TENSOR_INT4(UInt4x2, TensorProto_DataType_UINT4) +#define DEFINE_UNPACK_TENSOR_INT2(INT2_TYPE, ONNX_TYPE) \ + template <> \ + void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, \ + /*out*/ INT2_TYPE* p_data, size_t expected_num_elems) { \ + if (nullptr == p_data) { \ + const size_t size = raw_data != nullptr ? raw_data_len : tensor.int32_data_size(); \ + if (size == 0) { \ + return; \ + } \ + ORT_CXX_API_THROW("p_data == nullptr, but size != 0", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + if (ONNX_NAMESPACE::ONNX_TYPE != tensor.data_type()) { \ + ORT_CXX_API_THROW("TensorProto data type is not INT2", OrtErrorCode::ORT_INVALID_ARGUMENT); \ + } \ + \ + size_t expected_int2_quads = (expected_num_elems + 3) / 4; \ + \ + if (raw_data != nullptr) { \ + UnpackTensorWithRawData(raw_data, raw_data_len, expected_num_elems, p_data); \ + return; \ + } \ + \ + if (static_cast(tensor.int32_data_size()) != expected_int2_quads) { \ + ORT_CXX_API_THROW("UnpackTensor: the pre-allocated size does not match the size in proto", \ + OrtErrorCode::ORT_FAIL); \ + } \ + \ + for (int i = 0; i < static_cast(tensor.int32_data_size()); i++) { \ + p_data[i] = INT2_TYPE(static_cast(tensor.int32_data()[i])); \ + } \ + } + +DEFINE_UNPACK_TENSOR_INT2(Int2x4, TensorProto_DataType_INT2) +DEFINE_UNPACK_TENSOR_INT2(UInt2x4, TensorProto_DataType_UINT2) + #if !defined(DISABLE_FLOAT4_TYPES) template <> void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, @@ -426,6 +504,13 @@ void UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_dat } \ break; +#define CASE_PROTO_TRACE_INT2(X) \ + case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ + if (!CalcMemSizeForArrayWithAlignment((size + 3) / 4, 1, alignment, out)) { \ + ORT_CXX_API_THROW("Invalid TensorProto", OrtErrorCode::ORT_FAIL); \ + } \ + break; + #if !defined(DISABLE_FLOAT4_TYPES) #define CASE_PROTO_TRACE_FLOAT4(X) \ case ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_##X: \ @@ -474,6 +559,8 @@ Status GetSizeInBytesFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_p CASE_PROTO_TRACE(STRING, std::string); CASE_PROTO_TRACE_INT4(UINT4); CASE_PROTO_TRACE_INT4(INT4); + CASE_PROTO_TRACE_INT2(UINT2); + CASE_PROTO_TRACE_INT2(INT2); default: return Status(common::ONNXRUNTIME, common::NOT_IMPLEMENTED); } @@ -561,6 +648,8 @@ ONNXTensorElementDataType CApiElementTypeFromProtoType(int type) { CASE_TYPE(FLOAT4E2M1) CASE_TYPE(UINT4) CASE_TYPE(INT4) + CASE_TYPE(UINT2) + CASE_TYPE(INT2) default: return ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; } @@ -630,6 +719,8 @@ Status TensorProtoToMLValue(const onnx::TensorProto& tensor_proto, const MemBuff #endif CASE_PROTO(INT4, Int4x2); CASE_PROTO(UINT4, UInt4x2); + CASE_PROTO(INT2, Int2x4); + CASE_PROTO(UINT2, UInt2x4); case onnx::TensorProto_DataType::TensorProto_DataType_STRING: if (preallocated != nullptr) { OrtStatus* status = OrtInitializeBufferForTensor(preallocated, preallocated_size, ele_type); @@ -767,6 +858,14 @@ Status MLValueToTensorProto(Ort::Value& value, onnx::TensorProto& tensor_proto) tensor_proto_dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT4; tensor_elem_bytes = 1; break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT2: + tensor_proto_dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT2; + tensor_elem_bytes = 1; + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT2: + tensor_proto_dtype = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT2; + tensor_elem_bytes = 1; + break; default: { // In this function, we do not support // ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING and ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 0afb836192b0a..4615b6a57b558 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -638,6 +638,37 @@ TEST_F(GraphTransformationTests, SkipLayerNormFusionTest) { TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format3_graph_output.onnx", 1, 1, 0, 0, logger_.get()); } +// SkipLayerNorm fusion should not be applied when gamma/beta have more than 1 dimension, +// because the SkipLayerNormalization kernel requires 1D gamma/beta. +TEST_F(GraphTransformationTests, SkipLayerNormFusion_3DGamma_NoFusion) { + auto build_test_case = [](ModelTestBuilder& builder) { + // Inputs: A and B are 3D [16, 32, 4] + auto* input_a = builder.MakeInput({16, 32, 4}, -1.0f, 1.0f); + auto* input_b = builder.MakeInput({16, 32, 4}, -1.0f, 1.0f); + // gamma and beta have 3D shape [1, 1, 4] (not 1D) + auto* gamma = builder.MakeInitializer({1, 1, 4}, {1.0f, 2.0f, 3.0f, 4.0f}); + auto* beta = builder.MakeInitializer({1, 1, 4}, {0.1f, 0.2f, 0.3f, 0.4f}); + auto* add_out = builder.MakeIntermediate(); + auto* ln_out = builder.MakeOutput(); + + builder.AddNode("Add", {input_a, input_b}, {add_out}); + builder.AddNode("LayerNormalization", {add_out, gamma, beta}, {ln_out}) + .AddAttribute("axis", static_cast(-1)); + }; + + auto post_graph_checker = [](Graph& graph) { + // SkipLayerNormalization should NOT have been created because gamma/beta are 3D. + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["Add"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(CountOpsInGraph(graph)["com.microsoft.SkipLayerNormalization"] == 0); + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + TEST_F(GraphTransformationTests, GroupQueryAttentionFusionTest) { TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_quantized_simple.onnx", 1, 0, logger_.get()); TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_different_head_sizes.onnx", 0, 1, logger_.get()); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index 1de9e87170943..7fe68a38d23a0 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3221,6 +3221,37 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) { test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128 } +// Test skip removing node when min/max come from DequantizeLinear nodes instead of initializers. +TEST(QDQTransformerTests, ClipQuantFusion_MultipleInputEdges) { + auto build_test_case = [&](ModelTestBuilder& builder) { + // Clip's min coming from another DQ node (creating 2 input edges to Clip) + auto* input_arg = builder.MakeInput({1, 2, 2, 2}, std::numeric_limits::min(), + std::numeric_limits::max()); + auto* data_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input_arg, 0.04f, static_cast(0), data_dq); + auto* min_q = builder.MakeScalarInitializer(0); + auto* min_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(min_q, 0.04f, static_cast(0), min_dq); + auto* clip_output = builder.MakeIntermediate(); + builder.AddNode("Clip", {data_dq, min_dq}, {clip_output}); + auto* output_q = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(clip_output, 0.04f, static_cast(0), output_q); + auto* output_arg = builder.MakeOutput(); + builder.AddDequantizeLinearNode(output_q, 0.04f, static_cast(0), output_arg); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + // ClipQuantFusion should skip it due to CanRemoveNode check + EXPECT_EQ(op_to_count["Clip"], 1); + }; + + TransformerTester(build_test_case, check_graph, + TransformerLevel::Default, + TransformerLevel::Level2, + 18); // opset +} + template void TestWhereWithDqInput(bool is_dq_1, bool is_dq_2, diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 83d533d5185ca..38e4d52d9a2d2 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -116,6 +116,8 @@ ABSL_FLAG(std::string, i, "", " [QNN only] [enable_htp_spill_fill_buffer]: Enable HTP spill fill buffer, used while generating QNN context binary.\n" " [QNN only] [enable_htp_shared_memory_allocator]: Enable the QNN HTP shared memory allocator and use it for inputs and outputs. Requires libcdsprpc.so/dll to be available.\n" " Defaults to '0' (disabled).\n" + " [QNN only] [extended_udma]: Enable HTP extended UDMA mode for better performance on supported hardware, options: \n" + " '0' (disabled), '1' (enabled). Default: '0'. \n" " [Example] [For QNN EP] -e qnn -i \"backend_type|cpu\" \n" "\n" " [TensorRT only] [trt_max_partition_iterations]: Maximum iterations for TensorRT parser to get capability.\n" @@ -197,10 +199,13 @@ ABSL_FLAG(std::string, select_ep_devices, "", "Specifies a semicolon-separated l ABSL_FLAG(std::string, filter_ep_devices, "", "Specifies EP or Device metadata entries as key-value pairs to filter ep devices passed to AppendExecutionProvider_V2.\n" "[Usage]: --filter_ep_devices \"| |\" \n" - "Devices that match any of the key-value pair will be appended to the session. --select_ep_devices will take precedence over this option.\n"); + "Devices that match any of the key-value pair will be appended to the session. --select_ep_devices will take precedence over this option.\n" + "[Example] --filter_ep_devices \"ov_device|NPU ov_device|CPU\" \n" + "Above example will append npu device first if available, followed by cpu device."); ABSL_FLAG(bool, compile_ep_context, DefaultPerformanceTestConfig().run_config.compile_ep_context, "Generate an EP context model"); ABSL_FLAG(std::string, compile_model_path, "model_ctx.onnx", "The compiled model path for saving EP context model. Overwrites if already exists"); ABSL_FLAG(bool, compile_binary_embed, DefaultPerformanceTestConfig().run_config.compile_binary_embed, "Embed binary blob within EP context node"); +ABSL_FLAG(bool, compile_only, DefaultPerformanceTestConfig().run_config.compile_only, "Only compile EP context model without running it"); ABSL_FLAG(bool, h, false, "Print program usage."); namespace onnxruntime { @@ -581,6 +586,9 @@ bool CommandLineParser::ParseArguments(PerformanceTestConfig& test_config, int a // --compile_binary_embed test_config.run_config.compile_binary_embed = absl::GetFlag(FLAGS_compile_binary_embed); + // --compile_only + test_config.run_config.compile_only = absl::GetFlag(FLAGS_compile_only); + if (positional.size() == 2) { test_config.model_info.model_file_path = ToPathString(positional[1]); test_config.run_config.f_dump_statistics = true; diff --git a/onnxruntime/test/perftest/common_utils.cc b/onnxruntime/test/perftest/common_utils.cc index 5cc6c240e25f0..63882c64c2dfc 100644 --- a/onnxruntime/test/perftest/common_utils.cc +++ b/onnxruntime/test/perftest/common_utils.cc @@ -90,6 +90,116 @@ std::vector CStringsFromStrings(std::vector& utf8_args) { return utf8_argv; } +void AppendPluginExecutionProviders(Ort::Env& env, + Ort::SessionOptions& session_options, + const PerformanceTestConfig& test_config) { + if (test_config.registered_plugin_eps.empty()) { + return; + } + + std::vector ep_devices = env.GetEpDevices(); + // EP -> associated EP devices (All OrtEpDevice instances must be from the same execution provider) + std::unordered_map> added_ep_devices; + std::unordered_set added_ep_device_index_set; + + auto& ep_list = test_config.machine_config.plugin_provider_type_list; + std::unordered_set ep_set(ep_list.begin(), ep_list.end()); + + // Select EP devices by provided device index + if (!test_config.selected_ep_device_indices.empty()) { + std::vector device_list; + device_list.reserve(test_config.selected_ep_device_indices.size()); + ParseEpDeviceIndexList(test_config.selected_ep_device_indices, device_list); + for (auto index : device_list) { + if (static_cast(index) > (ep_devices.size() - 1)) { + fprintf(stderr, "%s", "The device index provided is not correct. Will skip this device id."); + continue; + } + + Ort::ConstEpDevice& device = ep_devices[index]; + if (ep_set.find(std::string(device.EpName())) != ep_set.end()) { + if (added_ep_device_index_set.find(index) == added_ep_device_index_set.end()) { + added_ep_devices[device.EpName()].push_back(device); + added_ep_device_index_set.insert(index); + fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast(index), device.EpName(), device.Device().Type()); + } + } else { + std::string err_msg = "[Plugin EP] [WARNING] : The EP device index and its corresponding OrtEpDevice is not created from " + + test_config.machine_config.provider_type_name + ". Will skip adding this device.\n"; + fprintf(stderr, "%s", err_msg.c_str()); + } + } + } else if (!test_config.filter_ep_device_kv_pairs.empty()) { + // Find and select the OrtEpDevice associated with the EP in "--filter_ep_devices". + for (const auto& kv : test_config.filter_ep_device_kv_pairs) { + for (size_t index = 0; index < ep_devices.size(); ++index) { + auto device = ep_devices[index]; + if (ep_set.find(std::string(device.EpName())) == ep_set.end()) + continue; + + // Skip if deviceid was already added + if (added_ep_devices.find(device.EpName()) != added_ep_devices.end() && + std::find(added_ep_devices[device.EpName()].begin(), added_ep_devices[device.EpName()].end(), device) != added_ep_devices[device.EpName()].end()) + continue; + + // Check both EP metadata and device metadata for a match + auto ep_metadata_kv_pairs = device.EpMetadata().GetKeyValuePairs(); + auto device_metadata_kv_pairs = device.Device().Metadata().GetKeyValuePairs(); + auto ep_metadata_itr = ep_metadata_kv_pairs.find(kv.first); + auto device_metadata_itr = device_metadata_kv_pairs.find(kv.first); + + if ((ep_metadata_itr != ep_metadata_kv_pairs.end() && kv.second == ep_metadata_itr->second) || + (device_metadata_itr != device_metadata_kv_pairs.end() && kv.second == device_metadata_itr->second)) { + added_ep_devices[device.EpName()].push_back(device); + fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast(index), device.EpName(), device.Device().Type()); + break; + } + } + } + } else { + // Find and select the OrtEpDevice associated with the EP in "--plugin_eps". + for (size_t index = 0; index < ep_devices.size(); ++index) { + Ort::ConstEpDevice& device = ep_devices[index]; + if (ep_set.find(std::string(device.EpName())) != ep_set.end()) { + added_ep_devices[device.EpName()].push_back(device); + fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s] has been added to session.\n", static_cast(index), device.EpName()); + } + } + } + + if (added_ep_devices.empty()) { + ORT_THROW("[ERROR] [Plugin EP]: No matching EP devices found."); + } + + std::string ep_option_string = ToUTF8String(test_config.run_config.ep_runtime_config_string); + + // EP's associated provider option lists + std::vector> ep_options_list; + ParseEpOptions(ep_option_string, ep_options_list); + + // If user only provide the EPs' provider option lists for the first several EPs, + // add empty provider option lists for the rest EPs. + if (ep_options_list.size() < ep_list.size()) { + for (size_t i = ep_options_list.size(); i < ep_list.size(); ++i) { + ep_options_list.emplace_back(); // Adds a new empty map + } + } else if (ep_options_list.size() > ep_list.size()) { + ORT_THROW("[ERROR] [Plugin EP]: Too many EP provider option lists provided."); + } + + // EP -> associated provider options + std::unordered_map> ep_options_map; + for (size_t i = 0; i < ep_list.size(); ++i) { + ep_options_map.emplace(ep_list[i], ep_options_list[i]); + } + + for (auto& ep_and_devices : added_ep_devices) { + auto& ep = ep_and_devices.first; + auto& devices = ep_and_devices.second; + session_options.AppendExecutionProvider_V2(env, devices, ep_options_map[ep]); + } +} + } // namespace utils } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/perftest/main.cc b/onnxruntime/test/perftest/main.cc index 6806d851958ea..512f217a77151 100644 --- a/onnxruntime/test/perftest/main.cc +++ b/onnxruntime/test/perftest/main.cc @@ -14,7 +14,7 @@ using namespace onnxruntime; const OrtApi* g_ort = NULL; int RunPerfTest(Ort::Env& env, const perftest::PerformanceTestConfig& test_config); -Ort::Status CompileEpContextModel(const Ort::Env& env, const perftest::PerformanceTestConfig& test_config); +Ort::Status CompileEpContextModel(Ort::Env& env, const perftest::PerformanceTestConfig& test_config); #ifdef _WIN32 int real_main(int argc, wchar_t* argv[]) { @@ -82,6 +82,12 @@ int real_main(int argc, char* argv[]) { return -1; } + std::cout << "Model compiled successfully to " << ToUTF8String(test_config.run_config.compile_model_path) << "\n"; + if (test_config.run_config.compile_only) { + return 0; + } + + std::cout << "\n> Running EP context model...\n"; { test_config.model_info.model_file_path = test_config.run_config.compile_model_path; status = RunPerfTest(env, test_config); @@ -134,14 +140,20 @@ int RunPerfTest(Ort::Env& env, const perftest::PerformanceTestConfig& test_confi return 0; } -Ort::Status CompileEpContextModel(const Ort::Env& env, const perftest::PerformanceTestConfig& test_config) { +Ort::Status CompileEpContextModel(Ort::Env& env, const perftest::PerformanceTestConfig& test_config) { auto output_ctx_model_path = test_config.run_config.compile_model_path; const auto provider_name = test_config.machine_config.provider_type_name; Ort::SessionOptions session_options; - std::unordered_map provider_options; - session_options.AppendExecutionProvider(provider_name, provider_options); + // Add EP devices if any (created by plugin EP) + if (!test_config.registered_plugin_eps.empty()) { + perftest::utils::AppendPluginExecutionProviders(env, session_options, test_config); + } else { + // Regular non-plugin EP + std::unordered_map provider_options; + session_options.AppendExecutionProvider(provider_name, provider_options); + } // free dim override if (!test_config.run_config.free_dim_name_overrides.empty()) { diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 3468e2e55c7b6..91f0581af0633 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -21,6 +21,7 @@ #include "providers.h" #include "TestCase.h" #include "strings_helper.h" +#include "utils.h" #ifdef USE_OPENVINO #include "nlohmann/json.hpp" @@ -90,102 +91,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device // Add EP devices if any (created by plugin EP) if (!performance_test_config.registered_plugin_eps.empty()) { - std::vector ep_devices = env.GetEpDevices(); - // EP -> associated EP devices (All OrtEpDevice instances must be from the same execution provider) - std::unordered_map> added_ep_devices; - std::unordered_set added_ep_device_index_set; - - auto& ep_list = performance_test_config.machine_config.plugin_provider_type_list; - std::unordered_set ep_set(ep_list.begin(), ep_list.end()); - - // Select EP devices by provided device index - if (!performance_test_config.selected_ep_device_indices.empty()) { - std::vector device_list; - device_list.reserve(performance_test_config.selected_ep_device_indices.size()); - ParseEpDeviceIndexList(performance_test_config.selected_ep_device_indices, device_list); - for (auto index : device_list) { - if (static_cast(index) > (ep_devices.size() - 1)) { - fprintf(stderr, "%s", "The device index provided is not correct. Will skip this device id."); - continue; - } - - Ort::ConstEpDevice& device = ep_devices[index]; - if (ep_set.find(std::string(device.EpName())) != ep_set.end()) { - if (added_ep_device_index_set.find(index) == added_ep_device_index_set.end()) { - added_ep_devices[device.EpName()].push_back(device); - added_ep_device_index_set.insert(index); - fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast(index), device.EpName(), device.Device().Type()); - } - } else { - std::string err_msg = "[Plugin EP] [WARNING] : The EP device index and its corresponding OrtEpDevice is not created from " + - performance_test_config.machine_config.provider_type_name + ". Will skip adding this device.\n"; - fprintf(stderr, "%s", err_msg.c_str()); - } - } - } else if (!performance_test_config.filter_ep_device_kv_pairs.empty()) { - // Find and select the OrtEpDevice associated with the EP in "--filter_ep_devices". - for (size_t index = 0; index < ep_devices.size(); ++index) { - auto device = ep_devices[index]; - if (ep_set.find(std::string(device.EpName())) == ep_set.end()) - continue; - - // Check both EP metadata and device metadata for a match - auto ep_metadata_kv_pairs = device.EpMetadata().GetKeyValuePairs(); - auto device_metadata_kv_pairs = device.Device().Metadata().GetKeyValuePairs(); - for (const auto& kv : performance_test_config.filter_ep_device_kv_pairs) { - auto ep_metadata_itr = ep_metadata_kv_pairs.find(kv.first); - auto device_metadata_itr = device_metadata_kv_pairs.find(kv.first); - - if ((ep_metadata_itr != ep_metadata_kv_pairs.end() && kv.second == ep_metadata_itr->second) || - (device_metadata_itr != device_metadata_kv_pairs.end() && kv.second == device_metadata_itr->second)) { - added_ep_devices[device.EpName()].push_back(device); - fprintf(stdout, "[Plugin EP] EP Device [Index: %d, Name: %s, Type: %d] has been added to session.\n", static_cast(index), device.EpName(), device.Device().Type()); - break; - } - } - } - } else { - // Find and select the OrtEpDevice associated with the EP in "--plugin_eps". - for (size_t index = 0; index < ep_devices.size(); ++index) { - Ort::ConstEpDevice& device = ep_devices[index]; - if (ep_set.find(std::string(device.EpName())) != ep_set.end()) { - added_ep_devices[device.EpName()].push_back(device); - fprintf(stdout, "EP Device [Index: %d, Name: %s] has been added to session.\n", static_cast(index), device.EpName()); - } - } - } - - if (added_ep_devices.empty()) { - ORT_THROW("[ERROR] [Plugin EP]: No matching EP devices found."); - } - - std::string ep_option_string = ToUTF8String(performance_test_config.run_config.ep_runtime_config_string); - - // EP's associated provider option lists - std::vector> ep_options_list; - ParseEpOptions(ep_option_string, ep_options_list); - - // If user only provide the EPs' provider option lists for the first several EPs, - // add empty provider option lists for the rest EPs. - if (ep_options_list.size() < ep_list.size()) { - for (size_t i = ep_options_list.size(); i < ep_list.size(); ++i) { - ep_options_list.emplace_back(); // Adds a new empty map - } - } else if (ep_options_list.size() > ep_list.size()) { - ORT_THROW("[ERROR] [Plugin EP]: Too many EP provider option lists provided."); - } - - // EP -> associated provider options - std::unordered_map> ep_options_map; - for (size_t i = 0; i < ep_list.size(); ++i) { - ep_options_map.emplace(ep_list[i], ep_options_list[i]); - } - - for (auto& ep_and_devices : added_ep_devices) { - auto& ep = ep_and_devices.first; - auto& devices = ep_and_devices.second; - session_options.AppendExecutionProvider_V2(env, devices, ep_options_map[ep]); - } + perftest::utils::AppendPluginExecutionProviders(env, session_options, performance_test_config); } provider_name_ = performance_test_config.machine_config.provider_type_name; @@ -352,7 +258,8 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device "qnn_saver_path", "htp_graph_finalization_optimization_mode", "qnn_context_priority", "htp_arch", "enable_htp_fp16_precision", "offload_graph_io_quantization", "enable_htp_spill_fill_buffer", "enable_htp_shared_memory_allocator", "dump_json_qnn_graph", - "json_qnn_graph_dir"}); + "json_qnn_graph_dir", "disable_file_mapped_weights", "htp_bf16_enable", "enable_vtcm_backup_buffer_sharing", "extended_udma"}); + for (const auto& provider_option : provider_options) { const std::string& key = provider_option.first; const std::string& value = provider_option.second; @@ -404,7 +311,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } } else if (key == "htp_arch") { - std::set supported_htp_archs = {"0", "68", "69", "73", "75"}; + std::set supported_htp_archs = {"0", "68", "69", "73", "75", "81"}; if (supported_htp_archs.find(value) == supported_htp_archs.end()) { std::ostringstream str_stream; std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), @@ -416,7 +323,10 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device key == "offload_graph_io_quantization" || key == "enable_htp_spill_fill_buffer" || key == "enable_htp_shared_memory_allocator" || - key == "dump_json_qnn_graph") { + key == "dump_json_qnn_graph" || + key == "extended_udma" || + key == "disable_file_mapped_weights" || + key == "enable_vtcm_backup_buffer_sharing") { std::set supported_options = {"0", "1"}; if (supported_options.find(value) == supported_options.end()) { std::ostringstream str_stream; @@ -1016,14 +926,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); input_names_[i] = input_names_str_[i].c_str(); } - auto transform_fcn = std::function(); - auto new_value = std::function&, Ort::ConstTensorTypeAndShapeInfo&)>(); - if (device_memory_name_.empty()) { - transform_fcn = [](int64_t input) { return input; }; - new_value = [](OrtAllocator*, const std::vector&, Ort::ConstTensorTypeAndShapeInfo&) { - return Ort::Value(nullptr); - }; - } else { + if (!device_memory_name_.empty()) { Ort::MemoryInfo memory_info(nullptr); // Default initialize, will be overwritten if (device_memory_name_ == CUDA) { memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeDefault); @@ -1031,22 +934,20 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); memory_info = Ort::MemoryInfo(device_memory_name_.data(), OrtArenaAllocator, 0, OrtMemTypeCPUOutput); } custom_allocator_ = Ort::Allocator(session_, memory_info); - // Switch to custom + // Switch to custom allocator allocator_ = Ort::UnownedAllocator(custom_allocator_); - - // free dimensions are treated as 1 if not overridden - transform_fcn = [](int64_t input) { return (input == -1) ? -input : input; }; - new_value = [](OrtAllocator* allocator, const std::vector& output_shape, Ort::ConstTensorTypeAndShapeInfo& tensor_info) { - return Ort::Value::CreateTensor(allocator, output_shape.data(), output_shape.size(), tensor_info.GetElementType()); - }; } - for (size_t i = 0; i < output_names_raw_ptr.size(); i++) { Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i); auto tensor_info = type_info.GetTensorTypeAndShapeInfo(); std::vector output_shape = tensor_info.GetShape(); - std::transform(output_shape.begin(), output_shape.end(), output_shape.begin(), transform_fcn); - outputs_.emplace_back(new_value(allocator_, output_shape, tensor_info)); + auto is_dynamic = std::find(output_shape.begin(), output_shape.end(), -1) != output_shape.end(); + if (is_dynamic || device_memory_name_.empty()) { + outputs_.emplace_back(Ort::Value(nullptr)); + } else { + auto new_value = Ort::Value::CreateTensor(allocator_, output_shape.data(), output_shape.size(), tensor_info.GetElementType()); + outputs_.emplace_back(std::move(new_value)); + } } } diff --git a/onnxruntime/test/perftest/test_configuration.h b/onnxruntime/test/perftest/test_configuration.h index 1d8ad77096ef3..14b48aef9607f 100644 --- a/onnxruntime/test/perftest/test_configuration.h +++ b/onnxruntime/test/perftest/test_configuration.h @@ -77,6 +77,7 @@ struct RunConfig { bool compile_ep_context{false}; std::basic_string compile_model_path; bool compile_binary_embed{false}; + bool compile_only{false}; struct CudaMempoolArenaConfig { std::string release_threshold; std::string bytes_to_keep; diff --git a/onnxruntime/test/perftest/utils.h b/onnxruntime/test/perftest/utils.h index 9f180e2c8d942..c008273a3dc52 100644 --- a/onnxruntime/test/perftest/utils.h +++ b/onnxruntime/test/perftest/utils.h @@ -33,6 +33,10 @@ void UnregisterExecutionProviderLibrary(Ort::Env& env, PerformanceTestConfig& te void ListEpDevices(const Ort::Env& env); +void AppendPluginExecutionProviders(Ort::Env& env, + Ort::SessionOptions& session_options, + const PerformanceTestConfig& test_config); + } // namespace utils } // namespace perftest } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc index 10affa538dfad..f968fc6fc2f2e 100644 --- a/onnxruntime/test/providers/cpu/controlflow/loop_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/loop_test.cc @@ -1037,7 +1037,8 @@ TEST(Loop, IterationCountAsOutput) { test.AddOutput("loop_var_0_final", {3, 1}, {0, 1, 2}); // Disable TensorRT on unsupported data type BOOL - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + // Disable OV EP due to ONNX partition create new domain and OV FE can't handle it + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); } #if defined(USE_CUDA) diff --git a/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc b/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc index c7fc73456dcba..671ada7d36383 100644 --- a/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc +++ b/onnxruntime/test/providers/cpu/ml/array_feature_extractor_test.cc @@ -109,5 +109,13 @@ TEST_F(ArrayFeatureExtractorTest, InvalidInputOutOfBoundsY) { test_.Run(OpTester::ExpectResult::kExpectFailure); } +TEST_F(ArrayFeatureExtractorTest, InvalidInputNegativeY) { + test_.AddInput("X", {10, 1}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + test_.AddInput("Y", {1}, {-10}); + // Should fail due to negative index -10 + test_.AddOutput("Z", {0}, {}); + test_.Run(OpTester::ExpectResult::kExpectFailure); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc index 289e94397fb39..67bb5d780ad2d 100644 --- a/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/cast_op_test.cc @@ -1,4 +1,4 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include @@ -75,6 +75,11 @@ void TestCastOp(gsl::span input, excluded_provider_types.insert(kCudaExecutionProvider); } + if (input.size() == 0) { + // The OpenVINO doesn't support 0 size input + excluded_provider_types.insert(kOpenVINOExecutionProvider); + } + if (cuda_only && (excluded_provider_types.count(kCudaExecutionProvider) > 0)) { return; } @@ -89,6 +94,20 @@ void TestCastOp(gsl::span input, test.Run(expect_result, expected_failure_string, excluded_provider_types); } +// INT2 types were introduced in opset 25 (IR13) +constexpr int kInt2Opset = 25; + +// Helper for INT2 cast tests that uses opset 25 by default +template +void TestCastOpInt2(gsl::span input, + gsl::span output, + const BaseTester::DimsVariant& dimensions, + OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, + const std::string& expected_failure_string = "", + Saturate saturate = Saturate::None) { + TestCastOp(input, output, dimensions, expect_result, expected_failure_string, kInt2Opset, saturate); +} + template using RequiresCastThroughFloat = boost::mp11::mp_any< @@ -1249,215 +1268,1231 @@ TEST(CastOpTest, FloatStringToInt4x2) { TestCastOp(gsl::span(string_input), gsl::span(expected_int4x2_output), shape); } -#if !defined(DISABLE_FLOAT8_TYPES) +TEST(CastOpTest, Int2x4ToInt8) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), // boundary and zero values + Int2x4(1, -2, -1, 0) // mixed values + }; -template -void CastOpTestFloat8(Saturate saturate) { - ASSERT_NE(saturate, Saturate::None); + const std::vector expected_int8_output = {-2, 1, 0, -1, 1, -2, -1, 0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_int8_output), shape); +} + +TEST(CastOpTest, Int2x4ToUInt8) { + // GIVEN const std::vector shape{2, 2, 2}; - const std::vector float_input = {NAN, -1.f, 0.0391877927f, 0.296140194f, -0.120196559f, 5.0f, - -std::numeric_limits::infinity(), - std::numeric_limits::infinity()}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; - // float output precision is 8, so the expected output differs slightly from the input due to that - std::vector output; - output.reserve(float_input.size()); - for (size_t i = 0; i < float_input.size(); ++i) { - output.emplace_back(F8(float_input[i], saturate == Saturate::True)); - } - TestCastOp(gsl::make_span(float_input), gsl::make_span(output), shape, OpTester::ExpectResult::kExpectSuccess, "", 19, saturate); + // Negative values will be cast to their unsigned representation + const std::vector expected_uint8_output = {254, 1, 0, 255, 1, 254, 255, 0}; - const std::vector float16_input = - CastedValues(gsl::make_span(float_input)); + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_uint8_output), shape); +} - TestCastOp(gsl::make_span(float16_input), gsl::make_span(output), shape, OpTester::ExpectResult::kExpectSuccess, "", 19, saturate); +TEST(CastOpTest, Int2x4ToInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + const std::vector expected_int16_output = {-2, 1, 0, -1, 1, -2, -1, 0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_int16_output), shape); } -TEST(CastOpTest, ToFloat8E4M3FN) { - constexpr int min_cuda_architecture = 11080; - bool enable_cuda = (nullptr != DefaultCudaExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); - bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); +TEST(CastOpTest, Int2x4ToInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; - if (enable_cpu || enable_cuda) { - CastOpTestFloat8(Saturate::True); - CastOpTestFloat8(Saturate::False); - } + const std::vector expected_int32_output = {-2, 1, 0, -1, 1, -2, -1, 0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_int32_output), shape); } -TEST(CastOpTest, ToFloat8E4M3FNUZ) { - bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); - if (enable_cpu) { - CastOpTestFloat8(Saturate::True); - CastOpTestFloat8(Saturate::False); - } +TEST(CastOpTest, Int2x4ToInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + const std::vector expected_int64_output = {-2, 1, 0, -1, 1, -2, -1, 0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_int64_output), shape); } -TEST(CastOpTest, ToFloat8E5M2) { - constexpr int min_cuda_architecture = 11080; - bool enable_cuda = (nullptr != DefaultCudaExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); - bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); +TEST(CastOpTest, Int2x4ToFloat) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; - if (enable_cpu || enable_cuda) { - CastOpTestFloat8(Saturate::True); - CastOpTestFloat8(Saturate::False); - } + const std::vector expected_float_output = {-2.0f, 1.0f, 0.0f, -1.0f, 1.0f, -2.0f, -1.0f, 0.0f}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_float_output), shape); } -TEST(CastOpTest, ToFloat8E5M2FNUZ) { - bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); - if (enable_cpu) { - CastOpTestFloat8(Saturate::True); - CastOpTestFloat8(Saturate::False); - } +TEST(CastOpTest, Int2x4ToDouble) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + const std::vector expected_double_output = {-2.0, 1.0, 0.0, -1.0, 1.0, -2.0, -1.0, 0.0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_double_output), shape); } -TEST(CastOpTest, Int4x2ToFloat8E4M3FN) { +TEST(CastOpTest, Int2x4ToBool) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int4x2_input = { - Int4x2(-8, 7), - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; + const std::vector int2x4_input = { + Int2x4(0, -1, 1, 0), + Int2x4(-2, 0, 1, -1)}; - std::vector expected_float8_output; - expected_float8_output.reserve(8); - const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; - for (float val : float_values) { - expected_float8_output.emplace_back(Float8E4M3FN(val, true)); - } + const bool bool_output[] = {false, true, true, false, true, false, true, true}; + const gsl::span expected_bool_output_span(bool_output); // WHEN, THEN - // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 - TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape); - // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 - TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape, - OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); + TestCastOpInt2(gsl::make_span(int2x4_input), expected_bool_output_span, shape); } -TEST(CastOpTest, UInt4x2ToFloat8E4M3FN) { +TEST(CastOpTest, Int2x4ToMLFloat16) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector uint4x2_input = { - UInt4x2(0, 15), - UInt4x2(1, 14), - UInt4x2(7, 8), - UInt4x2(3, 12)}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; - std::vector expected_uint_float8_output; - expected_uint_float8_output.reserve(8); - const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; - for (float val : uint_float_values) { - expected_uint_float8_output.emplace_back(Float8E4M3FN(val, true)); - } + const std::vector expected_float16_output = + CastedValues( + gsl::make_span( + std::vector{-2.0f, 1.0f, 0.0f, -1.0f, 1.0f, -2.0f, -1.0f, 0.0f})); // WHEN, THEN - // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 - TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape); - // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 - TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape, - OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_float16_output), shape); } -TEST(CastOpTest, Int4x2ToFloat8E5M2) { +TEST(CastOpTest, Int2x4ToString) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector int4x2_input = { - Int4x2(-8, 7), - Int4x2(0, -1), - Int4x2(3, -5), - Int4x2(6, 2)}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; - std::vector expected_float8e5m2_output; - expected_float8e5m2_output.reserve(8); - const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; - for (float val : float_values) { - expected_float8e5m2_output.emplace_back(Float8E5M2(val, true)); - } + const std::vector expected_output = { + "-2", "1", "0", "-1", + "1", "-2", "-1", "0"}; // WHEN, THEN - // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 - TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape); - // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 - TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape, - OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); + TestCastOpInt2(gsl::span(int2x4_input), gsl::span(expected_output), shape); } -TEST(CastOpTest, UInt4x2ToFloat8E5M2) { +TEST(CastOpTest, Int2x4ToInt32OddNumberOfElements) { + // GIVEN - Test with 5 elements (not a multiple of 4) + const std::vector odd_shape{5}; + const std::vector odd_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, 0, 0, 0), // last 3 values are padding + }; + + const std::vector expected_odd_output = {-2, 1, 0, -1, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(odd_input), gsl::make_span(expected_odd_output), odd_shape); +} + +TEST(CastOpTest, UInt2x4ToUInt8) { // GIVEN const std::vector shape{2, 2, 2}; - const std::vector uint4x2_input = { - UInt4x2(0, 15), - UInt4x2(1, 14), - UInt4x2(7, 8), - UInt4x2(3, 12)}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), // boundary and mid values + UInt2x4(3, 0, 2, 1) // reversed order + }; - std::vector expected_uint_float8e5m2_output; - expected_uint_float8e5m2_output.reserve(8); - const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; - for (float val : uint_float_values) { - expected_uint_float8e5m2_output.emplace_back(Float8E5M2(val, true)); - } + const std::vector expected_uint8_output = {0, 3, 1, 2, 3, 0, 2, 1}; // WHEN, THEN - // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 - TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); - // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 - TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape, - OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_uint8_output), shape); } -TEST(CastOpTest, Float8E4M3FNToInt4x2) { +TEST(CastOpTest, UInt2x4ToInt8) { // GIVEN const std::vector shape{2, 2, 2}; - std::vector float8_input; - const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; - for (float val : input_values) { - float8_input.emplace_back(Float8E4M3FN(val, true)); - } + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; - const std::vector expected_int4x2_output = { + const std::vector expected_int8_output = {0, 3, 1, 2, 3, 0, 2, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_int8_output), shape); +} + +TEST(CastOpTest, UInt2x4ToInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_int32_output = {0, 3, 1, 2, 3, 0, 2, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_int32_output), shape); +} + +TEST(CastOpTest, UInt2x4ToFloat) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_float_output = {0.0f, 3.0f, 1.0f, 2.0f, 3.0f, 0.0f, 2.0f, 1.0f}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_float_output), shape); +} + +TEST(CastOpTest, UInt2x4ToBool) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 1, 2, 0), + UInt2x4(3, 0, 0, 1)}; + + const bool bool_output[] = {false, true, true, false, true, false, false, true}; + const gsl::span expected_bool_output_span(bool_output); + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), expected_bool_output_span, shape); +} + +TEST(CastOpTest, UInt2x4ToString) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_output = { + "0", "3", "1", "2", + "3", "0", "2", "1"}; + + // WHEN, THEN + TestCastOpInt2(gsl::span(uint2x4_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, Int2x4ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // Reinterpret: -2 becomes 2, -1 becomes 3 (mask to 2 bits) + const std::vector expected_uint2x4_output = { + UInt2x4(2, 1, 0, 3), + UInt2x4(1, 2, 3, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, UInt2x4ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + // Sign-extend: 2 becomes -2, 3 becomes -1 (values >= 2 are negative in 2-bit signed) + const std::vector expected_int2x4_output = { + Int2x4(0, -1, 1, -2), + Int2x4(-1, 0, -2, 1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, Int4x2ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { Int4x2(-8, 7), Int4x2(0, -1), Int4x2(3, -5), Int4x2(6, 2)}; + // Truncate to 2 bits and sign-extend: -8 -> 0, 7 -> -1, 0 -> 0, -1 -> -1, 3 -> -1, -5 -> -1, 6 -> -2, 2 -> -2 + const std::vector expected_int2x4_output = { + Int2x4(0, -1, 0, -1), + Int2x4(-1, -1, -2, -2)}; + // WHEN, THEN - // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, - // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. - TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); + TestCastOpInt2(gsl::make_span(int4x2_input), gsl::make_span(expected_int2x4_output), shape); } -TEST(CastOpTest, Float8E4M3FNToInt4x2_OddShape) { +TEST(CastOpTest, Int4x2ToUInt2x4) { // GIVEN - const std::vector shape{1, 2, 3}; - std::vector float8_input; - const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f}; - for (float val : input_values) { - float8_input.emplace_back(Float8E4M3FN(val, true)); - } - - const std::vector expected_int4x2_output = { + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { Int4x2(-8, 7), Int4x2(0, -1), - Int4x2(3, -5)}; + Int4x2(3, -5), + Int4x2(6, 2)}; + + // Truncate to 2 bits: -8 -> 0, 7 -> 3, 0 -> 0, -1 -> 3, 3 -> 3, -5 -> 3, 6 -> 2, 2 -> 2 + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 0, 3), + UInt2x4(3, 3, 2, 2)}; // WHEN, THEN - // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, - // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. - TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); + TestCastOpInt2(gsl::make_span(int4x2_input), gsl::make_span(expected_uint2x4_output), shape); } -TEST(CastOpTest, Float8E4M3FNToUInt4x2) { +TEST(CastOpTest, UInt4x2ToInt2x4) { // GIVEN const std::vector shape{2, 2, 2}; - std::vector uint_float8_input; - const std::vector uint_input_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; - for (float val : uint_input_values) { - uint_float8_input.emplace_back(Float8E4M3FN(val, true)); - } - - const std::vector expected_uint4x2_output = { + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + // Truncate to 2 bits and sign-extend: 0 -> 0, 15 -> -1, 1 -> 1, 14 -> -2, 7 -> -1, 8 -> 0, 3 -> -1, 12 -> 0 + const std::vector expected_int2x4_output = { + Int2x4(0, -1, 1, -2), + Int2x4(-1, 0, -1, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint4x2_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, UInt4x2ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + // Truncate to 2 bits: 0 -> 0, 15 -> 3, 1 -> 1, 14 -> 2, 7 -> 3, 8 -> 0, 3 -> 3, 12 -> 0 + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 3, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, Int2x4ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // Values fit directly: -2 -> -2, 1 -> 1, 0 -> 0, -1 -> -1 + const std::vector expected_int4x2_output = { + Int4x2(-2, 1), + Int4x2(0, -1), + Int4x2(1, -2), + Int4x2(-1, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Int2x4ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // Mask to 4 bits: -2 -> 14, 1 -> 1, 0 -> 0, -1 -> 15 + const std::vector expected_uint4x2_output = { + UInt4x2(14, 1), + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(15, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, UInt2x4ToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + // Values fit directly: 0-3 all fit in int4 + const std::vector expected_int4x2_output = { + Int4x2(0, 3), + Int4x2(1, 2), + Int4x2(3, 0), + Int4x2(2, 1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, UInt2x4ToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + // Values fit directly: 0-3 all fit in uint4 + const std::vector expected_uint4x2_output = { + UInt4x2(0, 3), + UInt4x2(1, 2), + UInt4x2(3, 0), + UInt4x2(2, 1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_uint4x2_output), shape); +} + +TEST(CastOpTest, Int8ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int8_input = {-10, 15, 0, -1, 7, -8, -128, 127}; + + // Truncate to 2 bits and sign-extend + // -10 = 0xF6, truncate to 0x02 = 2, sign-extend to -2 + // 15 = 0x0F, truncate to 0x03 = 3, sign-extend to -1 + // 0 = 0x00, truncate to 0x00 = 0 + // -1 = 0xFF, truncate to 0x03 = 3, sign-extend to -1 + // 7 = 0x07, truncate to 0x03 = 3, sign-extend to -1 + // -8 = 0xF8, truncate to 0x00 = 0 + // -128 = 0x80, truncate to 0x00 = 0 + // 127 = 0x7F, truncate to 0x03 = 3, sign-extend to -1 + const std::vector expected_int2x4_output = { + Int2x4(-2, -1, 0, -1), + Int2x4(-1, 0, 0, -1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int8_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, UInt8ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint8_input = {20, 255, 0, 17, 7, 240, 15, 31}; + + // values get truncated to lower 2 bits + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 0, 1), // 20 (0x14) truncate to 0, 255 (0xFF) truncate to 3, etc. + UInt2x4(3, 0, 3, 3)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint8_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, Int32ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int32_input = {-10, INT32_MAX, 0, INT32_MIN, 3, -5, 4080, 287}; + + // Truncate to 2 bits and sign-extend + const std::vector expected_int2x4_output = { + Int2x4(-2, -1, 0, 0), // -10 -> -2, INT32_MAX -> -1, 0 -> 0, INT32_MIN -> 0 + Int2x4(-1, -1, 0, -1) // 3 -> -1, -5 -> -1, 4080 -> 0, 287 -> -1 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int32_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, Int32ToInt2x4OddNumberOfElements) { + // GIVEN + const std::vector odd_shape{5}; + const std::vector odd_input = {-10, INT32_MAX, 0, INT32_MIN, 3}; + + // Truncate to 2 bits and sign-extend; INT2 packs 4 per byte + const std::vector expected_odd_output = { + Int2x4(-2, -1, 0, 0), // -10 -> -2, INT32_MAX -> -1, 0 -> 0, INT32_MIN -> 0 + Int2x4(-1, 0, 0, 0) // 3 -> -1, padded with 0 + }; + + // WHEN, THEN + TestCastOp(gsl::make_span(odd_input), gsl::make_span(expected_odd_output), odd_shape); +} + +TEST(CastOpTest, UInt32ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint32_input = {20, UINT32_MAX, 0, 256, 7, 240, 15, 4095}; + + // Truncate to 2 bits (no sign extension for unsigned) + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 0, 0), // 20 -> 0, UINT32_MAX -> 3, 0 -> 0, 256 -> 0 + UInt2x4(3, 0, 3, 3) // 7 -> 3, 240 -> 0, 15 -> 3, 4095 -> 3 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint32_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, FloatToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector float_input = {-2.3f, 1.7f, 0.4f, -1.6f, 3.0f, -5.2f, 240.1f, 31.9f}; + + // Round then truncate to 2 bits and sign-extend + // -2.3 rounds to -2 -> -2 + // 1.7 rounds to 2 -> -2 (truncate and sign-extend) + // 0.4 rounds to 0 -> 0 + // -1.6 rounds to -2 -> -2 + // 3.0 -> 3, truncate to -1 + // -5.2 rounds to -5 -> -1 (truncate 0x03) + // 240.1 rounds to 240 -> 0 (truncate) + // 31.9 rounds to 32 -> 0 (truncate) + const std::vector expected_int2x4_output = { + Int2x4(-2, -2, 0, -2), + Int2x4(-1, -1, 0, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(float_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, FloatToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector float_input = {0.4f, 3.7f, 1.0f, 2.5f, 4.0f, -1.0f, 15.1f, 31.9f}; + + // Round then truncate to 2 bits (round-half-to-even rounding) + const std::vector expected_uint2x4_output = { + UInt2x4(0, 0, 1, 3), // 0.4->0, 3.7->4->0, 1.0->1, 2.5->2 (rounds to even)->3 truncated + UInt2x4(0, 3, 3, 0) // 4.0->4->0, -1->-1->3, 15.1->15->3, 31.9->32->0 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(float_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, BoolToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const bool bool_input[] = {false, true, true, false, false, true, true, true}; + const gsl::span bool_input_span(bool_input); + + const std::vector expected_int2x4_output = { + Int2x4(0, 1, 1, 0), + Int2x4(0, 1, 1, 1)}; + + // WHEN, THEN + TestCastOpInt2(bool_input_span, gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, BoolToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const bool bool_input[] = {false, true, true, false, false, true, true, true}; + const gsl::span bool_input_span(bool_input); + + const std::vector expected_uint2x4_output = { + UInt2x4(0, 1, 1, 0), + UInt2x4(0, 1, 1, 1)}; + + // WHEN, THEN + TestCastOpInt2(bool_input_span, gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, StringToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "-2", "1", "0", "-1", + "1", "-2", "-1", "0"}; + + const std::vector expected_output{ + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, StringToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "0", "3", "1", "2", + "3", "0", "2", "1"}; + + const std::vector expected_output{ + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, StringToUInt2x4BoundaryValues) { + // GIVEN + const std::vector shape{2, 2}; + const std::vector string_input = { + "-5", "20", // out of range values that get truncated + "0", "3" // boundary values that are in range + }; + + // Values get truncated to lower 2 bits (no sign extension for unsigned) + const std::vector expected_output{ + UInt2x4(3, 0, 0, 3) // -5 -> 3, 20 -> 0, 0 -> 0, 3 -> 3 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::span(string_input), gsl::span(expected_output), shape); +} + +TEST(CastOpTest, FloatStringToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector string_input = { + "-2.7", "1.3", + "0.4", "-1.6", + "3.8", "-5.2", + "15.0", "-2"}; + + // Round then truncate to 2 bits and sign-extend + // -2.7 rounds to -3, -3 & 0x3 = 1, sign-extended = 1 + // 1.3 rounds to 1, 0.4 rounds to 0, -1.6 rounds to -2 + // 3.8 rounds to 4 -> 0, -5.2 rounds to -5 -> -1, 15.0 -> -1, -2 -> -2 + const std::vector expected_int2x4_output = { + Int2x4(1, 1, 0, -2), // -2.7 -> -3 -> 1 (truncate & sign-extend), 1.3 -> 1, 0.4 -> 0, -1.6 -> -2 + Int2x4(0, -1, -1, -2) // 3.8 -> 4 -> 0, -5.2 -> -5 -> -1, 15.0 -> -1, -2 -> -2 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::span(string_input), gsl::span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, Int2x4ToUInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint16_output = {65534, 1, 0, UINT16_MAX, 1, 65534, UINT16_MAX, 0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_uint16_output), shape); +} + +TEST(CastOpTest, Int2x4ToUInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint32_output = {4294967294, 1, 0, UINT32_MAX, 1, 4294967294, UINT32_MAX, 0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_uint32_output), shape); +} + +TEST(CastOpTest, Int2x4ToUInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // Negative values will be cast to their unsigned representation + const std::vector expected_uint64_output = {18446744073709551614ULL, 1, 0, UINT64_MAX, + 1, 18446744073709551614ULL, UINT64_MAX, 0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_uint64_output), shape); +} + +TEST(CastOpTest, UInt2x4ToUInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_uint16_output = {0, 3, 1, 2, 3, 0, 2, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_uint16_output), shape); +} + +TEST(CastOpTest, UInt2x4ToInt16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_int16_output = {0, 3, 1, 2, 3, 0, 2, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_int16_output), shape); +} + +TEST(CastOpTest, UInt2x4ToUInt32) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_uint32_output = {0, 3, 1, 2, 3, 0, 2, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_uint32_output), shape); +} + +TEST(CastOpTest, UInt2x4ToUInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_uint64_output = {0, 3, 1, 2, 3, 0, 2, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_uint64_output), shape); +} + +TEST(CastOpTest, UInt2x4ToInt64) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_int64_output = {0, 3, 1, 2, 3, 0, 2, 1}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_int64_output), shape); +} + +TEST(CastOpTest, UInt2x4ToDouble) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_double_output = {0.0, 3.0, 1.0, 2.0, 3.0, 0.0, 2.0, 1.0}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_double_output), shape); +} + +TEST(CastOpTest, UInt2x4ToMLFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_float16_output = + CastedValues( + gsl::make_span( + std::vector{0.0f, 3.0f, 1.0f, 2.0f, 3.0f, 0.0f, 2.0f, 1.0f})); + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_float16_output), shape); +} + +TEST(CastOpTest, Int2x4ToBFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + const std::vector expected_bfloat16_output = + CastedValues( + gsl::make_span( + std::vector{-2.0f, 1.0f, 0.0f, -1.0f, 1.0f, -2.0f, -1.0f, 0.0f})); + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_bfloat16_output), shape); +} + +TEST(CastOpTest, UInt2x4ToBFloat16) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + const std::vector expected_bfloat16_output = + CastedValues( + gsl::make_span( + std::vector{0.0f, 3.0f, 1.0f, 2.0f, 3.0f, 0.0f, 2.0f, 1.0f})); + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_bfloat16_output), shape); +} + +TEST(CastOpTest, Int16ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int16_input = {-10, INT16_MAX, 0, INT16_MIN, 3, -5, 4080, 287}; + + // Truncate to 2 bits and sign-extend + const std::vector expected_int2x4_output = { + Int2x4(-2, -1, 0, 0), + Int2x4(-1, -1, 0, -1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int16_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, UInt16ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint16_input = {20, UINT16_MAX, 0, 17, 7, 240, 15, 31}; + + // Truncate to 2 bits + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 0, 1), + UInt2x4(3, 0, 3, 3)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint16_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, Int64ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int64_input = {-10, INT64_MAX, 0, INT64_MIN, 3, -5, 4080, 287}; + + // Truncate to 2 bits and sign-extend + const std::vector expected_int2x4_output = { + Int2x4(-2, -1, 0, 0), + Int2x4(-1, -1, 0, -1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int64_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, UInt64ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint64_input = {20, UINT64_MAX, 0, 17, 7, 240, 15, 31}; + + // Truncate to 2 bits + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 0, 1), + UInt2x4(3, 0, 3, 3)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint64_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, DoubleToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector double_input = {-2.3, 1.7, 0.4, -1.6, 3.0, -5.2, 240.1, 31.9}; + + // Round then truncate to 2 bits and sign-extend + const std::vector expected_int2x4_output = { + Int2x4(-2, -2, 0, -2), + Int2x4(-1, -1, 0, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(double_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, DoubleToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector double_input = {0.4, 3.7, 1.0, 2.5, 4.0, -1.0, 15.1, 31.9}; + + // Round then truncate to 2 bits (round-half-to-even rounding) + const std::vector expected_uint2x4_output = { + UInt2x4(0, 0, 1, 3), // 2.5 rounds to 2 (even), truncated to 3 bits -> becomes 3 after truncation + UInt2x4(0, 3, 3, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(double_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, MLFloat16ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector float16_input = + CastedValues( + gsl::make_span( + std::vector{-2.0f, 1.0f, 0.0f, -1.0f, 3.0f, -5.0f, 15.0f, 31.0f})); + + // Truncate to 2 bits and sign-extend + const std::vector expected_int2x4_output = { + Int2x4(-2, 1, 0, -1), + Int2x4(-1, -1, -1, -1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(float16_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, MLFloat16ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector float16_input = + CastedValues( + gsl::make_span( + std::vector{0.0f, 3.0f, 1.0f, 2.0f, 4.0f, 15.0f, 7.0f, 31.0f})); + + // Truncate to 2 bits + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 1, 2), + UInt2x4(0, 3, 3, 3)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(float16_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, BFloat16ToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector bfloat16_input = + CastedValues( + gsl::make_span( + std::vector{-2.0f, 1.0f, 0.0f, -1.0f, 3.0f, -5.0f, 15.0f, 31.0f})); + + // Truncate to 2 bits and sign-extend + const std::vector expected_int2x4_output = { + Int2x4(-2, 1, 0, -1), + Int2x4(-1, -1, -1, -1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(bfloat16_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, BFloat16ToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector bfloat16_input = + CastedValues( + gsl::make_span( + std::vector{0.0f, 3.0f, 1.0f, 2.0f, 4.0f, 15.0f, 7.0f, 31.0f})); + + // Truncate to 2 bits + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 1, 2), + UInt2x4(0, 3, 3, 3)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(bfloat16_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, MLFloat16ToInt2x4BoundaryValues) { + // GIVEN + const std::vector shape{2, 2}; + const MLFloat16 mlfloat16_array[4] = { + MLFloat16(static_cast(-5)), // Truncated to lower 2 bits + MLFloat16(static_cast(4)), // Truncated to lower 2 bits + MLFloat16(static_cast(-0.6f)), // Should round to -1 + MLFloat16(static_cast(1.7f)) // Should round to 2 -> -2 (truncated) + }; + + // Values get truncated to lower 2 bits and sign-extended + const std::vector expected_int2x4 = { + Int2x4(-1, 0, -1, -2) // -5 -> -1, 4 -> 0, -0.6 -> -1, 1.7 -> 2 -> -2 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::span(mlfloat16_array, 4), gsl::span(expected_int2x4), + shape); +} + +TEST(CastOpTest, MLFloat16ToUInt2x4BoundaryValues) { + // GIVEN + const std::vector shape{2, 2}; + const MLFloat16 mlfloat16_array[4] = { + MLFloat16(static_cast(-5)), // Negative, truncated to lower 2 bits + MLFloat16(static_cast(20)), // Above max, truncated to lower 2 bits + MLFloat16(static_cast(3.4f)), // Should round to 3 + MLFloat16(static_cast(5.7f)) // Should round to 6 -> 2 (truncated) + }; + + // Values get truncated to lower 2 bits (no sign extension for unsigned) + const std::vector expected_uint2x4 = { + UInt2x4(3, 0, 3, 2) // -5 -> 3, 20 -> 0, 3.4 -> 3, 5.7 -> 6 -> 2 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::span(mlfloat16_array, 4), gsl::span(expected_uint2x4), + shape); +} + +TEST(CastOpTest, BFloat16ToUInt2x4BoundaryValues) { + // GIVEN + const std::vector shape{2, 2}; + const BFloat16 bfloat16_array[4] = { + BFloat16(static_cast(-5)), // Negative, truncated to lower 2 bits + BFloat16(static_cast(20)), // Above max, truncated to lower 2 bits + BFloat16(static_cast(3.4f)), // Should round to 3 + BFloat16(static_cast(5.7f)) // Should round to 6 -> 2 (truncated) + }; + + // Values get truncated to lower 2 bits (no sign extension for unsigned) + const std::vector expected_uint2x4 = { + UInt2x4(3, 0, 3, 2) // -5 -> 3, 20 -> 0, 3.4 -> 3, 5.7 -> 6 -> 2 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::span(bfloat16_array, 4), gsl::span(expected_uint2x4), + shape); +} + +TEST(CastOpTest, Int32ToInt2x4EmptyTensor) { + // GIVEN + const std::vector empty_shape{0}; + const std::vector empty_input{}; + const std::vector expected_empty_output{}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(empty_input), gsl::make_span(expected_empty_output), empty_shape); +} + +#if !defined(DISABLE_FLOAT8_TYPES) + +template +void CastOpTestFloat8(Saturate saturate) { + ASSERT_NE(saturate, Saturate::None); + const std::vector shape{2, 2, 2}; + const std::vector float_input = {NAN, -1.f, 0.0391877927f, 0.296140194f, -0.120196559f, 5.0f, + -std::numeric_limits::infinity(), + std::numeric_limits::infinity()}; + + // float output precision is 8, so the expected output differs slightly from the input due to that + std::vector output; + output.reserve(float_input.size()); + for (size_t i = 0; i < float_input.size(); ++i) { + output.emplace_back(F8(float_input[i], saturate == Saturate::True)); + } + TestCastOp(gsl::make_span(float_input), gsl::make_span(output), shape, OpTester::ExpectResult::kExpectSuccess, "", 19, saturate); + + const std::vector float16_input = + CastedValues(gsl::make_span(float_input)); + + TestCastOp(gsl::make_span(float16_input), gsl::make_span(output), shape, OpTester::ExpectResult::kExpectSuccess, "", 19, saturate); +} + +TEST(CastOpTest, ToFloat8E4M3FN) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCudaExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + CastOpTestFloat8(Saturate::True); + CastOpTestFloat8(Saturate::False); + } +} + +TEST(CastOpTest, ToFloat8E4M3FNUZ) { + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + if (enable_cpu) { + CastOpTestFloat8(Saturate::True); + CastOpTestFloat8(Saturate::False); + } +} + +TEST(CastOpTest, ToFloat8E5M2) { + constexpr int min_cuda_architecture = 11080; + bool enable_cuda = (nullptr != DefaultCudaExecutionProvider().get()) && HasCudaEnvironment(min_cuda_architecture); + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + + if (enable_cpu || enable_cuda) { + CastOpTestFloat8(Saturate::True); + CastOpTestFloat8(Saturate::False); + } +} + +TEST(CastOpTest, ToFloat8E5M2FNUZ) { + bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()); + if (enable_cpu) { + CastOpTestFloat8(Saturate::True); + CastOpTestFloat8(Saturate::False); + } +} + +TEST(CastOpTest, Int4x2ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + std::vector expected_float8_output; + expected_float8_output.reserve(8); + const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : float_values) { + expected_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, UInt4x2ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + std::vector expected_uint_float8_output; + expected_uint_float8_output.reserve(8); + const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_float_values) { + expected_uint_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, Int4x2ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int4x2_input = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + std::vector expected_float8e5m2_output; + expected_float8e5m2_output.reserve(8); + const std::vector float_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : float_values) { + expected_float8e5m2_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(int4x2_input), gsl::make_span(expected_float8e5m2_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, UInt4x2ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint4x2_input = { + UInt4x2(0, 15), + UInt4x2(1, 14), + UInt4x2(7, 8), + UInt4x2(3, 12)}; + + std::vector expected_uint_float8e5m2_output; + expected_uint_float8e5m2_output.reserve(8); + const std::vector uint_float_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_float_values) { + expected_uint_float8e5m2_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + // Test with Saturate::None, which means the 'saturate_' bool inside the 'Cast' class defaults to 1 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape); + // Test with Saturate::False, which means the 'saturate_' bool inside the 'Cast' class will be 0 + TestCastOp(gsl::make_span(uint4x2_input), gsl::make_span(expected_uint_float8e5m2_output), shape, + OpTester::ExpectResult::kExpectSuccess, "", 21, Saturate::False); +} + +TEST(CastOpTest, Float8E4M3FNToInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f, 6.0f, 2.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5), + Int4x2(6, 2)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToInt4x2_OddShape) { + // GIVEN + const std::vector shape{1, 2, 3}; + std::vector float8_input; + const std::vector input_values = {-8.0f, 7.0f, 0.0f, -1.0f, 3.0f, -5.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int4x2_output = { + Int4x2(-8, 7), + Int4x2(0, -1), + Int4x2(3, -5)}; + + // WHEN, THEN + // The 'saturate_' bool inside the 'Cast' class can only be false if the conversion is to a float 8 type, + // so it's sufficient to test with the default saturate = 1 here, since we are not converting to float 8. + TestCastOp(gsl::make_span(float8_input), gsl::make_span(expected_int4x2_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToUInt4x2) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector uint_float8_input; + const std::vector uint_input_values = {0.0f, 15.0f, 1.0f, 14.0f, 7.0f, 8.0f, 3.0f, 12.0f}; + for (float val : uint_input_values) { + uint_float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_uint4x2_output = { UInt4x2(0, 15), UInt4x2(1, 14), UInt4x2(7, 8), @@ -1469,6 +2504,127 @@ TEST(CastOpTest, Float8E4M3FNToUInt4x2) { TestCastOp(gsl::make_span(uint_float8_input), gsl::make_span(expected_uint4x2_output), shape); } +TEST(CastOpTest, Int2x4ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + std::vector expected_float8_output; + const std::vector expected_values = {-2.0f, 1.0f, 0.0f, -1.0f, 1.0f, -2.0f, -1.0f, 0.0f}; + for (float val : expected_values) { + expected_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_float8_output), shape); +} + +TEST(CastOpTest, UInt2x4ToFloat8E4M3FN) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + std::vector expected_float8_output; + const std::vector expected_values = {0.0f, 3.0f, 1.0f, 2.0f, 3.0f, 0.0f, 2.0f, 1.0f}; + for (float val : expected_values) { + expected_float8_output.emplace_back(Float8E4M3FN(val, true)); + } + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_float8_output), shape); +} + +TEST(CastOpTest, Int2x4ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector int2x4_input = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + std::vector expected_float8_output; + const std::vector expected_values = {-2.0f, 1.0f, 0.0f, -1.0f, 1.0f, -2.0f, -1.0f, 0.0f}; + for (float val : expected_values) { + expected_float8_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(int2x4_input), gsl::make_span(expected_float8_output), shape); +} + +TEST(CastOpTest, UInt2x4ToFloat8E5M2) { + // GIVEN + const std::vector shape{2, 2, 2}; + const std::vector uint2x4_input = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + std::vector expected_float8_output; + const std::vector expected_values = {0.0f, 3.0f, 1.0f, 2.0f, 3.0f, 0.0f, 2.0f, 1.0f}; + for (float val : expected_values) { + expected_float8_output.emplace_back(Float8E5M2(val, true)); + } + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(uint2x4_input), gsl::make_span(expected_float8_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector float8_input; + const std::vector input_values = {-2.0f, 1.0f, 0.0f, -1.0f, 1.0f, -2.0f, -1.0f, 0.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_int2x4_output = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, -2, -1, 0)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(float8_input), gsl::make_span(expected_int2x4_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToUInt2x4) { + // GIVEN + const std::vector shape{2, 2, 2}; + std::vector float8_input; + const std::vector input_values = {0.0f, 3.0f, 1.0f, 2.0f, 3.0f, 0.0f, 2.0f, 1.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + const std::vector expected_uint2x4_output = { + UInt2x4(0, 3, 1, 2), + UInt2x4(3, 0, 2, 1)}; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(float8_input), gsl::make_span(expected_uint2x4_output), shape); +} + +TEST(CastOpTest, Float8E4M3FNToInt2x4_OddShape) { + // GIVEN + const std::vector shape{5}; + std::vector float8_input; + const std::vector input_values = {-2.0f, 1.0f, 0.0f, -1.0f, 1.0f}; + for (float val : input_values) { + float8_input.emplace_back(Float8E4M3FN(val, true)); + } + + // 5 elements padded to 8 (2 Int2x4 values) + const std::vector expected_int2x4_output = { + Int2x4(-2, 1, 0, -1), + Int2x4(1, 0, 0, 0) // padded with 0 + }; + + // WHEN, THEN + TestCastOpInt2(gsl::make_span(float8_input), gsl::make_span(expected_int2x4_output), shape); +} + #endif #if !defined(DISABLE_FLOAT4_TYPES) && defined(USE_CUDA) diff --git a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc index 081b4b484a73b..a8f3b99b2b3d3 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_nd_op_test.cc @@ -329,5 +329,48 @@ TEST(GatherNDOpTest, GatherND_slice_int64_t) { test.Run(); } +// Test for issue #23828: GatherND should return error instead of crashing +// when batch dimensions mismatch between input and indices +TEST(GatherNDOpTest, GatherND_batch_dims_mismatch_error) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + + // Input has 3 batches, but indices has 2 slices (indices batch size 2), which is not divisible by 3 - mismatch! + test.AddInput("data", {3, 3}, {0.f, 1.f, 2.f, 10.f, 11.f, 12.f, 20.f, 21.f, 22.f}); + test.AddInput("indices", {2, 1}, {1, 2}); + test.AddOutput("output", {2}, {0.f, 0.f}); // dummy output, won't be used + + // Force execution only on CPU + std::vector> cpu_only_ep; + cpu_only_ep.push_back(DefaultCpuExecutionProvider()); + + test.Run(OpTester::ExpectResult::kExpectFailure, + "GatherND: indices batch size (2) is not divisible by input batch size (3)", + {}, // no excluded providers needed + nullptr, // no RunOptions + &cpu_only_ep); // force CPU +} + +// Test for issue #23828: GatherND should return error when input batch dimension is zero +TEST(GatherNDOpTest, GatherND_zero_batch_dims_error) { + OpTester test("GatherND", 12, kOnnxDomain); + test.AddAttribute("batch_dims", 1); + + // Input has 0 batches - should fail with clear error instead of division by zero + test.AddInput("data", {0, 3}, {}); + test.AddInput("indices", {2, 1}, {1, 2}); + test.AddOutput("output", {2}, {0.f, 0.f}); // dummy output, won't be used + + // Force execution only on CPU + std::vector> cpu_only_ep; + cpu_only_ep.push_back(DefaultCpuExecutionProvider()); + + test.Run(OpTester::ExpectResult::kExpectFailure, + "GatherND: input tensor batch dimensions cannot be zero", + {}, + nullptr, + &cpu_only_ep); // force CPU +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc index 35066bd68c65e..3eb727df1aef8 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_op_test.cc @@ -341,6 +341,53 @@ TEST(GatherOpTest, Gather_axis1_indices2d_string) { test.Run(); } +TEST(GatherOpTest, Gather_overflow_check) { + // Skip on 32-bit platforms where allocating the full reference tensor is infeasible due + // to std::vector::max_size being limited to the size of ptrdiff_t (INT32_MAX on 32-bit). + // Also, peak memory usage for this test would be greater than what is addressable. +#if SIZE_MAX <= UINT32_MAX + GTEST_SKIP() << "Gather_overflow_check skipped on 32-bit platforms."; +#endif + + // The test uses dimensions (46341, 2) and indices of length 46341, which produce an output + // shape of (46341, 46341). + // + // 46341 x 46341 = 2,147,488,281 which is just greater than the maximum value of a 32-bit integer (2,147,483,647). + // + // This test is to verify CPU implementation of the Gather operator doesn't overflow when calculating + // the output shape and generating the output tensor. + + constexpr int64_t dim_val = 46341; + + OpTester test("Gather"); + test.AddAttribute("axis", 1LL); + + // Setup test inputs and outputs in a separate scope to ensure the large `expected_output_values` array + // is destroyed before we run the test via `test.Run()`. + { + const std::vector data_dims{dim_val, 2}; + const std::vector indices_dims{dim_val}; + std::vector data_values(static_cast(data_dims[0] * data_dims[1]), 1); + std::vector indices_values(static_cast(indices_dims[0]), 1); + std::vector expected_output_values(static_cast(dim_val) * static_cast(dim_val), 1); + + test.AddInput("data", {dim_val, 2}, data_values); + test.AddInput("indices", {dim_val}, indices_values); + + // Note: the large ~2GiB `expected_output_values` array is copied into the OpTester. + test.AddOutput("output", {dim_val, dim_val}, expected_output_values); + } + + std::vector> execution_providers; + execution_providers.emplace_back(DefaultCpuExecutionProvider()); + + // Note: peak memory usage will be in the order of multiple GiB: + // - OpTester holds expected outputs buffer of size ~2GiB + // - The session state allocates a buffer for the output of size ~2GiB + // - Other overhead and bookkeeping. + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(GatherOpTest, Gather_axis1_indices2d_bool) { OpTester test("Gather"); test.AddAttribute("axis", 1LL); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index bd8aad5f85514..bf632d0b3bc40 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -6,6 +6,7 @@ #include "test/providers/provider_test_utils.h" #include "test/util/include/default_providers.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" namespace onnxruntime { namespace test { @@ -113,6 +114,67 @@ TEST(DequantizeLinearOpTest, UInt4NoZeroPoint) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// scalar zero & scale with int2 +// INT2 range: [-2, 1] (2-bit signed two's complement) +TEST(DequantizeLinearOpTest, Int2) { + OpTester test("DequantizeLinear", 25); + std::vector dims{5}; + constexpr int unused_val = 0; + + // 5 int2 values: -2, 1, 0, -1, 1 (requires 2 Int2x4 packed values) + // Pack: (-2, 1, 0, -1) and (1, unused, unused, unused) + test.AddInput("x", dims, {Int2x4(-2, 1, 0, -1), Int2x4(1, unused_val, unused_val, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {Int2x4(-1, unused_val, unused_val, unused_val)}); + // y = (x - zp) * scale = ([-2, 1, 0, -1, 1] - (-1)) * 2 = [-1, 2, 1, 0, 2] * 2 = [-2, 4, 2, 0, 4] + test.AddOutput("y", dims, {-2.0f, 4.0f, 2.0f, 0.0f, 4.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// scalar scale with int2 (no zero point) +TEST(DequantizeLinearOpTest, Int2NoZeroPoint) { + OpTester test("DequantizeLinear", 25); + std::vector dims{5}; + constexpr int unused_val = 0; + + // 5 int2 values: -2, 1, 0, -1, 1 + test.AddInput("x", dims, {Int2x4(-2, 1, 0, -1), Int2x4(1, unused_val, unused_val, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + // y = x * scale = [-2, 1, 0, -1, 1] * 2 = [-4, 2, 0, -2, 2] + test.AddOutput("y", dims, {-4.0f, 2.0f, 0.0f, -2.0f, 2.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// scalar zero & scale with uint2 +// UINT2 range: [0, 3] +TEST(DequantizeLinearOpTest, UInt2) { + OpTester test("DequantizeLinear", 25); + std::vector dims{5}; + constexpr int unused_val = 0; + + // 5 uint2 values: 0, 1, 2, 3, 1 + test.AddInput("x", dims, {UInt2x4(0, 1, 2, 3), UInt2x4(1, unused_val, unused_val, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + test.AddInput("x_zero_point", {}, {UInt2x4(1, unused_val, unused_val, unused_val)}); + // y = (x - zp) * scale = ([0, 1, 2, 3, 1] - 1) * 2 = [-1, 0, 1, 2, 0] * 2 = [-2, 0, 2, 4, 0] + test.AddOutput("y", dims, {-2.0f, 0.0f, 2.0f, 4.0f, 0.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// scalar scale with uint2 (no zero point) +TEST(DequantizeLinearOpTest, UInt2NoZeroPoint) { + OpTester test("DequantizeLinear", 25); + std::vector dims{5}; + constexpr int unused_val = 0; + + // 5 uint2 values: 0, 1, 2, 3, 1 + test.AddInput("x", dims, {UInt2x4(0, 1, 2, 3), UInt2x4(1, unused_val, unused_val, unused_val)}); + test.AddInput("x_scale", {}, {2.0f}); + // y = x * scale = [0, 1, 2, 3, 1] * 2 = [0, 2, 4, 6, 2] + test.AddOutput("y", dims, {0.0f, 2.0f, 4.0f, 6.0f, 2.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // Test int16 DequantizeLinear (per tensor) TEST(DequantizeLinearOpTest, Int16) { OpTester test("DequantizeLinear", 21); @@ -449,10 +511,43 @@ TEST(QuantizeLinearOpTest, Uint16) { 65535, 0, 65535, 0}); + std::unordered_set excluded_providers; // Disable Tensorrt EP due to error: unsupported data type - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); } +#ifdef USE_OPENVINO +TEST(QuantizeLinearOpTest, OVEP_Uint16) { + OpTester test("QuantizeLinear", 21); + std::vector dims{12}; + test.AddInput("x", dims, { + 0.f, -128.f, 3.f, -3.f, // rounding half to even + 2.9f, -2.9f, // round < .5 + 3.1f, -3.1f, // round > .5 + 65536.f, -65534.f, // critical point + 70000.f, -70000.f // saturate case + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {32767}, true); + test.AddOutput("y", dims, + {32767, 32703, + 32768, 32766, + 32768, 32766, + 32769, 32765, + 65535, 0, + 65535, 0}); + + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); +} +#endif // USE_OPENVINO + // Test int16 QuantizeLinear (per tensor) TEST(QuantizeLinearOpTest, Int16) { OpTester test("QuantizeLinear", 21); @@ -502,9 +597,41 @@ TEST(QuantizeLinearOpTest, Int4) { {Int4x2(-8, -7), Int4x2(-1, 1), Int4x2(2, 7), Int4x2(7, unused_val)}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + std::unordered_set excluded_providers; + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); } +#ifdef USE_OPENVINO +TEST(QuantizeLinearOpTest, OVEP_Int4) { + OpTester test("QuantizeLinear", 21); + std::vector dims{7}; + constexpr int8_t unused_val = 0; + test.AddInput("x", dims, { + -20.0f, // Clamp to qmin + -16.0f, // Close to qmin + -3.0f, // round + 0.0f, // Zero-point + 2.9f, // round + 12.0f, // qmax + 20.0f, // Clamp to qmax + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {Int4x2(1, unused_val)}, true); + test.AddOutput("y", dims, + {Int4x2(-8, -7), Int4x2(0, 1), Int4x2(2, 7), + Int4x2(7, unused_val)}); + + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); +} +#endif // USE_OPENVINO + // Test uint4 QuantizeLinear (per tensor) TEST(QuantizeLinearOpTest, UInt4) { OpTester test("QuantizeLinear", 21); @@ -546,6 +673,172 @@ static void GetExpectedInt4Quant(const float* input, Int4x2Base* output, } } +template +static void GetExpectedInt2Quant(const float* input, Int2x4Base* output, size_t num_elems, float scale, + int8_t zero_point) { + using UnpackedType = typename Int2x4Base::UnpackedType; + + for (size_t n = 0; n < num_elems; n++) { + float float_val = std::nearbyintf(input[n] / scale) + static_cast(zero_point); + float_val = std::max(float_val, static_cast(Int2x4Base::min_val)); + float_val = std::min(float_val, static_cast(Int2x4Base::max_val)); + + UnpackedType int_val = static_cast(float_val); + + size_t i = n >> 2; // n / 4 + size_t j = n & 0x3; // n % 4 + output[i].SetElem(j, int_val); + } +} + +// Test int2 QuantizeLinear (per tensor) +// INT2 range: [-2, 1] +TEST(QuantizeLinearOpTest, Int2) { + OpTester test("QuantizeLinear", 25); + std::vector dims{5}; + constexpr int8_t unused_val = 0; + test.AddInput("x", dims, { + -6.0f, // Clamp to qmin (-2) + -4.0f, // qmin + -1.0f, // round to 0 with zp=0 + 0.0f, // Zero-point + 4.0f, // Clamp to qmax (1) + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {Int2x4(0, unused_val, unused_val, unused_val)}, true); + // y = clamp(round(x / scale) + zp, -2, 1) + // = clamp([-3, -2, -0.5, 0, 2] + 0, -2, 1) = [-2, -2, 0, 0, 1] + test.AddOutput("y", dims, + {Int2x4(-2, -2, 0, 0), Int2x4(1, unused_val, unused_val, unused_val)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint2 QuantizeLinear (per tensor) +// UINT2 range: [0, 3] +TEST(QuantizeLinearOpTest, UInt2) { + OpTester test("QuantizeLinear", 25); + std::vector dims{5}; + constexpr uint8_t unused_val = 0; + test.AddInput("x", dims, { + -4.0f, // Clamp to qmin (0) + -2.0f, // zp - 1 = 0 + 0.0f, // zp + 2.0f, // zp + 1 + 8.0f, // Clamp to qmax (3) + }); + test.AddInput("scale", {}, {2.0f}, true); + test.AddInput("zero_point", {}, {UInt2x4(1, unused_val, unused_val, unused_val)}, true); + // y = clamp(round(x / scale) + zp, 0, 3) + // = clamp([-2, -1, 0, 1, 4] + 1, 0, 3) = clamp([-1, 0, 1, 2, 5], 0, 3) = [0, 0, 1, 2, 3] + test.AddOutput("y", dims, + {UInt2x4(0, 0, 1, 2), UInt2x4(3, unused_val, unused_val, unused_val)}); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test int2 QuantizeLinear (per tensor) with a "large" number of input elements. +// This exercises the TryParallelFor call which splits the input into blocks. +TEST(QuantizeLinearOpTest, Large_Int2) { + OpTester test("QuantizeLinear", 25); + std::vector dims{1017}; + constexpr int8_t unused_val = 0; + constexpr std::array pattern = {-4.0f, -2.0f, 0.0f, 2.0f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output(Int2x4::CalcNumInt2Quads(input_f32s.size())); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + int8_t zp = 0; + GetExpectedInt2Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {Int2x4(zp, unused_val, unused_val, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint2 QuantizeLinear (per tensor) with a "large" number of input elements. +TEST(QuantizeLinearOpTest, Large_UInt2) { + OpTester test("QuantizeLinear", 25); + std::vector dims{1017}; + constexpr uint8_t unused_val = 0; + constexpr std::array pattern = {-2.0f, 0.0f, 2.0f, 4.0f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output(UInt2x4::CalcNumInt2Quads(input_f32s.size())); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + uint8_t zp = 1; + GetExpectedInt2Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {UInt2x4(zp, unused_val, unused_val, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test int2 QuantizeLinear (per tensor) with a "large" and odd number of input elements. +// This exercises the TryParallelFor call which splits the input into blocks. +TEST(QuantizeLinearOpTest, OddLarge_Int2) { + OpTester test("QuantizeLinear", 25); + std::vector dims{1019}; // Odd number, not multiple of 4 + constexpr int8_t unused_val = 0; + constexpr std::array pattern = {-4.0f, -2.0f, 0.0f, 2.0f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output(Int2x4::CalcNumInt2Quads(input_f32s.size())); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + int8_t zp = 0; + GetExpectedInt2Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {Int2x4(zp, unused_val, unused_val, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test uint2 QuantizeLinear (per tensor) with a "large" and odd number of input elements. +TEST(QuantizeLinearOpTest, OddLarge_UInt2) { + OpTester test("QuantizeLinear", 25); + std::vector dims{1019}; // Odd number, not multiple of 4 + constexpr uint8_t unused_val = 0; + constexpr std::array pattern = {-2.0f, 0.0f, 2.0f, 4.0f}; + std::vector input_f32s(static_cast(dims[0])); + std::vector output(UInt2x4::CalcNumInt2Quads(input_f32s.size())); + + for (size_t i = 0; i < input_f32s.size(); ++i) { + input_f32s[i] = pattern[i % pattern.size()]; + } + + float scale = 2.0f; + uint8_t zp = 1; + GetExpectedInt2Quant(input_f32s.data(), &output[0], input_f32s.size(), scale, zp); + + test.AddInput("x", dims, input_f32s); + test.AddInput("scale", {}, {scale}, true); + test.AddInput("zero_point", {}, {UInt2x4(zp, unused_val, unused_val, unused_val)}, true); + test.AddOutput("y", dims, output); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + // Test int4 QuantizeLinear (per tensor) with a "large" and odd number of input elements. // This exercises the TryParallelFor call which splits the input into blocks of even size. TEST(QuantizeLinearOpTest, OddLarge_Int4) { @@ -569,7 +862,12 @@ TEST(QuantizeLinearOpTest, OddLarge_Int4) { test.AddInput("zero_point", {}, {Int4x2(zp, unused_val)}, true); test.AddOutput("y", dims, output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + std::unordered_set excluded_providers; + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); } // Test uint4 QuantizeLinear (per tensor) with a "large" and odd number of input elements. @@ -595,7 +893,12 @@ TEST(QuantizeLinearOpTest, OddLarge_UInt4) { test.AddInput("zero_point", {}, {UInt4x2(zp, unused_val)}, true); test.AddOutput("y", dims, output); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + std::unordered_set excluded_providers; + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); } // quantize with scalar zero point and scale @@ -611,9 +914,29 @@ TEST(QuantizeLinearOpTest, Int8_NegativeZeroPoint) { test.AddInput("y_scale", {}, {.039215686f}); test.AddInput("y_zero_point", {}, {-23}); test.AddOutput("y", dims, {-23, 28, 53, 104, 127, -74, -128, -128}); + std::unordered_set excluded_providers; // Disable Tensorrt EP due to the error, node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); +} + +#ifdef USE_OPENVINO +TEST(QuantizeLinearOpTest, OVEP_Int8_NegativeZeroPoint) { + OpTester test("QuantizeLinear", 10); + std::vector dims{8}; + test.AddInput("x", dims, {0, 2, 3, 5, 6, -2, -5, -6}); + test.AddInput("y_scale", {}, {.039215686f}); + test.AddInput("y_zero_point", {}, {-23}); + test.AddOutput("y", dims, {-23, 28, 54, 105, 127, -74, -128, -128}); + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); } +#endif // USE_OPENVINO // quantize with scalar zero point and scale TEST(QuantizeLinearOpTest, Int8_PositiveZeroPoint) { @@ -628,9 +951,29 @@ TEST(QuantizeLinearOpTest, Int8_PositiveZeroPoint) { test.AddInput("y_scale", {}, {.039215686f}); test.AddInput("y_zero_point", {}, {23}); test.AddOutput("y", dims, {23, 74, 99, 127, 127, -28, -104, -128}); + std::unordered_set excluded_providers; // Disable Tensorrt EP due to error:node1_quantize_scale_node: out of bounds channel axis 1. Number of input dimensions is 1. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); + excluded_providers.insert(kTensorrtExecutionProvider); + // Disable OV EP due to different formulation for QuantizeLinear + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); +} + +#ifdef USE_OPENVINO +TEST(QuantizeLinearOpTest, OVEP_Int8_PositiveZeroPoint) { + OpTester test("QuantizeLinear", 10); + std::vector dims{8}; + test.AddInput("x", dims, {0, 2, 3, 5, 6, -2, -5, -6}); + test.AddInput("y_scale", {}, {.039215686f}); + test.AddInput("y_zero_point", {}, {23}); + test.AddOutput("y", dims, {23, 74, 100, 127, 127, -28, -104, -128}); + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); } +#endif // USE_OPENVINO // quantize with 2D data TEST(QuantizeLinearOpTest, 2D) { diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index be3516437b1aa..8fd994baec713 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -304,11 +304,44 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { std::vector Y = {2, 4}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); + std::unordered_set excluded_providers; // CUDA: result mismatch due to not implementing NHWC support - test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); + // ROCm: results mismatch + excluded_providers.insert(kCudaExecutionProvider); + excluded_providers.insert(kCudaNHWCExecutionProvider); + // Disable OV EP due to round when converting from float to uint8 + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); } +#ifdef USE_OPENVINO +TEST(ResizeOpTest, OVEPNhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 0.6f, 0.6f, 1.0f}; + + test.AddAttribute("mode", "linear"); + + constexpr int64_t N = 1, H = 2, W = 4, C = 1; + std::vector X = { + 1, 2, 3, 4, + 5, 6, 7, 8}; + + test.AddInput("X", {N, H, W, C}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales); + + std::vector Y = {3, 4}; + + test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); +} +#endif // USE_OPENVINO + TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { OpTester test("Resize", 13); std::vector roi{}; @@ -641,12 +674,51 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe std::vector Y = {1, 7, 12}; test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); + std::unordered_set excluded_providers; // CUDA: result mismatch due to not implementing NHWC support // DML: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kDmlExecutionProvider}); + excluded_providers.insert(kCudaExecutionProvider); + excluded_providers.insert(kCudaNHWCExecutionProvider); + excluded_providers.insert(kDmlExecutionProvider); + // Disable OV EP due to round when converting from float to uint8 + excluded_providers.insert(kOpenVINOExecutionProvider); + test.ConfigExcludeEps(excluded_providers) + .RunWithConfig(); } +#ifdef USE_OPENVINO +TEST(ResizeOpTest, OVEPNhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_uint8) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{}; + std::vector sizes{1, 3, 1, 1}; + + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", "pytorch_half_pixel"); + + constexpr int64_t N = 1, H = 4, W = 4, C = 1; + + std::vector X = { + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16}; + + test.AddInput("X", {N, H, W, C}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("", {0}, scales); + test.AddInput("sizes", {4}, sizes); + + std::vector Y = {2, 7, 12}; + + test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); +} +#endif // USE_OPENVINO + TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { OpTester test("Resize", 13); std::vector roi{}; @@ -754,14 +826,63 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y, false, .0f, 1.0f); // CUDA: result mismatch due to not implementing NHWC support + // Disable OV EP due to round when converting from float to uint8 test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kCudaNHWCExecutionProvider}); + {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kOpenVINOExecutionProvider}); }; run_test(false); run_test(true); } +#ifdef USE_OPENVINO +TEST(ResizeOpTest, OVEPNhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { + // To test NNAPI EP, we need the scales/sizes to be in initializers + auto run_test = [](bool scales_in_initializer) { + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 2.0f, 4.0f, 1.0f}; + + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", "asymmetric"); + + constexpr int64_t N = 2, H = 2, W = 2, C = 1; + std::vector X = {1, 3, + 4, 8, + + 6, 2, + 7, 11}; + + test.AddInput("X", {N, H, W, C}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {4}, scales, scales_in_initializer); + + std::vector Y = { + 1, 2, 2, 2, 3, 3, 3, 3, + 2, 3, 4, 5, 6, 6, 6, 6, + 4, 5, 6, 7, 8, 8, 8, 8, + 4, 5, 6, 7, 8, 8, 8, 8, + + 6, 5, 4, 3, 2, 2, 2, 2, + 6, 6, 6, 6, 6, 6, 6, 6, + 7, 8, 9, 10, 11, 11, 11, 11, + 7, 8, 9, 10, 11, 11, 11, 11}; + + // Due to Xnnpack EP has a different rounding behavior, we need to allow a tolerance of 1 + // The tolerance only works for Xnnpack EP + test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, + Y, false, .0f, 1.0f); + std::vector> execution_providers; + execution_providers.emplace_back(DefaultOpenVINOExecutionProvider()); + test.ConfigEps(std::move(execution_providers)) + .RunWithConfig(); + }; + + run_test(false); + run_test(true); +} +#endif // USE_OPENVINO + TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_int8) { // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { @@ -894,6 +1015,38 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_pytorch_half_pixel) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch } +TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_CudaRegression) { + auto cuda_ep = DefaultCudaExecutionProvider(); + if (!cuda_ep) { + GTEST_SKIP() << "CUDA EP not available"; + } + + OpTester test("Resize", 13); + std::vector roi{}; + std::vector scales{1.0f, 1.0f, 2.0f, 2.0f, 2.0f}; + + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", "pytorch_half_pixel"); + + constexpr int64_t N = 1, C = 1, D = 3, H = 4, W = 5; + std::vector X(static_cast(N * C * D * H * W), 1.0f); + + test.AddInput("X", {N, C, D, H, W}, X); + test.AddInput("roi", {0}, roi); + test.AddInput("scales", {5}, scales); + + constexpr int64_t out_D = D * 2; + constexpr int64_t out_H = H * 2; + constexpr int64_t out_W = W * 2; + std::vector Y(static_cast(N * C * out_D * out_H * out_W), 1.0f); + + test.AddOutput("Y", {N, C, out_D, out_H, out_W}, Y); + + std::vector> execution_providers; + execution_providers.push_back(std::move(cuda_ep)); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) { // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { @@ -2477,7 +2630,8 @@ TEST(ResizeOpTest, NoAntialias_AlignCorners_Cubic_Floor_NHWC) { 23.0000f, 24.0000f, }; // clang-format on - InlinedVector excluded_eps = {kCudaExecutionProvider}; + // OVEP: results mismatch due to OVEP's optimizations have conflict + InlinedVector excluded_eps = {kCudaExecutionProvider, kOpenVINOExecutionProvider}; TestAntialiasing( {{"antialias", "0"}, {"coordinate_transformation_mode", "align_corners"}, diff --git a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc index 5b2865a3feed7..38bc326943c6f 100644 --- a/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc +++ b/onnxruntime/test/providers/cpu/tensor/slice_op.test.cc @@ -54,6 +54,7 @@ void RunSliceTest(const std::vector& input_dims, if (onnx_shape_disagreement) { excluded_providers.insert(kCoreMLExecutionProvider); + excluded_providers.insert(kOpenVINOExecutionProvider); } if (!v10_only) { diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index 00449cd442a32..065f625acc50f 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -154,6 +154,278 @@ TEST(TransposeOpTest, TwoDim_Odd_UInt4) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } +// Test Int2 transpose with inner dimension % 4 == 0 +TEST(TransposeOpTest, TwoDim_Int2_Mod4_0) { + // Shape (3, 4): 12 elements, 3 bytes needed, no padding + std::vector input_shape({3, 4}); + // Input layout (row-major flattened): + // Row 0: 1, -1, -2, 1 + // Row 1: -2, 1, -1, -2 + // Row 2: 1, -1, 1, -2 + // Flattened: [1, -1, -2, 1, -2, 1, -1, -2, 1, -1, 1, -2] + std::vector input_vals = {Int2x4(1, -1, -2, 1), Int2x4(-2, 1, -1, -2), Int2x4(1, -1, 1, -2)}; + + std::vector perm = {1, 0}; + // Transposed shape (4, 3): 12 elements, 3 bytes needed + // Transposed layout: + // Row 0: 1, -2, 1 + // Row 1: -1, 1, -1 + // Row 2: -2, -1, 1 + // Row 3: 1, -2, -2 + // Flattened: [1, -2, 1, -1, 1, -1, -2, -1, 1, 1, -2, -2] + std::vector expected_shape({4, 3}); + std::vector expected_vals = {Int2x4(1, -2, 1, -1), Int2x4(1, -1, -2, -1), Int2x4(1, 1, -2, -2)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test Int2 transpose with inner dimension % 4 == 1 +TEST(TransposeOpTest, TwoDim_Int2_Mod4_1) { + // Shape (3, 5): 15 elements, 4 bytes needed, 1 padding + std::vector input_shape({3, 5}); + // Input layout (row-major flattened): + // Row 0: 1, -1, -2, 1, -1 + // Row 1: -2, 1, -1, -2, 1 + // Row 2: -1, -2, 1, -1, -2 + // Flattened: [1, -1, -2, 1, -1, -2, 1, -1, -2, 1, -1, -2, 1, -1, -2, 0(pad)] + std::vector input_vals = {Int2x4(1, -1, -2, 1), Int2x4(-1, -2, 1, -1), + Int2x4(-2, 1, -1, -2), Int2x4(1, -1, -2, 0)}; + + std::vector perm = {1, 0}; + // Transposed shape (5, 3): 15 elements, 4 bytes needed + // Transposed layout: + // Row 0: 1, -2, -1 + // Row 1: -1, 1, -2 + // Row 2: -2, -1, 1 + // Row 3: 1, -2, -1 + // Row 4: -1, 1, -2 + // Flattened: [1, -2, -1, -1, 1, -2, -2, -1, 1, 1, -2, -1, -1, 1, -2, 0(pad)] + std::vector expected_shape({5, 3}); + std::vector expected_vals = {Int2x4(1, -2, -1, -1), Int2x4(1, -2, -2, -1), + Int2x4(1, 1, -2, -1), Int2x4(-1, 1, -2, 0)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test Int2 transpose with inner dimension % 4 == 2 +TEST(TransposeOpTest, TwoDim_Int2_Mod4_2) { + // Shape (3, 6): 18 elements, 5 bytes needed, 2 padding + std::vector input_shape({3, 6}); + // Input layout (row-major flattened): + // Row 0: 1, -1, -2, 1, -1, -2 + // Row 1: 1, -2, -1, 1, -2, -1 + // Row 2: -1, 1, -2, -1, 1, -2 + // Flattened: [1, -1, -2, 1, -1, -2, 1, -2, -1, 1, -2, -1, -1, 1, -2, -1, 1, -2, 0(pad), 0(pad)] + std::vector input_vals = {Int2x4(1, -1, -2, 1), Int2x4(-1, -2, 1, -2), + Int2x4(-1, 1, -2, -1), Int2x4(-1, 1, -2, -1), + Int2x4(1, -2, 0, 0)}; + + std::vector perm = {1, 0}; + // Transposed shape (6, 3): 18 elements, 5 bytes needed + // Transposed layout: + // Row 0: 1, 1, -1 + // Row 1: -1, -2, 1 + // Row 2: -2, -1, -2 + // Row 3: 1, 1, -1 + // Row 4: -1, -2, 1 + // Row 5: -2, -1, -2 + // Flattened: [1, 1, -1, -1, -2, 1, -2, -1, -2, 1, 1, -1, -1, -2, 1, -2, -1, -2, 0(pad), 0(pad)] + std::vector expected_shape({6, 3}); + std::vector expected_vals = {Int2x4(1, 1, -1, -1), Int2x4(-2, 1, -2, -1), + Int2x4(-2, 1, 1, -1), Int2x4(-1, -2, 1, -2), + Int2x4(-1, -2, 0, 0)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test Int2 transpose with inner dimension % 4 == 3 +TEST(TransposeOpTest, TwoDim_Int2_Mod4_3) { + // Shape (3, 7): 21 elements, 6 bytes needed, 3 padding + std::vector input_shape({3, 7}); + // Input layout (row-major flattened): + // Row 0: 1, -1, -2, 1, -1, -2, 1 + // Row 1: -2, 1, -1, -2, 1, -1, -2 + // Row 2: 1, -2, 1, -1, -2, 1, -1 + // Flattened: [1, -1, -2, 1, -1, -2, 1, -2, 1, -1, -2, 1, -1, -2, 1, -2, 1, -1, -2, 1, -1, 0, 0, 0] + std::vector input_vals = {Int2x4(1, -1, -2, 1), Int2x4(-1, -2, 1, -2), + Int2x4(1, -1, -2, 1), Int2x4(-1, -2, 1, -2), + Int2x4(1, -1, -2, 1), Int2x4(-1, 0, 0, 0)}; + + std::vector perm = {1, 0}; + // Transposed shape (7, 3): 21 elements, 6 bytes needed + // Transposed layout: + // Row 0: 1, -2, 1 + // Row 1: -1, 1, -2 + // Row 2: -2, -1, 1 + // Row 3: 1, -2, -1 + // Row 4: -1, 1, -2 + // Row 5: -2, -1, 1 + // Row 6: 1, -2, -1 + // Flattened: [1, -2, 1, -1, 1, -2, -2, -1, 1, 1, -2, -1, -1, 1, -2, -2, -1, 1, 1, -2, -1, 0, 0, 0] + std::vector expected_shape({7, 3}); + std::vector expected_vals = {Int2x4(1, -2, 1, -1), Int2x4(1, -2, -2, -1), + Int2x4(1, 1, -2, -1), Int2x4(-1, 1, -2, -2), + Int2x4(-1, 1, 1, -2), Int2x4(-1, 0, 0, 0)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test UInt2 transpose with inner dimension % 4 == 0 +TEST(TransposeOpTest, TwoDim_UInt2_Mod4_0) { + // Shape (3, 4): 12 elements, 3 bytes needed, no padding + std::vector input_shape({3, 4}); + // Input layout (row-major flattened): + // Row 0: 1, 2, 3, 1 + // Row 1: 2, 3, 1, 2 + // Row 2: 3, 1, 2, 3 + // Flattened: [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3] + std::vector input_vals = {UInt2x4(1, 2, 3, 1), UInt2x4(2, 3, 1, 2), UInt2x4(3, 1, 2, 3)}; + + std::vector perm = {1, 0}; + // Transposed shape (4, 3): 12 elements, 3 bytes needed + // Transposed layout: + // Row 0: 1, 2, 3 + // Row 1: 2, 3, 1 + // Row 2: 3, 1, 2 + // Row 3: 1, 2, 3 + // Flattened: [1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3] + std::vector expected_shape({4, 3}); + std::vector expected_vals = {UInt2x4(1, 2, 3, 2), UInt2x4(3, 1, 3, 1), UInt2x4(2, 1, 2, 3)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test UInt2 transpose with inner dimension % 4 == 1 +TEST(TransposeOpTest, TwoDim_UInt2_Mod4_1) { + // Shape (3, 5): 15 elements, 4 bytes needed, 1 padding + std::vector input_shape({3, 5}); + // Input layout (row-major flattened): + // Row 0: 1, 2, 3, 1, 2 + // Row 1: 3, 1, 2, 3, 1 + // Row 2: 2, 3, 1, 2, 3 + // Flattened: [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 0(pad)] + std::vector input_vals = {UInt2x4(1, 2, 3, 1), UInt2x4(2, 3, 1, 2), + UInt2x4(3, 1, 2, 3), UInt2x4(1, 2, 3, 0)}; + + std::vector perm = {1, 0}; + // Transposed shape (5, 3): 15 elements, 4 bytes needed + // Transposed layout: + // Row 0: 1, 3, 2 + // Row 1: 2, 1, 3 + // Row 2: 3, 2, 1 + // Row 3: 1, 3, 2 + // Row 4: 2, 1, 3 + // Flattened: [1, 3, 2, 2, 1, 3, 3, 2, 1, 1, 3, 2, 2, 1, 3, 0(pad)] + std::vector expected_shape({5, 3}); + std::vector expected_vals = {UInt2x4(1, 3, 2, 2), UInt2x4(1, 3, 3, 2), + UInt2x4(1, 1, 3, 2), UInt2x4(2, 1, 3, 0)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test UInt2 transpose with inner dimension % 4 == 2 +TEST(TransposeOpTest, TwoDim_UInt2_Mod4_2) { + // Shape (3, 6): 18 elements, 5 bytes needed, 2 padding + std::vector input_shape({3, 6}); + // Input layout (row-major flattened): + // Row 0: 1, 2, 3, 1, 2, 3 + // Row 1: 2, 3, 1, 2, 3, 1 + // Row 2: 3, 1, 2, 3, 1, 2 + // Flattened: [1, 2, 3, 1, 2, 3, 2, 3, 1, 2, 3, 1, 3, 1, 2, 3, 1, 2, 0(pad), 0(pad)] + std::vector input_vals = {UInt2x4(1, 2, 3, 1), UInt2x4(2, 3, 2, 3), + UInt2x4(1, 2, 3, 1), UInt2x4(3, 1, 2, 3), + UInt2x4(1, 2, 0, 0)}; + + std::vector perm = {1, 0}; + // Transposed shape (6, 3): 18 elements, 5 bytes needed + // Transposed layout: + // Row 0: 1, 2, 3 + // Row 1: 2, 3, 1 + // Row 2: 3, 1, 2 + // Row 3: 1, 2, 3 + // Row 4: 2, 3, 1 + // Row 5: 3, 1, 2 + // Flattened: [1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3, 2, 3, 1, 3, 1, 2, 0(pad), 0(pad)] + std::vector expected_shape({6, 3}); + std::vector expected_vals = {UInt2x4(1, 2, 3, 2), UInt2x4(3, 1, 3, 1), + UInt2x4(2, 1, 2, 3), UInt2x4(2, 3, 1, 3), + UInt2x4(1, 2, 0, 0)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + +// Test UInt2 transpose with inner dimension % 4 == 3 +TEST(TransposeOpTest, TwoDim_UInt2_Mod4_3) { + // Shape (3, 7): 21 elements, 6 bytes needed, 3 padding + std::vector input_shape({3, 7}); + // Input layout (row-major flattened): + // Row 0: 1, 2, 3, 1, 2, 3, 1 + // Row 1: 2, 3, 1, 2, 3, 1, 2 + // Row 2: 3, 1, 2, 3, 1, 2, 3 + // Flattened: [1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3, 0, 0, 0] + std::vector input_vals = {UInt2x4(1, 2, 3, 1), UInt2x4(2, 3, 1, 2), + UInt2x4(3, 1, 2, 3), UInt2x4(1, 2, 3, 1), + UInt2x4(2, 3, 1, 2), UInt2x4(3, 0, 0, 0)}; + + std::vector perm = {1, 0}; + // Transposed shape (7, 3): 21 elements, 6 bytes needed + // Transposed layout: + // Row 0: 1, 2, 3 + // Row 1: 2, 3, 1 + // Row 2: 3, 1, 2 + // Row 3: 1, 2, 3 + // Row 4: 2, 3, 1 + // Row 5: 3, 1, 2 + // Row 6: 1, 2, 3 + // Flattened: [1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3, 2, 3, 1, 3, 1, 2, 1, 2, 3, 0, 0, 0] + std::vector expected_shape({7, 3}); + std::vector expected_vals = {UInt2x4(1, 2, 3, 2), UInt2x4(3, 1, 3, 1), + UInt2x4(2, 1, 2, 3), UInt2x4(2, 3, 1, 3), + UInt2x4(1, 2, 1, 2), UInt2x4(3, 0, 0, 0)}; + + OpTester test("Transpose", 25); + test.AddAttribute("perm", perm); + test.AddInput("X", input_shape, input_vals); + test.AddOutput("Y", expected_shape, expected_vals); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); +} + TEST(TransposeOpTest, TwoDim_double) { std::vector input_shape({2, 3}); std::vector input_vals = {1.0, 2.0, 3.0, diff --git a/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc b/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc index ec7e98528504e..593255b9e9c23 100644 --- a/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc +++ b/onnxruntime/test/providers/cuda/test_cases/reduction_functions_test.cc @@ -177,6 +177,35 @@ void TestReduceColumnsToColumn(int m, int n, float relative_error_tolerance = 1e CheckDeviceValues(m, d_out.get(), expected_column.data(), relative_error_tolerance); } + +void TestReduceColumnsToColumnRepeated(int m, int n, int iterations, float relative_error_tolerance = 1e-4f) { + SCOPED_TRACE(MakeString("m: ", m, ", n:", n, ", iterations: ", iterations)); + + const TensorShape shape{m, n}; + RandomValueGenerator random{}; + const auto values = random.Uniform(shape.GetDims(), 1.0f, 10.0f); + const auto expected_column = ExpectedReduceMatrixColumnsOutput(m, n, values); + + auto d_in = AllocateDeviceMemory(m * n); + auto d_out = AllocateDeviceMemory(m); + + cudaMemcpy(d_in.get(), values.data(), m * n * sizeof(float), cudaMemcpyHostToDevice); + + size_t buffer_size_in_bytes = + compute_reduce_matrix_columns_buffer_size(m, n); + auto d_buffer = AllocateDeviceMemory(buffer_size_in_bytes); + + for (int i = 0; i < iterations; ++i) { + ASSERT_STATUS_OK(reduce_matrix_columns( + 0, + d_in.get(), d_out.get(), + m, n, + d_buffer.get(), buffer_size_in_bytes)); + + ASSERT_TRUE(CUDA_CALL(cudaDeviceSynchronize()).IsOK()); + CheckDeviceValues(m, d_out.get(), expected_column.data(), relative_error_tolerance); + } +} } // namespace TEST(ReductionFunctionsTest, ReduceRowToScalar) { @@ -205,6 +234,10 @@ TEST(ReductionFunctionsTest, ReduceColumnsToColumn) { } } +TEST(ReductionFunctionsTest, ReduceColumnsToColumnRepeated) { + TestReduceColumnsToColumnRepeated(17, 8192, 100, 2e-4f); +} + TEST(ReductionFunctionsTest, BufferOffsets) { const int m = 2048; const int n = 1024; diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc index 1a987ab4f411a..f017c86824df6 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc @@ -275,7 +275,6 @@ INSTANTIATE_TEST_SUITE_P(NvExecutionProviderTest, TypeTests, ), [](const testing::TestParamInfo& info) { return getTypeAsName(info.param); }); -#ifdef _WIN32 static bool SessionHasEp(Ort::Session& session, const char* ep_name) { // Access the underlying InferenceSession. const OrtSession* ort_session = session; @@ -292,7 +291,6 @@ static bool SessionHasEp(Ort::Session& session, const char* ep_name) { } // Tests autoEP feature to automatically select an EP that supports the GPU. -// Currently only works on Windows. TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { PathString model_name = ORT_TSTR("nv_execution_provider_auto_ep.onnx"); std::string graph_name = "test"; @@ -302,7 +300,11 @@ TEST(NvExecutionProviderTest, AutoEp_PreferGpu) { CreateBaseModel(model_name, graph_name, dims); { +#if _WIN32 ort_env->RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("onnxruntime_providers_nv_tensorrt_rtx.dll")); +#else + ort_env->RegisterExecutionProviderLibrary(kNvTensorRTRTXExecutionProvider, ORT_TSTR("libonnxruntime_providers_nv_tensorrt_rtx.so")); +#endif Ort::SessionOptions so; so.SetEpSelectionPolicy(OrtExecutionProviderDevicePolicy_PREFER_GPU); @@ -599,7 +601,5 @@ TEST(NvExecutionProviderTest, FP4CustomOpModel) { LOGS_DEFAULT(INFO) << "[NvExecutionProviderTest] TRT FP4 dynamic quantize model run completed successfully"; } -#endif - } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc index ac24dcb70c1dd..bcdfd18407ca8 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/nv_ep_context_test.cc @@ -14,7 +14,6 @@ namespace test { RegisteredEpDeviceUniquePtr AppendTrtEtxEP(Ort::SessionOptions& session_options, std::unordered_map& option_map) { RegisteredEpDeviceUniquePtr nv_tensorrt_rtx_ep; -#ifdef _WIN32 /// Since this test runs after other tests that use registration interface this test has to use it as well /// windows as otherwise the kernel registry inside the EP will not be populated. The legacy APis ony call the initialize once. Utils::RegisterAndGetNvTensorRtRtxEp(*ort_env, nv_tensorrt_rtx_ep); @@ -26,9 +25,6 @@ RegisteredEpDeviceUniquePtr AppendTrtEtxEP(Ort::SessionOptions& session_options, } } session_options.AppendExecutionProvider_V2(*ort_env, {selected_device}, option_map); -#else - session_options.AppendExecutionProvider(onnxruntime::kNvTensorRTRTXExecutionProvider, option_map); -#endif return nv_tensorrt_rtx_ep; } diff --git a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc index 47127399b4646..de028bf613a27 100644 --- a/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc +++ b/onnxruntime/test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.cc @@ -24,7 +24,6 @@ namespace onnxruntime { namespace test { -#ifdef _WIN32 Utils::NvTensorRtRtxEpInfo Utils::nv_tensorrt_rtx_ep_info; @@ -61,7 +60,6 @@ void Utils::RegisterAndGetNvTensorRtRtxEp(Ort::Env& env, RegisteredEpDeviceUniqu c_api.UnregisterExecutionProviderLibrary(env, nv_tensorrt_rtx_ep_info.registration_name.c_str()); }); } -#endif // _WIN32 void CreateBaseModel(const PathString& model_name, std::string graph_name, diff --git a/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc b/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc index 74d4172cc234c..85088e65d2db3 100644 --- a/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc +++ b/onnxruntime/test/providers/openvino/openvino_ep_ext_init.cc @@ -4,6 +4,7 @@ #include #include #include +#include #include "core/session/onnxruntime_cxx_api.h" @@ -23,16 +24,16 @@ class OVEP_ExtInit_Tests : public ::testing::TestWithParam {}; namespace { -std::vector LoadFileToMemory(const std::string& path) { +std::optional> LoadFileToMemory(const std::string& path) { std::ifstream file(path, std::ios::binary | std::ios::ate); if (!file.is_open()) { - return std::vector(); + return std::nullopt; } std::streamsize size = file.tellg(); file.seekg(0, std::ios::beg); std::vector buffer(static_cast(size)); if (!file.read(reinterpret_cast(buffer.data()), size)) { - return std::vector(); + return std::nullopt; } return buffer; } @@ -57,8 +58,7 @@ auto ProbeDevice(const std::string& device) { namespace onnxruntime { namespace test { -// this test requires OV 2025.4+ to run -TEST_P(OVEP_ExtInit_Tests, DISABLED_ModelFromExtInit) { +TEST_P(OVEP_ExtInit_Tests, ModelFromExtInit) { const auto& device = GetParam(); if (!ProbeDevice(device)) GTEST_SKIP() << device + " is not available on this machine"; @@ -161,14 +161,15 @@ TEST_P(OVEP_ExtInit_Tests, DISABLED_ModelFromExtInit) { } // 4. Load model and weights into memory - std::vector model_data = LoadFileToMemory(model_path); - std::vector weights_data = LoadFileToMemory(weights_path); + auto model_data = LoadFileToMemory(model_path); + auto weights_data = LoadFileToMemory(weights_path); + ASSERT_TRUE(model_data.has_value() && weights_data.has_value()); // 5. Prepare external initializer info PathString weights_name_path(weights_path.begin(), weights_path.end()); std::vector names_path = {weights_name_path}; - std::vector buffers = {reinterpret_cast(weights_data.data())}; - std::vector buffer_sizes = {weights_data.size()}; + std::vector buffers = {reinterpret_cast(weights_data.value().data())}; + std::vector buffer_sizes = {weights_data.value().size()}; // 6. Set up session options with OpenVINO Ort::SessionOptions session_options; @@ -179,7 +180,7 @@ TEST_P(OVEP_ExtInit_Tests, DISABLED_ModelFromExtInit) { session_options.AddExternalInitializersFromFilesInMemory(names_path, buffers, buffer_sizes); // 7. Create session from memory - Ort::Session session(*ort_env, model_data.data(), model_data.size(), session_options); + Ort::Session session(*ort_env, model_data.value().data(), model_data.value().size(), session_options); // 8. Run inference to verify weights are loaded std::vector input_data(floats_per_initializer, 2.0f); diff --git a/onnxruntime/test/providers/qnn/batch_norm_test.cc b/onnxruntime/test/providers/qnn/batch_norm_test.cc index cb48506be9f62..79c16dcde07c3 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_test.cc @@ -3,8 +3,10 @@ #if !defined(ORT_MINIMAL_BUILD) +#include #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "core/common/float16.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -414,6 +416,104 @@ TEST_F(QnnHTPBackendTests, BatchNorm3D) { ExpectedEPNodeAssignment::None); } +// Tests BatchNorm with Q->DQ structure commonly seen in quantized models +template +GetTestQDQModelFn BuildBatchNormQdqParamsTestCase(const TestInputDef& input_def, + const TestInputDef& scale_def, + const TestInputDef& bias_def) { + ORT_ENFORCE(input_def.IsRawData()); + ORT_ENFORCE(scale_def.IsRawData()); + + return [input_def, scale_def, bias_def](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const auto& input_shape = input_def.GetShape(); + const auto& input_data = input_def.GetRawData(); + const int64_t num_channels = input_shape[1]; + + // Input: float -> Q -> DQ + bool symmetric = sizeof(InputQType) == sizeof(uint16_t); + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def, symmetric); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); + + NodeAttributes axis_0_attrs; + utils::SetNodeAttribute(utils::MakeAttribute("axis", static_cast(0)), axis_0_attrs); + + // Scale: float_init -> Q -> DQ (per-channel with axis=0, symmetric) + const auto& scale_data = scale_def.GetRawData(); + std::vector scale_scales(num_channels); + std::vector scale_zero_points(num_channels, static_cast(0)); + for (int64_t c = 0; c < num_channels; ++c) { + float abs_max = std::abs(scale_data[c]); + if (abs_max == 0.0f) abs_max = 1.0f; + scale_scales[c] = abs_max / static_cast(std::numeric_limits::max()); + } + std::vector param_shape = {num_channels}; + NodeArg* scale_float_init = builder.MakeInitializer(param_shape, scale_data); + NodeArg* scale_qdq = AddQDQNodePair(builder, scale_float_init, scale_scales, scale_zero_points, + &axis_0_attrs, &axis_0_attrs); + + NodeArg* bias = builder.MakeInitializer(bias_def.GetShape(), bias_def.GetRawData()); + + // Compute mean and var from input data + std::vector mean_vals(num_channels); + std::vector var_vals(num_channels); + ComputeChannelMeanAndVar(input_data, input_shape, mean_vals, var_vals); + + // Mean: float_init -> Q -> DQ (per-channel with axis=0, symmetric) + std::vector mean_scales(num_channels); + std::vector mean_zero_points(num_channels, static_cast(0)); + for (int64_t c = 0; c < num_channels; ++c) { + float abs_max = std::abs(mean_vals[c]); + if (abs_max == 0.0f) abs_max = 1.0f; + mean_scales[c] = abs_max / static_cast(std::numeric_limits::max()); + } + NodeArg* mean_float_init = builder.MakeInitializer(param_shape, mean_vals); + NodeArg* mean_qdq = AddQDQNodePair(builder, mean_float_init, mean_scales, mean_zero_points, + &axis_0_attrs, &axis_0_attrs); + + // Var: float_init -> Q -> DQ (per-channel with axis=0, symmetric) + std::vector var_scales(num_channels); + std::vector var_zero_points(num_channels, static_cast(0)); + for (int64_t c = 0; c < num_channels; ++c) { + float abs_max = std::abs(var_vals[c]); + if (abs_max == 0.0f) abs_max = 1.0f; + var_scales[c] = abs_max / static_cast(std::numeric_limits::max()); + } + NodeArg* var_float_init = builder.MakeInitializer(param_shape, var_vals); + NodeArg* var_qdq = AddQDQNodePair(builder, var_float_init, var_scales, var_zero_points, + &axis_0_attrs, &axis_0_attrs); + + auto* batchnorm_output = builder.MakeIntermediate(); + builder.AddNode("BatchNormalization", {input_qdq, scale_qdq, bias, mean_qdq, var_qdq}, + {batchnorm_output}); + + AddQDQNodePairWithOutputAsGraphOutput(builder, batchnorm_output, + output_qparams[0].scale, output_qparams[0].zero_point); + }; +} + +// Test BatchNorm with Q->DQ on input/scale/mean/var, float bias +TEST_F(QnnHTPBackendTests, BatchNorm2dQdqParams) { + constexpr int64_t num_channels = 2; + std::vector input_data = {-8.0f, -6.0f, -4.0f, -2.0f, 0.0f, 1.1f, 3.3f, 8.0f, + -7.0f, -5.0f, -3.0f, -1.0f, 0.0f, 2.1f, 4.3f, 7.0f}; + + TestInputDef input_def({2, num_channels, 2, 2}, false, input_data); + TestInputDef scale_def({num_channels}, true, {1.0f, 2.0f}); + TestInputDef bias_def({num_channels}, true, {1.1f, 2.1f}); + + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + + TestQDQModelAccuracy(BuildBatchNormTestCase(input_def, scale_def, bias_def), + BuildBatchNormQdqParamsTestCase(input_def, scale_def, bias_def), + provider_options, + 21, + ExpectedEPNodeAssignment::All); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/providers/qnn/bf16_op_test.cc b/onnxruntime/test/providers/qnn/bf16_op_test.cc new file mode 100644 index 0000000000000..1c2a7fa2c0720 --- /dev/null +++ b/onnxruntime/test/providers/qnn/bf16_op_test.cc @@ -0,0 +1,350 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/onnx_protobuf.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Helper function to create a simple Add model for BF16 testing +[[maybe_unused]] static GetTestModelFn BuildBF16AddTestCase(const TestInputDef& input1_def, + const TestInputDef& input2_def) { + return [input1_def, input2_def](ModelTestBuilder& builder) { + NodeArg* input1 = MakeTestInput(builder, input1_def); + NodeArg* input2 = MakeTestInput(builder, input2_def); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Add", {input1, input2}, {output}); + }; +} + +// Helper function to create a simple MatMul model for BF16 testing +[[maybe_unused]] static GetTestModelFn BuildBF16MatMulTestCase(const TestInputDef& input1_def, + const TestInputDef& input2_def) { + return [input1_def, input2_def](ModelTestBuilder& builder) { + NodeArg* input1 = MakeTestInput(builder, input1_def); + NodeArg* input2 = MakeTestInput(builder, input2_def); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("MatMul", {input1, input2}, {output}); + }; +} + +// Helper function to create a Conv model for BF16 testing +[[maybe_unused]] static GetTestModelFn BuildBF16ConvTestCase(const TestInputDef& input_def, + const TestInputDef& weights_def) { + return [input_def, weights_def](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* weights = MakeTestInput(builder, weights_def); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Conv", {input, weights}, {output}); + }; +} + +// Helper function to run BF16 model test +[[maybe_unused]] static void RunBF16ModelTest(const GetTestModelFn& build_test_case, + const std::vector& input_shape, + ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All, + int opset = 18, + float fp32_abs_err = 1e-2f) { + ORT_UNUSED_PARAMETER(input_shape); + + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["htp_bf16_enable"] = "1"; // Enable BF16 mode + provider_options["soc_id"] = "88"; // Target SOC ID for BF16 support + provider_options["offload_graph_io_quantization"] = "0"; + + RunQnnModelTest(build_test_case, provider_options, opset, expected_ep_assignment, fp32_abs_err); +} + +#if defined(__aarch64__) || defined(_M_ARM64) + +// +// HTP BF16 tests: +// + +// Test BF16 handling with Add operator - both inputs dynamic +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Add_DynamicInputs) { + std::vector shape = {2, 3, 4}; + RunBF16ModelTest( + BuildBF16AddTestCase( + TestInputDef(shape, false, GetSequentialFloatData(shape, 0.0f, 0.1f)), + TestInputDef(shape, false, GetSequentialFloatData(shape, 0.1f, 0.1f))), + shape); +} + +// Test BF16 handling with Add operator - one input static (initializer) +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Add_StaticInput) { + std::vector shape = {2, 3, 4}; + RunBF16ModelTest( + BuildBF16AddTestCase( + TestInputDef(shape, false, GetSequentialFloatData(shape, 0.0f, 0.1f)), + TestInputDef(shape, true, GetSequentialFloatData(shape, 0.1f, 0.1f))), + shape); +} + +// Test BF16 handling with Add operator - both inputs static +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Add_BothStatic) { + std::vector shape = {2, 3, 4}; + RunBF16ModelTest( + BuildBF16AddTestCase( + TestInputDef(shape, true, GetSequentialFloatData(shape, 0.0f, 0.1f)), + TestInputDef(shape, true, GetSequentialFloatData(shape, 0.1f, 0.1f))), + shape); +} + +// Test BF16 handling with MatMul operator - dynamic inputs +TEST_F(QnnHTPBackendTests, DISABLED_BF16_MatMul_DynamicInputs) { + RunBF16ModelTest( + BuildBF16MatMulTestCase( + TestInputDef({2, 3}, false, GetSequentialFloatData({2, 3}, 0.0f, 0.1f)), + TestInputDef({3, 4}, false, GetSequentialFloatData({3, 4}, 0.1f, 0.1f))), + {2, 3}); +} + +// Test BF16 handling with MatMul operator - static weight +TEST_F(QnnHTPBackendTests, DISABLED_BF16_MatMul_StaticWeight) { + RunBF16ModelTest( + BuildBF16MatMulTestCase( + TestInputDef({2, 3}, false, GetSequentialFloatData({2, 3}, 0.0f, 0.1f)), + TestInputDef({3, 4}, true, GetSequentialFloatData({3, 4}, 0.1f, 0.1f))), + {2, 3}); +} + +// Test BF16 handling with MatMul operator - batched inputs +TEST_F(QnnHTPBackendTests, DISABLED_BF16_MatMul_BatchedInputs) { + RunBF16ModelTest( + BuildBF16MatMulTestCase( + TestInputDef({2, 3, 4}, false, GetSequentialFloatData({2, 3, 4}, 0.0f, 0.1f)), + TestInputDef({4, 5}, false, GetSequentialFloatData({4, 5}, 0.1f, 0.1f))), + {2, 3, 4}); +} + +// Test BF16 handling with Conv operator - dynamic input +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Conv_DynamicInput) { + std::vector input_shape = {1, 3, 8, 8}; + std::vector weights_shape = {16, 3, 3, 3}; + + RunBF16ModelTest( + BuildBF16ConvTestCase( + TestInputDef(input_shape, false, GetSequentialFloatData(input_shape, 0.0f, 0.01f)), + TestInputDef(weights_shape, true, GetSequentialFloatData(weights_shape, -0.1f, 0.01f))), + input_shape); +} + +// Test BF16 handling with Conv operator - larger input +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Conv_LargerInput) { + std::vector input_shape = {1, 64, 32, 32}; + std::vector weights_shape = {128, 64, 3, 3}; + + RunBF16ModelTest( + BuildBF16ConvTestCase( + TestInputDef(input_shape, false, GetSequentialFloatData(input_shape, 0.0f, 0.001f)), + TestInputDef(weights_shape, true, GetSequentialFloatData(weights_shape, -0.05f, 0.001f))), + input_shape, + ExpectedEPNodeAssignment::All, + 18, + 1e-1f); // Larger tolerance for bigger models +} + +// Test BF16 handling with multiple operations in sequence +static GetTestModelFn BuildBF16MultiOpTestCase() { + return [](ModelTestBuilder& builder) { + std::vector shape = {2, 3, 4}; + + // Create inputs + NodeArg* input1 = MakeTestInput(builder, TestInputDef(shape, false, GetSequentialFloatData(shape, 0.0f, 0.1f))); + NodeArg* input2 = MakeTestInput(builder, TestInputDef(shape, false, GetSequentialFloatData(shape, 0.1f, 0.1f))); + NodeArg* input3 = MakeTestInput(builder, TestInputDef(shape, false, GetSequentialFloatData(shape, 0.2f, 0.1f))); + + // Add1: input1 + input2 + NodeArg* add1_output = builder.MakeIntermediate(); + builder.AddNode("Add", {input1, input2}, {add1_output}); + + // Add2: add1_output + input3 + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Add", {add1_output, input3}, {output}); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_MultipleOps) { + std::vector shape = {2, 3, 4}; + RunBF16ModelTest(BuildBF16MultiOpTestCase(), shape); +} + +// Test BF16 handling with graph that has multiple outputs +static GetTestModelFn BuildBF16MultiOutputTestCase() { + return [](ModelTestBuilder& builder) { + std::vector shape = {2, 3, 4}; + + // Create inputs + NodeArg* input1 = MakeTestInput(builder, TestInputDef(shape, false, GetSequentialFloatData(shape, 0.0f, 0.1f))); + NodeArg* input2 = MakeTestInput(builder, TestInputDef(shape, false, GetSequentialFloatData(shape, 0.1f, 0.1f))); + + // Add: input1 + input2 -> output1 + NodeArg* output1 = builder.MakeOutput(); + builder.AddNode("Add", {input1, input2}, {output1}); + + // Mul: input1 * input2 -> output2 + NodeArg* output2 = builder.MakeOutput(); + builder.AddNode("Mul", {input1, input2}, {output2}); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_MultipleOutputs) { + std::vector shape = {2, 3, 4}; + RunBF16ModelTest(BuildBF16MultiOutputTestCase(), shape); +} + +// Test BF16 handling with Relu activation +static GetTestModelFn BuildBF16ReluTestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Relu", {input}, {output}); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Relu) { + std::vector shape = {2, 3, 4, 5}; + RunBF16ModelTest( + BuildBF16ReluTestCase( + TestInputDef(shape, false, GetSequentialFloatData(shape, -1.0f, 0.1f))), + shape); +} + +// Test BF16 handling with Sigmoid activation +static GetTestModelFn BuildBF16SigmoidTestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Sigmoid", {input}, {output}); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Sigmoid) { + std::vector shape = {2, 3, 4}; + RunBF16ModelTest( + BuildBF16SigmoidTestCase( + TestInputDef(shape, false, GetSequentialFloatData(shape, -2.0f, 0.2f))), + shape); +} + +// Test BF16 handling with Softmax +static GetTestModelFn BuildBF16SoftmaxTestCase(const TestInputDef& input_def, int64_t axis) { + return [input_def, axis](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* output = builder.MakeOutput(); + Node& node = builder.AddNode("Softmax", {input}, {output}); + node.AddAttribute("axis", axis); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Softmax) { + std::vector shape = {2, 3, 4}; + RunBF16ModelTest( + BuildBF16SoftmaxTestCase( + TestInputDef(shape, false, GetSequentialFloatData(shape, 0.0f, 0.1f)), + -1), + shape); +} + +// Test BF16 handling with Transpose +static GetTestModelFn BuildBF16TransposeTestCase(const TestInputDef& input_def, + const std::vector& perm) { + return [input_def, perm](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* output = builder.MakeOutput(); + Node& node = builder.AddNode("Transpose", {input}, {output}); + node.AddAttribute("perm", perm); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Transpose) { + std::vector shape = {2, 3, 4, 5}; + std::vector perm = {0, 2, 1, 3}; + RunBF16ModelTest( + BuildBF16TransposeTestCase( + TestInputDef(shape, false, GetSequentialFloatData(shape, 0.0f, 0.1f)), + perm), + shape); +} + +// Test BF16 handling with Reshape +static GetTestModelFn BuildBF16ReshapeTestCase(const TestInputDef& input_def, + const std::vector& new_shape) { + return [input_def, new_shape](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* shape_input = builder.MakeInitializer({static_cast(new_shape.size())}, new_shape); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Reshape", {input, shape_input}, {output}); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Reshape) { + std::vector input_shape = {2, 3, 4}; + std::vector output_shape = {6, 4}; + RunBF16ModelTest( + BuildBF16ReshapeTestCase( + TestInputDef(input_shape, false, GetSequentialFloatData(input_shape, 0.0f, 0.1f)), + output_shape), + input_shape); +} + +// Test BF16 handling with Concat +static GetTestModelFn BuildBF16ConcatTestCase(const std::vector>& input_defs, int64_t axis) { + return [input_defs, axis](ModelTestBuilder& builder) { + std::vector inputs; + for (const auto& input_def : input_defs) { + inputs.push_back(MakeTestInput(builder, input_def)); + } + NodeArg* output = builder.MakeOutput(); + Node& node = builder.AddNode("Concat", inputs, {output}); + node.AddAttribute("axis", axis); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Concat) { + std::vector shape1 = {2, 3, 4}; + std::vector shape2 = {2, 5, 4}; + std::vector> input_defs = { + TestInputDef(shape1, false, GetSequentialFloatData(shape1, 0.0f, 0.1f)), + TestInputDef(shape2, false, GetSequentialFloatData(shape2, 0.5f, 0.1f))}; + RunBF16ModelTest(BuildBF16ConcatTestCase(input_defs, 1), shape1); +} + +// Test BF16 handling with Split +static GetTestModelFn BuildBF16SplitTestCase(const TestInputDef& input_def, int64_t axis, int64_t num_outputs) { + return [input_def, axis, num_outputs](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + std::vector outputs; + for (int64_t i = 0; i < num_outputs; i++) { + outputs.push_back(builder.MakeOutput()); + } + Node& node = builder.AddNode("Split", {input}, outputs); + node.AddAttribute("axis", axis); + }; +} + +TEST_F(QnnHTPBackendTests, DISABLED_BF16_Split) { + std::vector shape = {2, 6, 4}; + RunBF16ModelTest( + BuildBF16SplitTestCase( + TestInputDef(shape, false, GetSequentialFloatData(shape, 0.0f, 0.1f)), + 1, + 2), + shape); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/fused_matmul_op_test.cc b/onnxruntime/test/providers/qnn/fused_matmul_op_test.cc new file mode 100644 index 0000000000000..839097cccd4f6 --- /dev/null +++ b/onnxruntime/test/providers/qnn/fused_matmul_op_test.cc @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include "core/graph/constants.h" +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a FusedMatMul operator on the QNN backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFusedMatMulTest(const TestInputDef& input_a_def, + const TestInputDef& input_b_def, + bool transA, + bool transB, + bool transBatchA = false, + bool transBatchB = false, + float alpha = 1.0f, + ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All, + const std::string& backend_name = "cpu") { + ProviderOptions provider_options; + provider_options["backend_type"] = backend_name; + + if (backend_name == "htp") { + provider_options["enable_htp_fp16_precision"] = "1"; + } + + auto model_builder = [input_a_def, input_b_def, transA, transB, transBatchA, transBatchB, alpha](ModelTestBuilder& builder) { + NodeArg* input_a = MakeTestInput(builder, input_a_def); + NodeArg* input_b = MakeTestInput(builder, input_b_def); + std::vector inputs = {input_a, input_b}; + + auto* output = builder.MakeOutput(); + + Node& node = builder.AddNode("FusedMatMul", inputs, {output}, kMSDomain); + node.AddAttribute("transA", static_cast(transA)); + node.AddAttribute("transB", static_cast(transB)); + node.AddAttribute("transBatchA", static_cast(transBatchA)); + node.AddAttribute("transBatchB", static_cast(transBatchB)); + node.AddAttribute("alpha", alpha); + }; + + RunQnnModelTest(model_builder, + provider_options, + 13, // opset version for contrib ops + expected_ep_assignment, + 5e-3f); +} + +// Tests the accuracy of a QDQ FusedMatMul model on QNN EP by comparing to CPU EP. +template +static void RunQDQFusedMatMulTest(const TestInputDef& input_a_def, + const TestInputDef& input_b_def, + bool transA, + bool transB, + bool transBatchA = false, + bool transBatchB = false, + float alpha = 1.0f, + ExpectedEPNodeAssignment expected_ep_assignment = ExpectedEPNodeAssignment::All, + const std::string& backend_name = "htp", + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + provider_options["backend_type"] = backend_name; + provider_options["offload_graph_io_quantization"] = "0"; + + GetTestModelFn model_builder_fn = [input_a_def, input_b_def, transA, transB, transBatchA, transBatchB, alpha](ModelTestBuilder& builder) { + NodeArg* input_a = MakeTestInput(builder, input_a_def); + NodeArg* input_b = MakeTestInput(builder, input_b_def); + std::vector inputs = {input_a, input_b}; + + auto* output = builder.MakeOutput(); + + Node& node = builder.AddNode("FusedMatMul", inputs, {output}, kMSDomain); + node.AddAttribute("transA", static_cast(transA)); + node.AddAttribute("transB", static_cast(transB)); + node.AddAttribute("transBatchA", static_cast(transBatchA)); + node.AddAttribute("transBatchB", static_cast(transBatchB)); + node.AddAttribute("alpha", alpha); + }; + + GetTestQDQModelFn qdq_model_builder_fn = [input_a_def, input_b_def, transA, transB, transBatchA, transBatchB, alpha, use_contrib_qdq]( + ModelTestBuilder& builder, std::vector>& output_qparams) { + // Process input A with QDQ + NodeArg* input_a = MakeTestInput(builder, input_a_def); + QuantParams input_a_qparams = GetTestInputQuantParams(input_a_def); + NodeArg* input_a_qdq = AddQDQNodePair(builder, input_a, input_a_qparams.scale, + input_a_qparams.zero_point, use_contrib_qdq); + + // Process input B with QDQ + NodeArg* input_b = MakeTestInput(builder, input_b_def); + QuantParams input_b_qparams = GetTestInputQuantParams(input_b_def); + NodeArg* input_b_qdq = AddQDQNodePair(builder, input_b, input_b_qparams.scale, + input_b_qparams.zero_point, use_contrib_qdq); + + std::vector inputs = {input_a_qdq, input_b_qdq}; + + // FusedMatMul -> op_output + auto* op_output = builder.MakeIntermediate(); + Node& node = builder.AddNode("FusedMatMul", inputs, {op_output}, kMSDomain); + node.AddAttribute("transA", static_cast(transA)); + node.AddAttribute("transB", static_cast(transB)); + node.AddAttribute("transBatchA", static_cast(transBatchA)); + node.AddAttribute("transBatchB", static_cast(transBatchB)); + node.AddAttribute("alpha", alpha); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; + + TestQDQModelAccuracy(model_builder_fn, + qdq_model_builder_fn, + provider_options, + 13, // opset version for contrib ops + expected_ep_assignment, + QDQTolerance(5e-3f)); +} + +// +// CPU tests: +// + +// Test FusedMatMul with default attributes (no transpose, alpha=1.0, no activation) +TEST_F(QnnCPUBackendTests, FusedMatMul_Default) { + RunFusedMatMulTest( + TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input A + TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input B + false, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test FusedMatMul with transpose A +TEST_F(QnnCPUBackendTests, FusedMatMul_TransposeA) { + RunFusedMatMulTest( + TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input A + TestInputDef({3, 4}, false, GetFloatDataInRange(-1.0f, 1.0f, 12)), // input B + true, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test FusedMatMul with transpose B +TEST_F(QnnCPUBackendTests, FusedMatMul_TransposeB) { + RunFusedMatMulTest( + TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input A + TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input B + false, // transA + true, // transB + false, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test FusedMatMul with custom alpha +TEST_F(QnnCPUBackendTests, FusedMatMul_CustomAlpha) { + RunFusedMatMulTest( + TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input A + TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input B + false, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 0.5f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test FusedMatMul with all features combined +TEST_F(QnnCPUBackendTests, DISABLED_FusedMatMul_Combined) { + RunFusedMatMulTest( + TestInputDef({2, 4, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 24)), // input A + TestInputDef({3, 4, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 12)), // input B - adjusted shape for compatibility + true, // transA + true, // transB + true, // transBatchA + true, // transBatchB + 0.5f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test FusedMatMul with higher rank tensors +TEST_F(QnnCPUBackendTests, FusedMatMul_HigherRank) { + RunFusedMatMulTest( + TestInputDef({2, 3, 4}, false, GetFloatDataInRange(-1.0f, 1.0f, 24)), // input A + TestInputDef({2, 4, 5}, false, GetFloatDataInRange(-1.0f, 1.0f, 40)), // input B + false, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test FusedMatMul with batch dimension transposition +TEST_F(QnnCPUBackendTests, FusedMatMul_BatchTranspose) { + RunFusedMatMulTest( + TestInputDef({2, 2, 4}, false, GetFloatDataInRange(-1.0f, 1.0f, 16)), // input A + TestInputDef({2, 4, 5}, false, GetFloatDataInRange(-1.0f, 1.0f, 40)), // input B + false, // transA + false, // transB + true, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Test FusedMatMul with default attributes on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_Default) { + RunFusedMatMulTest( + TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input A + TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input B + false, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +// Test FusedMatMul with float16 inputs and custom alpha on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_Float16_CustomAlpha) { + RunFusedMatMulTest( + ConvertToFP16InputDef(TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6))), // input A + ConvertToFP16InputDef(TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6))), // input B + false, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 0.5f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +// Test FusedMatMul with float16 inputs, transpose, and custom alpha on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_Float16_TransposeA_CustomAlpha) { + RunFusedMatMulTest( + ConvertToFP16InputDef(TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6))), // input A + ConvertToFP16InputDef(TestInputDef({3, 4}, false, GetFloatDataInRange(-1.0f, 1.0f, 12))), // input B + true, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 1.702f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +// Test FusedMatMul with batch dimension transposition on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_BatchTranspose) { + RunFusedMatMulTest( + TestInputDef({2, 2, 4}, false, GetFloatDataInRange(-1.0f, 1.0f, 16)), // input A + TestInputDef({2, 4, 5}, false, GetFloatDataInRange(-1.0f, 1.0f, 40)), // input B + false, // transA + false, // transB + true, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +// Test 8-bit QDQ FusedMatMul with default attributes on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_QDQ_U8_Default) { + RunQDQFusedMatMulTest( + TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input A + TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input B + false, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ FusedMatMul with batch dimension transposition on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_QDQ_U8_BatchTranspose) { + RunQDQFusedMatMulTest( + TestInputDef({2, 2, 4}, false, GetFloatDataInRange(-1.0f, 1.0f, 16)), // input A + TestInputDef({2, 4, 5}, false, GetFloatDataInRange(-1.0f, 1.0f, 40)), // input B + false, // transA + false, // transB + true, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ FusedMatMul with default attributes on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_QDQ_U16_Default) { + RunQDQFusedMatMulTest( + TestInputDef({2, 3}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input A + TestInputDef({3, 2}, false, GetFloatDataInRange(-1.0f, 1.0f, 6)), // input B + false, // transA + false, // transB + false, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All, + "htp", + true); // Use com.microsoft Q/DQ ops +} + +// Test 16-bit QDQ FusedMatMul with batch dimension transposition on HTP +TEST_F(QnnHTPBackendTests, FusedMatMul_QDQ_U16_BatchTranspose) { + RunQDQFusedMatMulTest( + TestInputDef({2, 2, 4}, false, GetFloatDataInRange(-1.0f, 1.0f, 16)), // input A + TestInputDef({2, 4, 5}, false, GetFloatDataInRange(-1.0f, 1.0f, 40)), // input B + false, // transA + false, // transB + true, // transBatchA + false, // transBatchB + 1.0f, // alpha + ExpectedEPNodeAssignment::All, + "htp", + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index 418842ee0a81b..d1f43787c7717 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1314,6 +1314,27 @@ TEST_F(QnnHTPBackendTests, DumpJsonQNNGraph) { std::filesystem::remove_all(dump_dir); } +// Test extended UDMA mode on supported hardware (should run successfully) +TEST_F(QnnHTPBackendTests, ExtendedUdmaModeTest) { + // Create provider options with extended UDMA mode enabled + ProviderOptions options; + options["backend_type"] = "htp"; + options["offload_graph_io_quantization"] = "0"; + options["htp_arch"] = "81"; + options["extended_udma"] = "1"; + + // Define a simple model with Add operation + auto input_defs = {TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f)}; + + // Run the test - this should succeed because v81 supports extended UDMA + RunQnnModelTest(BuildOpTestCase("Add", input_defs, {}, {}, kOnnxDomain), + options, + 13, + ExpectedEPNodeAssignment::All, + 0.008f); +} + // Test option for offloading quantization of graph inputs and dequantization of graph outputs to the CPU EP. TEST_F(QnnHTPBackendTests, EPOffloadsGraphIOQuantDequant) { // Returns a function that checks that the Q/DQ ops at the graph IO boundary are offloaded to CPU diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index a2f1b9b56538b..813abf74828a2 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -77,26 +77,29 @@ void CleanUpCtxFile(std::string context_file_path) { ASSERT_EQ(std::remove(context_file_path.c_str()), 0); } -// Create a model with FusedMatMul + Add (quantized) +// Create a model with FusedGemm + Add (quantized) // input1 -> Add -> Q -> DQ ---- // | -// input2 -> Q -> DQ -> FusedMatMul -> Q -> DQ -> output +// input2 -> Q -> DQ -> FusedGemm -> Q -> DQ -> output static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { return [single_ep_node](ModelTestBuilder& builder) { - // Creat non-quantized FusedMatMul node1 - std::vector data(200 * 200, 1.0f); - NodeArg* input1 = MakeTestInput(builder, TestInputDef({200, 200}, false, data)); - NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({200, 200}, true, data)); + // Create non-quantized FusedGemm node1 + std::vector gemm_input_data(12, 1.0f); + std::vector gemm_weight_data(16, 1.0f); // 4x4 = 16 elements + NodeArg* input1 = MakeTestInput(builder, TestInputDef({3, 4}, false, gemm_input_data)); + NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({4, 4}, true, gemm_weight_data)); auto* add1_output = builder.MakeIntermediate(); - builder.AddNode("FusedMatMul", {input1, add1_ini_input2}, {add1_output}, kMSDomain); + Node& fused_gemm_node1 = builder.AddNode("FusedGemm", {input1, add1_ini_input2}, {add1_output}, kMSDomain); + fused_gemm_node1.AddAttribute("activation", "Relu"); // Create quantized Add node2 - gsl::span data_range = gsl::make_span(data); + std::vector add_data(12, 1.0f); + gsl::span data_range = gsl::make_span(add_data); QuantParams q_parameter = GetDataQuantParams(data_range); auto* add2_input1_qdq = AddQDQNodePair(builder, add1_output, q_parameter.scale, q_parameter.zero_point); - NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({200, 200}, true, data)); + NodeArg* add2_input2 = MakeTestInput(builder, TestInputDef({3, 4}, true, add_data)); auto* add2_input2_qdq = AddQDQNodePair(builder, add2_input2, q_parameter.scale, q_parameter.zero_point); auto* add2_output = builder.MakeIntermediate(); @@ -108,15 +111,16 @@ static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { AddQDQNodePairWithOutputAsGraphOutput(builder, add2_output, q_parameter.scale, q_parameter.zero_point); } else { auto* add3_input1_qdq = AddQDQNodePair(builder, add2_output, q_parameter.scale, q_parameter.zero_point); - NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({200, 200}, true, data)); + NodeArg* add3_ini_input2 = MakeTestInput(builder, TestInputDef({4, 4}, true, gemm_weight_data)); auto* add3_output = builder.MakeIntermediate(); - builder.AddNode("FusedMatMul", {add3_input1_qdq, add3_ini_input2}, {add3_output}, kMSDomain); + Node& fused_gemm_node2 = builder.AddNode("FusedGemm", {add3_input1_qdq, add3_ini_input2}, {add3_output}, kMSDomain); + fused_gemm_node2.AddAttribute("activation", "Relu"); // Create quantized Add node4 auto* add4_input1_qdq = AddQDQNodePair(builder, add3_output, q_parameter.scale, q_parameter.zero_point); - NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({200, 200}, true, data)); + NodeArg* add4_input2 = MakeTestInput(builder, TestInputDef({3, 4}, true, add_data)); auto* add4_input2_qdq = AddQDQNodePair(builder, add4_input2, q_parameter.scale, q_parameter.zero_point); auto* add4_output = builder.MakeIntermediate(); @@ -752,15 +756,15 @@ TEST_F(QnnHTPBackendTests, QnnContextBinary_OriginalCompileApproach_IgnoreCompil } } -// Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary -// The generated Onnx model has 1 FusedMatMul node and 1 EPContext node +// Test that models with 1 non-quantized FusedGemm node and 1 quantized Add node can still generate the context binary +// The generated Onnx model has 1 FusedGemm node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { bool single_ep_node = true; QnnContextBinaryMultiPartitionTestBody(single_ep_node); } -// Test that models with 2 non-quantized FusedMatMul nodes and 2 quantized Add nodes can still generate the context binary -// The generated Onnx model has 2 FusedMatMul nodes and 1 EPContext nodes +// Test that models with 2 non-quantized FusedGemm nodes and 2 quantized Add nodes can still generate the context binary +// The generated Onnx model has 2 FusedGemm nodes and 2 EPContext nodes TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { bool single_ep_node = false; QnnContextBinaryMultiPartitionTestBody(single_ep_node); @@ -836,21 +840,21 @@ void EpCtxCpuNodeWithExternalIniFileTestBody(bool expect_external_ini_file, bool CleanUpCtxFile(ep_context_model_file); } -// Set the session option "ep.context_model_external_initializers_file_name" so FusedMatMul (which fallback on CPU) +// Set the session option "ep.context_model_external_initializers_file_name" so FusedGemm (which fallback on CPU) // will dump initializer data to external file TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithExternalWeights) { EpCtxCpuNodeWithExternalIniFileTestBody(true); } // Without setting the session option "ep.context_model_external_initializers_file_name" -// so FusedMatMul (which fallback on CPU) will NOT dump initializer data to external file +// so FusedGemm (which fallback on CPU) will NOT dump initializer data to external file TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeights) { EpCtxCpuNodeWithExternalIniFileTestBody(false); } // Load model from memory // Without setting the session option "ep.context_model_external_initializers_file_name" -// so FusedMatMul (which fallback on CPU) will NOT dump initializer data to external file +// so FusedGemm (which fallback on CPU) will NOT dump initializer data to external file TEST_F(QnnHTPBackendTests, QnnContextBinaryCpuNodeWithoutExternalWeightsModelFromMemory) { EpCtxCpuNodeWithExternalIniFileTestBody(false, true); } @@ -1897,6 +1901,114 @@ TEST_F(QnnHTPBackendTests, VTCMBackupBufferSharing) { std::remove(qnn_ctx_binary_file_name1.c_str()); } +TEST_F(QnnHTPBackendTests, FileMapping_Off) { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + provider_options["disable_file_mapped_weights"] = "1"; + + // Create QDQ models + std::vector onnx_model_paths{"./weight_share1.onnx", "./weight_share2.onnx"}; + // cleanup in case some failure test doesn't remove them + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + + std::vector ctx_model_paths; + for (auto model_path : onnx_model_paths) { + CreateQdqModel(model_path, DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(std::filesystem::exists(model_path.c_str())); + auto pos = model_path.find_last_of("."); + if (pos != std::string::npos) { + model_path = model_path.substr(0, pos) + "_ctx.onnx"; + } else { + model_path = model_path + "_ctx.onnx"; + } + ctx_model_paths.push_back(model_path); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + + DumpModelWithSharedCtx(provider_options, onnx_model_paths[0], onnx_model_paths[1]); + + std::string qnn_ctx_binary_file_name1; + GetContextBinaryFileName(ctx_model_paths[0], qnn_ctx_binary_file_name1, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name1.empty()); + + std::string qnn_ctx_binary_file_name2; + GetContextBinaryFileName(ctx_model_paths[1], qnn_ctx_binary_file_name2, + DefaultLoggingManager().DefaultLogger()); + EXPECT_TRUE(!qnn_ctx_binary_file_name2.empty()); + // 2 *_ctx.onn point to same .bin file + EXPECT_TRUE(qnn_ctx_binary_file_name1 == qnn_ctx_binary_file_name2); + auto file_size_1 = std::filesystem::file_size(qnn_ctx_binary_file_name1); + EXPECT_TRUE(file_size_1 > 0); + + // only load and run the session on real device +#if defined(__aarch64__) || defined(_M_ARM64) + Ort::SessionOptions so1; + so1.SetLogId("so1"); + so1.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so1.AppendExecutionProvider("QNN", provider_options); + Ort::SessionOptions so2; + + // Test CreateFromBinaryListAsync path + provider_options["enable_vtcm_backup_buffer_sharing"] = "1"; + so2.SetLogId("so2"); + so2.AddConfigEntry(kOrtSessionOptionShareEpContexts, "1"); + so2.AppendExecutionProvider("QNN", provider_options); + + EXPECT_TRUE(2 == ctx_model_paths.size()); +#ifdef _WIN32 + std::wstring ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::wstring ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#else + std::string ctx_model_file1(ctx_model_paths[0].begin(), ctx_model_paths[0].end()); + std::string ctx_model_file2(ctx_model_paths[1].begin(), ctx_model_paths[1].end()); +#endif + Ort::Session session1(*ort_env, ctx_model_file1.c_str(), so1); + Ort::Session session2(*ort_env, ctx_model_file2.c_str(), so2); + + std::vector input_names; + std::vector output_names; + GetModelInputNames(ctx_model_paths[1], input_names, output_names, + DefaultLoggingManager().DefaultLogger()); + + // Run sessions + // prepare input + std::vector input_dim{2, 3}; + std::vector input_value(2 * 3, 0.0f); + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + std::vector ort_inputs; + std::vector input_names_c; + for (size_t i = 0; i < input_names.size(); ++i) { + auto input_tensor = Ort::Value::CreateTensor(info, input_value.data(), input_value.size(), + input_dim.data(), input_dim.size()); + ort_inputs.push_back(std::move(input_tensor)); + input_names_c.push_back(input_names[i].c_str()); + } + std::vector output_names_c; + for (size_t i = 0; i < output_names.size(); ++i) { + output_names_c.push_back(output_names[i].c_str()); + } + + auto ort_outputs1 = session1.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); + auto ort_outputs2 = session2.Run(Ort::RunOptions{}, input_names_c.data(), ort_inputs.data(), ort_inputs.size(), + output_names_c.data(), 1); +#endif + + for (auto model_path : onnx_model_paths) { + std::remove(model_path.c_str()); + } + for (auto ctx_model_path : ctx_model_paths) { + std::remove(ctx_model_path.c_str()); + } + std::remove(qnn_ctx_binary_file_name1.c_str()); +} + // For Ort sessions to generate the context binary, with session option ep.share_ep_contexts enabled // Ort sessions will share the QnnBackendManager, so that all graphs from all models compile into the same Qnn context TEST_F(QnnHTPBackendTests, QnnContextGenWeightSharingSessionAPI) { diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 9ad34788444db..a6d43a3d3a9d9 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -408,12 +408,28 @@ static BackendSupport GetHTPSupport(const onnxruntime::logging::Logger& logger) // Create QNN EP and call GetCapability(). MockKernelLookup kernel_lookup; onnxruntime::GraphViewer graph_viewer(graph); - std::unique_ptr qnn_ep = QnnExecutionProviderWithOptions( - {{"backend_type", "htp"}, {"offload_graph_io_quantization", "0"}}); - GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability - qnn_ep->SetLogger(&logger); - auto result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); + std::vector> result; + std::unique_ptr qnn_ep; + try { + qnn_ep = QnnExecutionProviderWithOptions( + {{"backend_type", "htp"}, {"offload_graph_io_quantization", "0"}, {"enable_htp_shared_memory_allocator", "1"}}); + GraphOptimizerRegistry graph_optimizer_registry(nullptr, nullptr, nullptr); // as a placeholder to feed into GetCapability + + qnn_ep->SetLogger(&logger); + result = qnn_ep->GetCapability(graph_viewer, kernel_lookup, graph_optimizer_registry, nullptr); + } catch (const std::exception& e) { + // handle exception that indicates that the libcdsprpc.so / dll can't be loaded + std::string_view error_message = e.what(); + std::string_view expected_error_message = "Failed to initialize RPCMEM dynamic library handle"; + + if (error_message.find(expected_error_message) != std::string_view::npos) { + return BackendSupport::UNSUPPORTED; + } + + // propagate other exceptions + throw; + } return result.empty() ? BackendSupport::UNSUPPORTED : BackendSupport::SUPPORTED; } diff --git a/onnxruntime/test/providers/qnn/quick_gelu_op_test.cc b/onnxruntime/test/providers/qnn/quick_gelu_op_test.cc new file mode 100644 index 0000000000000..38d26d6c8a4b1 --- /dev/null +++ b/onnxruntime/test/providers/qnn/quick_gelu_op_test.cc @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include "core/graph/constants.h" +#include "test/providers/qnn/qnn_test_utils.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a QuickGelu operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunQuickGeluTest(const TestInputDef& input_def, + float alpha, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "cpu", + float fp32_abs_err = 5e-3f) { + ProviderOptions provider_options; + provider_options["backend_type"] = backend_name; + + if (backend_name == "htp") { + provider_options["enable_htp_fp16_precision"] = "1"; + } + + auto model_builder = [input_def, alpha](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + auto* output = builder.MakeOutput(); + + Node& node = builder.AddNode("QuickGelu", {input}, {output}, kMSDomain); + node.AddAttribute("alpha", alpha); + }; + + RunQnnModelTest(model_builder, + provider_options, + 13, // opset version for contrib ops + expected_ep_assignment, + fp32_abs_err); +} + +// Tests the accuracy of a QDQ QuickGelu model on QNN EP by comparing to CPU EP. +template +static void RunQDQQuickGeluTest(const TestInputDef& input_def, + float alpha, + ExpectedEPNodeAssignment expected_ep_assignment, + const std::string& backend_name = "htp", + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + provider_options["backend_type"] = backend_name; + provider_options["offload_graph_io_quantization"] = "0"; + + GetTestModelFn model_builder_fn = [input_def, alpha](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + auto* output = builder.MakeOutput(); + + Node& node = builder.AddNode("QuickGelu", {input}, {output}, kMSDomain); + node.AddAttribute("alpha", alpha); + }; + + GetTestQDQModelFn qdq_model_builder_fn = [input_def, alpha, use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + + // QuickGelu -> op_output + auto* op_output = builder.MakeIntermediate(); + Node& node = builder.AddNode("QuickGelu", {input_after_qdq}, {op_output}, kMSDomain); + node.AddAttribute("alpha", alpha); + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; + + TestQDQModelAccuracy(model_builder_fn, + qdq_model_builder_fn, + provider_options, + 13, // opset version for contrib ops + expected_ep_assignment, + QDQTolerance(5e-3f)); +} + +// +// CPU tests: +// + +// Test QuickGelu with default alpha value (1.0) +TEST_F(QnnCPUBackendTests, QuickGelu_Default_Alpha) { + RunQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test QuickGelu with custom alpha value +TEST_F(QnnCPUBackendTests, QuickGelu_Custom_Alpha) { + RunQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.702f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test QuickGelu with negative alpha value +TEST_F(QnnCPUBackendTests, QuickGelu_Negative_Alpha) { + RunQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + -1.702f, // alpha + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +TEST_F(QnnHTPBackendTests, QuickGelu_Default_Alpha) { + RunQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.0f, + ExpectedEPNodeAssignment::All, + "htp", + 0.01f); +} + +// Test QuickGelu with custom alpha value on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_Custom_Alpha) { + RunQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.702f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +// Test QuickGelu with negative alpha value on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_Negative_Alpha) { + RunQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + -1.702f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Default_Alpha) { + RunQuickGeluTest(ConvertToFP16InputDef(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))), + 1.0f, + ExpectedEPNodeAssignment::All, + "htp", + 0.01f); +} + +// Test QuickGelu with float16 inputs and custom alpha on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Custom_Alpha) { + RunQuickGeluTest(ConvertToFP16InputDef(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))), + 1.702f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +// Test QuickGelu with float16 inputs and negative alpha on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Negative_Alpha) { + RunQuickGeluTest(ConvertToFP16InputDef(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))), + -1.702f, // alpha + ExpectedEPNodeAssignment::All, + "htp"); +} + +// Test 8-bit QDQ QuickGelu with default alpha value on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U8_Default_Alpha) { + RunQDQQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.0f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ QuickGelu with custom alpha value on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U8_Custom_Alpha) { + RunQDQQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.702f, // alpha + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ QuickGelu with default alpha value on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U16_Default_Alpha) { + RunQDQQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.0f, // alpha + ExpectedEPNodeAssignment::All, + "htp", + true); // Use com.microsoft Q/DQ ops +} + +// Test 16-bit QDQ QuickGelu with custom alpha value on HTP +TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U16_Custom_Alpha) { + RunQDQQuickGeluTest(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + 1.702f, // alpha + ExpectedEPNodeAssignment::All, + "htp", + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index daae3d939660f..b84b361b61367 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -268,6 +268,26 @@ static void RunFP16OpTest(const std::string& op_type, tolerance); } +// Test Concat with empty input +TEST_F(QnnHTPBackendTests, Concat_EmptyInput) { + RunOpTest("Concat", + {TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 0, 4, 4}, false, {})}, + {utils::MakeAttribute("axis", static_cast(1))}, + 13, + ExpectedEPNodeAssignment::All); +} + +// Test Concat with empty initializer +TEST_F(QnnHTPBackendTests, Concat_EmptyInitializer) { + RunOpTest("Concat", + {TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 0, 4, 4}, true, {})}, // true makes this an initializer + {utils::MakeAttribute("axis", static_cast(1))}, + 13, + ExpectedEPNodeAssignment::All); +} + // Test the accuracy of QDQ Sigmoid. TEST_F(QnnHTPBackendTests, UnaryOp_Sigmoid) { RunQDQOpTest("Sigmoid", diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 768a97d7ed2bc..fe98cc2ad561a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1967,6 +1967,52 @@ def test_run_base_model(self): self.assertEqual(len(outputs), 1) self.assertTrue(np.allclose(outputs[0], expected_output)) + def test_get_graph_provider_assignment_info(self): + """ + Tests querying for information about the nodes assigned to the CPU EP. + """ + + # Create session options that enables recording EP graph partitioning info. + session_options = onnxrt.SessionOptions() + session_options.add_session_config_entry("session.record_ep_graph_assignment_info", "1") + + session = onnxrt.InferenceSession(get_name("add_mul_add.onnx"), sess_options=session_options) + + # Query session for information on each subgraph assigned to an EP. + ep_subgraphs = session.get_provider_graph_assignment_info() + + # Check that all 3 nodes are assigned to CPU EP (each in its own subgraph) + self.assertEqual(len(ep_subgraphs), 3) + for ep_subgraph in ep_subgraphs: + self.assertEqual(ep_subgraph.ep_name, "CPUExecutionProvider") + self.assertEqual(len(ep_subgraph.get_nodes()), 1) + + # Serialize each node to an identifier (concatenates domain, operator type, and node name) + node_ids: list[str] = [f"{n.domain}:{n.op_type}/{n.name}" for s in ep_subgraphs for n in s.get_nodes()] + + # Should have 1 Mul and 2 Adds. + self.assertEqual(len(node_ids), 3) + self.assertIn(":Add/add_0", node_ids) + self.assertIn(":Add/add_1", node_ids) + self.assertIn(":Mul/mul_0", node_ids) + + def test_get_graph_provider_assignment_info_not_enabled(self): + """ + Tests querying for information about the nodes assigned to the CPU EP when + the corresponding config entry is disabled. + """ + + # Do not enable "session.record_ep_graph_assignment_info" + session = onnxrt.InferenceSession(get_name("add_mul_add.onnx")) + + # Expect failure + with self.assertRaises(Fail) as context: + session.get_provider_graph_assignment_info() + self.assertIn( + "Session configuration entry 'session.record_ep_graph_assignment_info' must be set to \"1\"", + str(context.exception), + ) + if __name__ == "__main__": unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/onnxruntime_test_python_symlink_data.py b/onnxruntime/test/python/onnxruntime_test_python_symlink_data.py new file mode 100644 index 0000000000000..ea3c0f9ca9904 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_python_symlink_data.py @@ -0,0 +1,250 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import os +import shutil +import struct +import tempfile +import unittest + +import numpy as np +from onnx import TensorProto, helper, save + +import onnxruntime as ort + + +class TestSymLinkOnnxModelExternalData(unittest.TestCase): + def test_symlink_model_and_data_under_same_directory(self): + # The following directory structure simulates huggingface hub local cache: + # temp_dir/ (This corresponds to .cache/huggingface/hub/model_id/) + # blobs/ + # guid1 + # guid2 + # snapshots/version/ + # model.onnx -> ../../blobs/guid1 + # data.bin -> ../../blobs/guid2 + + self.temp_dir = tempfile.mkdtemp() + try: + blobs_dir = os.path.join(self.temp_dir, "blobs") + os.makedirs(blobs_dir) + + snapshots_dir = os.path.join(self.temp_dir, "snapshots", "version") + os.makedirs(snapshots_dir) + + # Create real files in blobs + # We'll use the helper to create the model, but we need to control where files end up. + # Let's manually create the data file in blobs + data_blob_path = os.path.join(blobs_dir, "guid2") + vals = [float(i) for i in range(10)] + with open(data_blob_path, "wb") as f: + f.writelines(struct.pack("f", v) for v in vals) + + # Create model in blobs (referencing "data.bin" as external data) + # When loaded from snapshots/version/model.onnx, ORT looks for snapshots/version/data.bin + + input_ = helper.make_tensor_value_info("input", TensorProto.FLOAT, [10]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [10]) + tensor = helper.make_tensor("external_data", TensorProto.FLOAT, [10], vals) + tensor.data_location = TensorProto.EXTERNAL + tensor.ClearField("float_data") + tensor.ClearField("raw_data") + + k = tensor.external_data.add() + k.key = "location" + k.value = "data.bin" # Relative path + + offset = tensor.external_data.add() + offset.key = "offset" + offset.value = "0" + + length = tensor.external_data.add() + length.key = "length" + length.value = str(len(vals) * 4) + + const_node = helper.make_node("Constant", [], ["const_out"], value=tensor) + add_node = helper.make_node("Add", ["input", "const_out"], ["output"]) + graph = helper.make_graph([const_node, add_node], "test_graph", [input_], [output]) + model = helper.make_model(graph) + + model_blob_path = os.path.join(blobs_dir, "guid1") + save(model, model_blob_path) + + # Now create symlinks in snapshots + model_symlink_path = os.path.join(snapshots_dir, "model.onnx") + data_symlink_path = os.path.join(snapshots_dir, "data.bin") + + try: + os.symlink(model_blob_path, model_symlink_path) + os.symlink(data_blob_path, data_symlink_path) + except (OSError, NotImplementedError) as e: + self.skipTest(f"Skipping symlink test: symlink creation is not supported in this environment: {e}") + + sess = ort.InferenceSession(model_symlink_path, providers=["CPUExecutionProvider"]) + + input_data = np.zeros(10, dtype=np.float32) + res = sess.run(["output"], {"input": input_data}) + expected = np.array([float(i) for i in range(10)], dtype=np.float32) + np.testing.assert_allclose(res[0], expected) + + finally: + shutil.rmtree(self.temp_dir) + + def test_symlink_with_data_in_model_sub_dir(self): + # working directory structure (data is in model sub directory): + # temp_dir/ + # blobs/ + # guid1 + # data/guid2 + # snapshots/version/ + # model.onnx -> ../../blobs/guid1 + # data.bin -> ../../blobs/data/guid2 + + self.temp_dir = tempfile.mkdtemp() + try: + blobs_dir = os.path.join(self.temp_dir, "blobs") + os.makedirs(blobs_dir) + data_dir = os.path.join(blobs_dir, "data") + os.makedirs(data_dir) + + snapshots_dir = os.path.join(self.temp_dir, "snapshots", "version") + os.makedirs(snapshots_dir) + + # Create real files in blobs + # We'll use the helper to create the model, but we need to control where files end up. + # Let's manually create the data file in blobs + data_blob_path = os.path.join(data_dir, "guid2") + vals = [float(i) for i in range(10)] + with open(data_blob_path, "wb") as f: + f.writelines(struct.pack("f", v) for v in vals) + + # Create model in blobs (referencing "data.bin" as external data) + # When loaded from snapshots/version/model.onnx, ORT looks for snapshots/version/data.bin + + input_ = helper.make_tensor_value_info("input", TensorProto.FLOAT, [10]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [10]) + tensor = helper.make_tensor("external_data", TensorProto.FLOAT, [10], vals) + tensor.data_location = TensorProto.EXTERNAL + tensor.ClearField("float_data") + tensor.ClearField("raw_data") + + k = tensor.external_data.add() + k.key = "location" + k.value = "data.bin" # Relative path + + offset = tensor.external_data.add() + offset.key = "offset" + offset.value = "0" + + length = tensor.external_data.add() + length.key = "length" + length.value = str(len(vals) * 4) + + const_node = helper.make_node("Constant", [], ["const_out"], value=tensor) + add_node = helper.make_node("Add", ["input", "const_out"], ["output"]) + graph = helper.make_graph([const_node, add_node], "test_graph", [input_], [output]) + model = helper.make_model(graph) + + model_blob_path = os.path.join(blobs_dir, "guid1") + save(model, model_blob_path) + + # Now create symlinks in snapshots + model_symlink_path = os.path.join(snapshots_dir, "model.onnx") + data_symlink_path = os.path.join(snapshots_dir, "data.bin") + + try: + os.symlink(model_blob_path, model_symlink_path) + os.symlink(data_blob_path, data_symlink_path) + except (OSError, NotImplementedError) as e: + self.skipTest(f"Skipping symlink test: symlink creation is not supported in this environment: {e}") + + sess = ort.InferenceSession(model_symlink_path, providers=["CPUExecutionProvider"]) + + input_data = np.zeros(10, dtype=np.float32) + res = sess.run(["output"], {"input": input_data}) + expected = np.array([float(i) for i in range(10)], dtype=np.float32) + np.testing.assert_allclose(res[0], expected) + + finally: + shutil.rmtree(self.temp_dir) + + def test_symlink_with_data_not_in_model_sub_dir(self): + # working directory structure (data is not in model directory or its sub directories): + # temp_dir/ + # model/ + # guid1 + # data/ + # guid2 + # snapshots/version/ + # model.onnx -> ../../model/guid1 + # data.bin -> ../../data/guid2 + + self.temp_dir = tempfile.mkdtemp() + try: + model_dir = os.path.join(self.temp_dir, "model") + os.makedirs(model_dir) + data_dir = os.path.join(self.temp_dir, "data") + os.makedirs(data_dir) + + snapshots_dir = os.path.join(self.temp_dir, "snapshots", "version") + os.makedirs(snapshots_dir) + + # Create real files in data_dir + # We'll use the helper to create the model, but we need to control where files end up. + # Let's manually create the data file in data_dir + data_blob_path = os.path.join(data_dir, "guid2") + vals = [float(i) for i in range(10)] + with open(data_blob_path, "wb") as f: + f.writelines(struct.pack("f", v) for v in vals) + + # Create model in model_dir (referencing "data.bin" as external data) + # When loaded from snapshots/version/model.onnx, ORT looks for snapshots/version/data.bin + + input_ = helper.make_tensor_value_info("input", TensorProto.FLOAT, [10]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [10]) + tensor = helper.make_tensor("external_data", TensorProto.FLOAT, [10], vals) + tensor.data_location = TensorProto.EXTERNAL + tensor.ClearField("float_data") + tensor.ClearField("raw_data") + + k = tensor.external_data.add() + k.key = "location" + k.value = "data.bin" # Relative path + + offset = tensor.external_data.add() + offset.key = "offset" + offset.value = "0" + + length = tensor.external_data.add() + length.key = "length" + length.value = str(len(vals) * 4) + + const_node = helper.make_node("Constant", [], ["const_out"], value=tensor) + add_node = helper.make_node("Add", ["input", "const_out"], ["output"]) + graph = helper.make_graph([const_node, add_node], "test_graph", [input_], [output]) + model = helper.make_model(graph) + + model_blob_path = os.path.join(model_dir, "guid1") + save(model, model_blob_path) + + # Now create symlinks in snapshots + model_symlink_path = os.path.join(snapshots_dir, "model.onnx") + data_symlink_path = os.path.join(snapshots_dir, "data.bin") + + try: + os.symlink(model_blob_path, model_symlink_path) + os.symlink(data_blob_path, data_symlink_path) + except (OSError, NotImplementedError) as e: + self.skipTest(f"Skipping symlink test: symlink creation is not supported in this environment: {e}") + + with self.assertRaises(Exception) as cm: + ort.InferenceSession(model_symlink_path, providers=["CPUExecutionProvider"]) + + # We expect an error about external data not under model directory or the real model directory. + self.assertIn("External data path validation failed", str(cm.exception)) + finally: + shutil.rmtree(self.temp_dir) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py index 7e0a8496b8bfb..70f8ca127e184 100644 --- a/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py +++ b/onnxruntime/test/python/quantization/test_qnn_preprocess_model.py @@ -55,15 +55,26 @@ def build_model(self, shape, scale_val, bias_val): bias_const = onnx.numpy_helper.from_array(np.array(bias_val, dtype=np.float32), "bias_const") two_const = onnx.numpy_helper.from_array(np.array(2.0, dtype=np.float32), "two_const") - m_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m_rm0_out"]) - m_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m_rm0_out"], ["m_sub_out"]) - m_pow_node = onnx.helper.make_node("Pow", ["m_sub_out", "two_const"], ["m_pow_out"]) - m_rm1_node = onnx.helper.make_node("ReduceMean", ["m_pow_out", "axes_const"], ["m_rm1_out"]) - m_add0_node = onnx.helper.make_node("Add", ["m_rm1_out", "eps_const"], ["m_add0_out"]) - m_sqrt_node = onnx.helper.make_node("Sqrt", ["m_add0_out"], ["m_sqrt_out"]) - m_div_node = onnx.helper.make_node("Div", ["m_sub_out", "m_sqrt_out"], ["m_div_out"]) - m_mul_node = onnx.helper.make_node("Mul", ["m_div_out", "scale_const"], ["m_mul_out"]) - m_add1_node = onnx.helper.make_node("Add", ["m_mul_out", "bias_const"], ["output"]) + m0_rm0_node = onnx.helper.make_node("ReduceMean", ["l2_seq_output", "axes_const"], ["m0_rm0_out"]) + m0_sub_node = onnx.helper.make_node("Sub", ["l2_seq_output", "m0_rm0_out"], ["m0_sub_out"]) + m0_pow_node = onnx.helper.make_node("Pow", ["m0_sub_out", "two_const"], ["m0_pow_out"]) + m0_rm1_node = onnx.helper.make_node("ReduceMean", ["m0_pow_out", "axes_const"], ["m0_rm1_out"]) + m0_add0_node = onnx.helper.make_node("Add", ["m0_rm1_out", "eps_const"], ["m0_add0_out"]) + m0_sqrt_node = onnx.helper.make_node("Sqrt", ["m0_add0_out"], ["m0_sqrt_out"]) + m0_div_node = onnx.helper.make_node("Div", ["m0_sub_out", "m0_sqrt_out"], ["m0_div_out"]) + m0_mul_node = onnx.helper.make_node("Mul", ["m0_div_out", "scale_const"], ["m0_mul_out"]) + m0_add1_node = onnx.helper.make_node("Add", ["m0_mul_out", "bias_const"], ["m0_add1_out"]) + + # Alternate ReduceMean sequence + m1_rm0_node = onnx.helper.make_node("ReduceMean", ["m0_add1_out", "axes_const"], ["m1_rm0_out"]) + m1_sub_node = onnx.helper.make_node("Sub", ["m0_add1_out", "m1_rm0_out"], ["m1_sub_out"]) + m1_mul0_node = onnx.helper.make_node("Mul", ["m1_sub_out", "m1_sub_out"], ["m1_mul0_out"]) + m1_rm1_node = onnx.helper.make_node("ReduceMean", ["m1_mul0_out", "axes_const"], ["m1_rm1_out"]) + m1_add0_node = onnx.helper.make_node("Add", ["m1_rm1_out", "eps_const"], ["m1_add0_out"]) + m1_sqrt_node = onnx.helper.make_node("Sqrt", ["m1_add0_out"], ["m1_sqrt_out"]) + m1_div_node = onnx.helper.make_node("Div", ["m1_sub_out", "m1_sqrt_out"], ["m1_div_out"]) + m1_mul1_node = onnx.helper.make_node("Mul", ["m1_div_out", "scale_const"], ["m1_mul1_out"]) + m1_add1_node = onnx.helper.make_node("Add", ["m1_mul1_out", "bias_const"], ["output"]) graph = onnx.helper.make_graph( [ @@ -76,15 +87,24 @@ def build_model(self, shape, scale_val, bias_val): l2_clip_node, l2_expand_node, l2_div_node, - m_rm0_node, - m_sub_node, - m_pow_node, - m_rm1_node, - m_add0_node, - m_sqrt_node, - m_div_node, - m_mul_node, - m_add1_node, + m0_rm0_node, + m0_sub_node, + m0_pow_node, + m0_rm1_node, + m0_add0_node, + m0_sqrt_node, + m0_div_node, + m0_mul_node, + m0_add1_node, + m1_rm0_node, + m1_sub_node, + m1_mul0_node, + m1_rm1_node, + m1_add0_node, + m1_sqrt_node, + m1_div_node, + m1_mul1_node, + m1_add1_node, ], "qnn_f32_model", [root_inp], @@ -119,8 +139,8 @@ def test_all_fusions(self): fused_model = onnx.load_model("model.qnn_pp.onnx") - # 3 fused Ops: Gelu, LpNorm, LayerNorm - self.assertEqual(len(fused_model.graph.node), 3) + # 4 fused Ops: Gelu, LpNorm, LayerNorm of two patterns + self.assertEqual(len(fused_model.graph.node), 4) expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} for node in fused_model.graph.node: self.assertIn(node.op_type, expected_op_types) @@ -167,8 +187,8 @@ def test_external_data(self): fused_model = onnx.load_model("model.qnn_pp.onnx", load_external_data=False) - # 3 fused Ops: Gelu, LpNorm, LayerNorm - self.assertEqual(len(fused_model.graph.node), 3) + # 4 fused Ops: Gelu, LpNorm, LayerNorm of two patterns + self.assertEqual(len(fused_model.graph.node), 4) expected_op_types = {"Gelu", "LpNormalization", "LayerNormalization"} for node in fused_model.graph.node: self.assertIn(node.op_type, expected_op_types) diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_2bits.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_2bits.py new file mode 100644 index 0000000000000..a7a130654407a --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_2bits.py @@ -0,0 +1,151 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest + +import numpy as np +import numpy.typing as npt + + +def dequantize_blockwise_2bits(quant_values, scale, zero_point, valid_len): + blob_size = quant_values.shape[0] + block_size = blob_size * 4 + + quant_float = np.zeros((block_size), dtype=scale.dtype) + for b in range(blob_size): + v = quant_values[b] + quant_float[4 * b] = ((v & 0x3) - zero_point) * scale if 4 * b < valid_len else 0.0 + quant_float[4 * b + 1] = (((v >> 2) & 0x3) - zero_point) * scale if 4 * b + 1 < valid_len else 0.0 + quant_float[4 * b + 2] = (((v >> 4) & 0x3) - zero_point) * scale if 4 * b + 2 < valid_len else 0.0 + quant_float[4 * b + 3] = (((v >> 6) & 0x3) - zero_point) * scale if 4 * b + 3 < valid_len else 0.0 + return quant_float + + +def quantize_blockwise_2bits_ref(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int2 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + blob_size = block_size // 4 + k_blocks = (rows + block_size - 1) // block_size + padded_rows = k_blocks * block_size + pad_len = padded_rows - rows + matrix_float_padded = matrix_float + if pad_len > 0: + matrix_float_padded = np.pad(matrix_float, ((0, pad_len), (0, 0)), "constant") + + packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float_padded.dtype) + zero_point = np.full((cols, (k_blocks + 3) // 4), 0xAA, dtype="uint8") + + matrix_float_padded = np.transpose(matrix_float_padded) + for n in range(cols): + for k_id in range(0, rows, block_size): + if is_symmetric: + amax_idx = np.argmax(np.abs(matrix_float_padded[n, k_id : k_id + block_size])) + bmax = np.float32(matrix_float_padded[n, k_id + amax_idx]) + scale = bmax / (-2.0) + zp = 2 + else: + vmin = np.min(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmax = np.max(np.float32(matrix_float_padded[n, k_id : k_id + block_size])) + vmin = min(vmin, 0.0) + vmax = max(vmax, 0.0) + scale = (vmax - vmin) / ((1 << 2) - 1) + zero_point_fp = vmin + if scale != 0.0: + zero_point_fp = 0.0 - vmin / scale + zp = min(3, max(0, round(zero_point_fp))) + + reciprocal_scale = 1.0 / scale if scale != 0 else 0.0 + block_idx = k_id // block_size + scales[n, block_idx] = scale + zp_pair = zero_point[n, block_idx // 4] + zp_idx = block_idx % 4 + zp_masks = [0xFC, 0xF3, 0xCF, 0x3F] + zero_point[n, block_idx // 4] = (zp_pair & zp_masks[zp_idx]) | (zp << (zp_idx * 2)) + + blk_int0 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int1 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 1 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int2 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 2 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + blk_int3 = np.clip( + np.round(np.float32(matrix_float_padded[n, k_id + 3 : k_id + block_size : 4] * reciprocal_scale + zp)), + 0, + 3, + ).astype("uint8") + packed[n, block_idx] = np.bitwise_or( + np.bitwise_or(blk_int0, np.left_shift(blk_int1, 2)), + np.bitwise_or(np.left_shift(blk_int2, 4), np.left_shift(blk_int3, 6)), + ) + + return (packed, scales, zero_point) + + +def quantize_blockwise_2bits_target(matrix_float: npt.ArrayLike, block_size: int, is_symmetric: bool): + if len(matrix_float.shape) != 2: + raise ValueError("Current int2 block quantization only supports 2D tensors!") + rows, cols = matrix_float.shape + + k_blocks = (rows + block_size - 1) // block_size + packed = np.zeros((cols, k_blocks, block_size // 4), dtype="uint8") + scales = np.zeros((cols, k_blocks), dtype=matrix_float.dtype) + zero_point = np.full((cols, (k_blocks + 3) // 4), 0xAA, dtype="uint8") + from onnxruntime.capi._pybind_state import quantize_matmul_2bits # noqa: PLC0415 + + quantize_matmul_2bits(packed, matrix_float, scales, zero_point, block_size, cols, rows, is_symmetric) + return (packed, scales, zero_point) + + +class TestQuantizeBlockwise2Bits(unittest.TestCase): + def test_quantize_blockwise_2bits(self): + for rows, cols in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + for is_symmetric in [True, False]: + matrix_float = np.random.rand(rows, cols).astype(type) + quant_value_ref, scales_ref, zero_point_ref = quantize_blockwise_2bits_ref( + matrix_float, block_size, is_symmetric + ) + quant_value, scales, zero_point = quantize_blockwise_2bits_target( + matrix_float, block_size, is_symmetric + ) + assert np.allclose(scales_ref, scales) + assert np.allclose(zero_point_ref, zero_point) + for c in range(quant_value_ref.shape[0]): + for k in range(quant_value_ref.shape[1]): + zp_shift = (k % 4) * 2 + assert np.allclose( + dequantize_blockwise_2bits( + quant_value_ref[c, k], + scales_ref[c, k], + (zero_point_ref[c, k // 4] >> zp_shift) & 0x3, + min(block_size, rows - k * block_size), + ), + dequantize_blockwise_2bits( + quant_value[c, k], + scales[c, k], + (zero_point[c, k // 4] >> zp_shift) & 0x3, + min(block_size, rows - k * block_size), + ), + atol=1.2 * abs(scales[c, k]), + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py index e7ea83dd00297..afefc4e616a87 100644 --- a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -59,6 +59,9 @@ def setUp(self): torch.manual_seed(0) pytorch_export_contrib_ops.register() + def tearDown(self): + pytorch_export_contrib_ops.unregister() + def run_test( self, model, @@ -101,6 +104,7 @@ def run_test( input_names=input_names, output_names=output_names, custom_opsets=custom_opsets, + dynamo=False, ) # compute onnxruntime output prediction @@ -143,12 +147,13 @@ def test_gelu_is_fused_by_default(self): f, opset_version=self.opset_version, custom_opsets={"com.microsoft": 1}, + dynamo=False, ) f.seek(0) onnx_model = onnx.load(f) - node = onnx_model.graph.node[0] - self.assertEqual(node.op_type, "Gelu") - self.assertEqual(node.domain, "com.microsoft") + # Default GELU should be mapped to ORT contrib Gelu for performance. + gelu_nodes = [n for n in onnx_model.graph.node if n.op_type == "Gelu" and n.domain == "com.microsoft"] + self.assertEqual(len(gelu_nodes), 1) @parameterized.parameterized.expand([("default_approximate", "none"), ("tanh_approximate", "tanh")]) @unittest.skipIf(_torch_version_lower_than("1.12"), "Gelu's approximate parameter unsupported in PyTorch < 1.12") @@ -230,8 +235,8 @@ def forward(self, input): # IR version 4 style export. ONNXExporterTest_opset9_IRv4 = type( "TestONNXRuntime_opset9_IRv4", - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, keep_initializers_as_inputs=False), + (ONNXExporterTest,), + dict(keep_initializers_as_inputs=False), ) diff --git a/onnxruntime/test/python/transformers/benchmark_qmoe.py b/onnxruntime/test/python/transformers/benchmark_qmoe.py new file mode 100644 index 0000000000000..b96c9cdcf5c3a --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_qmoe.py @@ -0,0 +1,191 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import os +import sys +import time +import unittest + +import numpy +import torch + +# Add current directory to path to allow importing from test_qmoe_cpu +current_dir = os.path.dirname(os.path.abspath(__file__)) +if current_dir not in sys.path: + sys.path.append(current_dir) + +from test_qmoe_cpu import PhiMoEConfig, PhiMoESparseMoeBlock, TensorProto # noqa: E402 + +# Reduces number of tests to run for faster pipeline checks +pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" + + +@unittest.skipIf(pipeline_mode, "Skip benchmark in CI pipeline.") +class TestQMoESwiGLUBenchmark(unittest.TestCase): + """Benchmark tests for QMoE SwiGLU performance measurement.""" + + def test_qmoe_swiglu_throughput_benchmark(self): + """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" + print("\n=== QMoE SwiGLU Throughput Benchmark ===") + + # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) + configs = [ + ("Medium-4bit", 2880, 2880, 32, 4, 4), + ("Medium-8bit", 2880, 2880, 32, 4, 8), + ] + + batch_size = 1 + sequence_length = 512 + num_runs = 1000 + + results = [] + + for config_name, hidden_size, intermediate_size, num_experts, top_k, quant_bits in configs: + torch.manual_seed(42) + numpy.random.seed(42) + + torch_output = None + ort_output = None + + print(f"\nTesting {config_name}:") + print(f" Hidden: {hidden_size}, Intermediate: {intermediate_size}") + print(f" Experts: {num_experts}, Top-K: {top_k}, Quant: {quant_bits}-bit") + + try: + # Create config and model + config = PhiMoEConfig( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + num_local_experts=num_experts, + num_experts_per_tok=top_k, + ) + + qmoe_swiglu = PhiMoESparseMoeBlock( + config, + batch_size=batch_size, + sequence_length=sequence_length, + quant_bits=quant_bits, + onnx_dtype=TensorProto.FLOAT, + ) + + # Create test input with fixed sequence length to match ONNX model + full_hidden_states = torch.randn(batch_size, sequence_length, hidden_size).to(torch.float32) + + # For TTFT simulation, we'll measure single forward pass time + # This represents the time to process one token in autoregressive generation + + # Warm up with full context + for _ in range(3): + _ = qmoe_swiglu.forward(full_hidden_states) + + # Benchmark PyTorch TTFT (Time to First Token) + # Measure time for a single forward pass (represents token generation time) + torch.manual_seed(42) + + start_time = time.time() + for _ in range(num_runs): + torch_output = qmoe_swiglu.forward(full_hidden_states) + end_time = time.time() + torch_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second (throughput) + # For sequence generation, this represents the rate at which we can generate tokens + torch_tokens_per_sec = 1000.0 / torch_ttft_ms # 1 token / (time_ms / 1000) + + print(f" PyTorch TTFT: {torch_ttft_ms:.3f} ms (per token generation time)") + print(f" PyTorch Throughput: {torch_tokens_per_sec:.1f} tokens/sec") + + # Benchmark ONNX Runtime + ort_ttft_ms = 0 + ort_tokens_per_sec = 0 + speedup = 0 + throughput_ratio = 0 + max_diff = 0 + + model_updated = qmoe_swiglu.recreate_onnx_model() + if model_updated and qmoe_swiglu.ort_sess is not None: + # Warm up ORT with full context + for _ in range(3): + _ = qmoe_swiglu.ort_forward(full_hidden_states) + + torch.manual_seed(42) + + # Measure ONNX Runtime TTFT (Time to First Token) + start_time = time.time() + for _ in range(num_runs): + ort_output = qmoe_swiglu.ort_forward(full_hidden_states) + end_time = time.time() + ort_ttft_ms = (end_time - start_time) / num_runs * 1000 + + # Calculate tokens per second for ONNX Runtime + ort_tokens_per_sec = 1000.0 / ort_ttft_ms # 1 token / (time_ms / 1000) + + speedup = torch_ttft_ms / ort_ttft_ms if ort_ttft_ms > 0 else 0 + throughput_ratio = ort_tokens_per_sec / torch_tokens_per_sec if torch_tokens_per_sec > 0 else 0 + + print(f" ONNX RT TTFT: {ort_ttft_ms:.3f} ms (per token generation time)") + print(f" ONNX RT Throughput: {ort_tokens_per_sec:.1f} tokens/sec") + print(f" TTFT Speedup: {speedup:.2f}x") + print(f" Throughput Gain: {throughput_ratio:.2f}x") + else: + print(" ONNX RT: Not available") + + # Calculate max difference if both outputs available + if torch_output is not None and ort_output is not None: + max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max().item() + print(f" Max diff: {max_diff:.6f}") + + results.append( + { + "config": config_name, + "torch_ttft_ms": torch_ttft_ms, + "torch_tokens_per_sec": torch_tokens_per_sec, + "ort_ttft_ms": ort_ttft_ms, + "ort_tokens_per_sec": ort_tokens_per_sec, + "speedup": speedup, + "throughput_ratio": throughput_ratio, + "max_diff": max_diff, + } + ) + + except Exception as e: + print(f" Error: {e}") + continue + + # Summary + print("\n=== Token Generation Time & Throughput Summary ===") + print( + f"{'Config':<15} {'PT Time':<10} {'PT tok/s':<10} {'ORT Time':<11} {'ORT tok/s':<11} {'Time Gain':<10} {'Throughput':<11} {'Max Diff':<10}" + ) + print("-" * 105) + for result in results: + config = result["config"] + torch_ttft = result["torch_ttft_ms"] + torch_tps = result["torch_tokens_per_sec"] + ort_ttft = result["ort_ttft_ms"] + ort_tps = result["ort_tokens_per_sec"] + speedup = result["speedup"] + throughput_ratio = result["throughput_ratio"] + max_diff = result["max_diff"] + + ort_ttft_str = f"{ort_ttft:.3f}" if ort_ttft > 0 else "N/A" + ort_tps_str = f"{ort_tps:.1f}" if ort_tps > 0 else "N/A" + speedup_str = f"{speedup:.2f}x" if speedup > 0 else "N/A" + throughput_str = f"{throughput_ratio:.2f}x" if throughput_ratio > 0 else "N/A" + + print( + f"{config:<15} {torch_ttft:<10.3f} {torch_tps:<10.1f} {ort_ttft_str:<11} {ort_tps_str:<11} {speedup_str:<10} {throughput_str:<11} {max_diff:<10.6f}" + ) + + print("\nNotes:") + print("- Time: Token generation time in ms (lower is better)") + print("- tok/s: Tokens per second throughput (higher is better)") + print("- Time Gain: ORT speedup for latency (higher is better)") + print("- Throughput: ORT throughput improvement (higher is better)") + + +if __name__ == "__main__": + benchmark = TestQMoESwiGLUBenchmark() + benchmark.test_qmoe_swiglu_throughput_benchmark() diff --git a/onnxruntime/test/python/transformers/parity_utilities.py b/onnxruntime/test/python/transformers/parity_utilities.py index fa16f0e67a523..04a1ed06773e7 100644 --- a/onnxruntime/test/python/transformers/parity_utilities.py +++ b/onnxruntime/test/python/transformers/parity_utilities.py @@ -92,6 +92,7 @@ def export_onnx(model, onnx_model_path, float16, hidden_size, device): dynamic_axes=dynamic_axes, opset_version=11, do_constant_folding=True, + dynamo=False, ) print("exported:", onnx_model_path) diff --git a/onnxruntime/test/python/transformers/test_gelu_fusions.py b/onnxruntime/test/python/transformers/test_gelu_fusions.py index 11ae1401ff8ed..a63e2653f2fbc 100644 --- a/onnxruntime/test/python/transformers/test_gelu_fusions.py +++ b/onnxruntime/test/python/transformers/test_gelu_fusions.py @@ -75,17 +75,22 @@ def test_fusions(self, test_case, dynamo): dummy_input = torch.ones(3, dtype=torch.float32) test_name = f"{operator}_{source}" onnx_path = f"{test_name}.onnx" + + # For Torch 2.10+, torch.nn.functional.gelu(approximate="tanh") exports as Gelu node. + # So we force opset_version=18 here. torch.onnx.export( model, (dummy_input,), onnx_path, input_names=["input"], output_names=["output"], - dynamo=dynamo, + opset_version=18, + dynamo=False, optimize=True, # Only meaningful when dynamo is True ) optimizer = optimize_model(onnx_path, "bert") # optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx") + os.remove(onnx_path) # Remove the associated .data file (dynamo) data_path = onnx_path + ".data" diff --git a/onnxruntime/test/python/transformers/test_gpt2_benchmark.py b/onnxruntime/test/python/transformers/test_gpt2_benchmark.py index 2d9bc035fe4fd..40be872250f1a 100644 --- a/onnxruntime/test/python/transformers/test_gpt2_benchmark.py +++ b/onnxruntime/test/python/transformers/test_gpt2_benchmark.py @@ -9,7 +9,6 @@ import os import unittest -import coloredlogs import pytest from parity_utilities import find_transformers_source @@ -50,6 +49,6 @@ def test_gpt2_int8(self): if __name__ == "__main__": - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s") logging.getLogger("transformers").setLevel(logging.ERROR) unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py b/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py index e179d3d087120..bda99abbb7287 100644 --- a/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py +++ b/onnxruntime/test/python/transformers/test_gpt2_to_onnx.py @@ -7,7 +7,6 @@ import logging import unittest -import coloredlogs import pytest from parity_utilities import find_transformers_source @@ -58,6 +57,6 @@ def test_auto_mixed_precision(self): if __name__ == "__main__": - coloredlogs.install(fmt="%(message)s") + logging.basicConfig(format="%(message)s") logging.getLogger("transformers").setLevel(logging.ERROR) unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gqa.py b/onnxruntime/test/python/transformers/test_gqa.py index b3a5c15718ffb..e800c22f92efb 100644 --- a/onnxruntime/test/python/transformers/test_gqa.py +++ b/onnxruntime/test/python/transformers/test_gqa.py @@ -22,21 +22,36 @@ from onnx import TensorProto, helper from parameterized import parameterized -from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_build_info +from onnxruntime import ( + InferenceSession, + SessionOptions, + get_available_providers, + get_build_info, +) # Set seed for reproducibility torch.manual_seed(0) random.seed(69) +try: + from rotary_flash import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + # Reduces number of tests to run for faster pipeline checks pipeline_mode = os.getenv("PIPELINE_MODE", "1") == "1" # Number of values per parameter (compared to pipeline mode) param_count = int(os.getenv("PARAM_COUNT", "3")) if not pipeline_mode else 2 -# When quick build is used, flash attention only supports fp16 and head_size=128 -quick_build = ", quick-build=1, " in get_build_info() +# When quick build is used, flash attention only supports head_size=128 +quick_build = ", quick-build=" in get_build_info() + +enable_debug_print = quick_build + +enable_deterministic_check = True +enable_quantized_kv_tests = True # ################################################################################################# # Configuration and Helper Classes # ################################################################################################# @@ -52,6 +67,14 @@ "int4": TensorProto.UINT8, } +TORCH_DTYPE_TO_ONNX_MAP = { + torch.float32: TensorProto.FLOAT, + torch.float16: TensorProto.FLOAT16, + torch.bfloat16: TensorProto.BFLOAT16, + torch.int32: TensorProto.INT32, + torch.int8: TensorProto.INT8, +} + TORCH_DTYPE_MAP = { "float32": torch.float32, "float16": torch.float16, @@ -156,8 +179,8 @@ def forward(self, x, cos, sin, pos, interleaved): # Triton-based implementation for CUDA def rotary_embedding_cuda(*args, **kwargs): - from rotary_flash import apply_rotary_emb # noqa: PLC0415 - + if apply_rotary_emb is None: + raise ImportError("rotary_flash not found") return apply_rotary_emb(*args, **kwargs) @@ -262,7 +285,10 @@ def create_gqa_node_and_io( helper.make_tensor_value_info("total_sequence_length", TensorProto.INT32, [1]), ] - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if not config.packed: graph_input.extend( @@ -431,12 +457,19 @@ def gqa_prompt_func( bind_tensor(io_binding, "key", new_k, device, ort_type) bind_tensor(io_binding, "value", new_v, device, ort_type) - # 3. Bind 'past_key', 'past_value' (if share_buffer and passed as k/v) + # 3. Bind 'past_key', 'past_value' if share_buffer: # cache_ort_type corresponds to config.kv_cache_type - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] - bind_tensor(io_binding, "past_key", k, device, cache_ort_type) - bind_tensor(io_binding, "past_value", v, device, cache_ort_type) + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + k_to_bind = k if share_buffer else k[:, :, :0, :] + v_to_bind = v if share_buffer else v[:, :, :0, :] + bind_tensor(io_binding, "past_key", k_to_bind, device, cache_ort_type) + bind_tensor(io_binding, "past_value", v_to_bind, device, cache_ort_type) + + # Scales are bound below in section 6 # 4. Bind scalars/1D tensors # seqlens_k is INT32 @@ -487,8 +520,16 @@ def gqa_prompt_func( # Determine dtype for cache tensors cache_dtype = out_dtype cache_ort_type = ort_type - if config.kv_cache_type in ONNX_TENSOR_TYPE_MAP: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + is_valid_type = config.kv_cache_type in TORCH_DTYPE_TO_ONNX_MAP + else: + is_valid_type = config.kv_cache_type in ONNX_TENSOR_TYPE_MAP + + if is_valid_type: + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if share_buffer: # We bind output to the input buffer 'k' / 'v' (in-place update) @@ -559,7 +600,10 @@ def gqa_past_func( # 3. Bind 'past_key', 'past_value' # These are required inputs for past_func # cache_ort_type corresponds to config.kv_cache_type - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if share_buffer: # If sharing buffer, we bind 'past_key' to the large buffer 'k' @@ -615,14 +659,22 @@ def gqa_past_func( if share_buffer: present_seqlen = config.buffer_sequence_length else: - present_seqlen = total_seq_len + present_seqlen = total_seq_len # For past_func, total seq len is accumulated present_dims = [config.batch_size, config.kv_num_heads, present_seqlen, config.head_size] cache_dtype = out_dtype cache_ort_type = ort_type - if config.kv_cache_type in ONNX_TENSOR_TYPE_MAP: - cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] + if isinstance(config.kv_cache_type, torch.dtype): + is_valid_type = config.kv_cache_type in TORCH_DTYPE_TO_ONNX_MAP + else: + is_valid_type = config.kv_cache_type in ONNX_TENSOR_TYPE_MAP + + if is_valid_type: + if isinstance(config.kv_cache_type, torch.dtype): + cache_ort_type = TORCH_DTYPE_TO_ONNX_MAP[config.kv_cache_type] + else: + cache_ort_type = ONNX_TENSOR_TYPE_MAP[config.kv_cache_type] if share_buffer: # In-place update to k/v buffers @@ -754,9 +806,9 @@ def parity_check_gqa_prompt( causal, rtol, atol, + std=0.2, ): torch.manual_seed(0) - std = 0.02 q = ( torch.randn( config.batch_size, @@ -873,24 +925,56 @@ def parity_check_gqa_prompt( # seqlens_k for GQA op is past_seq_len + seq_len - 1 ort_seqlens = cache_seqlens - 1 - out, present_k, present_v = gqa_prompt_func( - q=q_ort, - k=k_ort, - v=v_ort, - config=config, - new_k=new_k_ort, - new_v=new_v_ort, - cos=cos, - sin=sin, - seqlens_k=ort_seqlens, - position_ids=position_ids, - attention_bias=attention_bias, - head_sink=head_sink, - ep=ep, - device=device, - share_buffer=config.share_buffer, - ort_type=ort_type, - ) + num_runs = 2 if enable_deterministic_check else 1 + for i in range(num_runs): + out, present_k, present_v = gqa_prompt_func( + q=q_ort, + k=k_ort, + v=v_ort, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens, + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=config.share_buffer, + ort_type=ort_type, + ) + if i == 0: + first_out = out.clone() + first_present_k = present_k.clone() if present_k is not None else None + first_present_v = present_v.clone() if present_v is not None else None + else: + if present_k is not None: + try: + torch.testing.assert_close( + present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" + ) + except AssertionError as e: + print(e) + raise e + if present_v is not None: + try: + torch.testing.assert_close( + present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" + ) + except AssertionError as e: + print(e) + raise e + try: + torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + except AssertionError as e: + max_diff = (out - first_out).abs().max().item() + print(f"Output mismatch max diff: {max_diff}") + with open("/tmp/gqa_diff_info.txt", "w") as f: + f.write(f"Max Diff: {max_diff}\n") + print(e) + raise e out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() @@ -917,9 +1001,12 @@ def parity_check_gqa_prompt( k_cache_ref_np = k_cache_ref_np[:, :, : config.kv_sequence_length, :] v_cache_ref_np = v_cache_ref_np[:, :, : config.kv_sequence_length, :] + print_diff_statistics(torch.tensor(present_k_np - k_cache_ref_np), "present_k") numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(present_v_np - v_cache_ref_np), "present_v") numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) @@ -932,6 +1019,7 @@ def parity_check_gqa_past( causal, rtol, atol, + std=0.2, ): if ort_type == TensorProto.FLOAT16: torch_type = torch.float16 @@ -940,7 +1028,6 @@ def parity_check_gqa_past( else: torch_type = torch.float32 torch.manual_seed(0) - std = 0.02 # --- Test Data Generation --- q = ( torch.randn( @@ -966,10 +1053,10 @@ def parity_check_gqa_past( ) v = torch.randn_like(k) * std - # Random past sequence lengths. This tests paddings in decoding. + # past cache sequence length is in [1, past_kv_sequence_length] cache_seqlens = torch.randint( - 0, - config.past_kv_sequence_length - config.q_sequence_length + 1, + 1, + config.past_kv_sequence_length + 1, (config.batch_size,), device=device, dtype=torch.long, @@ -1056,27 +1143,50 @@ def parity_check_gqa_past( new_k_ort, new_v_ort = None, None ort_seqlens = cache_seqlens + config.q_sequence_length - 1 - out, present_k, present_v = gqa_past_func( - q=q_ort, - k=k, - v=v, - config=config, - new_k=new_k_ort, - new_v=new_v_ort, - cos=cos, - sin=sin, - seqlens_k=ort_seqlens.int(), - position_ids=position_ids, - attention_bias=attention_bias, - head_sink=head_sink, - ep=ep, - device=device, - share_buffer=config.share_buffer, - ort_type=ort_type, - ) + num_runs = 2 if enable_deterministic_check else 1 + for i in range(num_runs): + out, present_k, present_v = gqa_past_func( + q=q_ort, + k=k, + v=v, + config=config, + new_k=new_k_ort, + new_v=new_v_ort, + cos=cos, + sin=sin, + seqlens_k=ort_seqlens.int(), + position_ids=position_ids, + attention_bias=attention_bias, + head_sink=head_sink, + ep=ep, + device=device, + share_buffer=config.share_buffer, + ort_type=ort_type, + ) + if i == 0: + first_out = out.clone() + first_present_k = present_k.clone() if present_k is not None else None + first_present_v = present_v.clone() if present_v is not None else None + else: + torch.testing.assert_close(out, first_out, rtol=0, atol=0, msg="Output mismatch between two runs") + if present_k is not None: + torch.testing.assert_close( + present_k, first_present_k, rtol=0, atol=0, msg="present_k mismatch between two runs" + ) + if present_v is not None: + torch.testing.assert_close( + present_v, first_present_v, rtol=0, atol=0, msg="present_v mismatch between two runs" + ) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out_np = out.to(torch.float32).detach().cpu().numpy() + if enable_debug_print: + print(f"[DEBUG] out_np non-zeros: {numpy.count_nonzero(out_np)} / {out_np.size}") + print(f"[DEBUG] out_ref_np non-zeros: {numpy.count_nonzero(out_ref_np)} / {out_ref_np.size}") + + if numpy.count_nonzero(out_ref_np) > 0 and numpy.count_nonzero(out_np) == 0: + raise RuntimeError("Output is all zeros") + # --- Comparison --- # Compare KV cache # Transpose reference back to BNSH to match ORT output @@ -1090,9 +1200,12 @@ def parity_check_gqa_past( k_cache_ref_np = k_cache_ref_np[:, :, :total_len, :] v_cache_ref_np = v_cache_ref_np[:, :, :total_len, :] + print_diff_statistics(torch.tensor(present_k_np - k_cache_ref_np), "present_k") numpy.testing.assert_allclose(present_k_np, k_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(present_v_np - v_cache_ref_np), "present_v") numpy.testing.assert_allclose(present_v_np, v_cache_ref_np, rtol=rtol, atol=atol) + print_diff_statistics(torch.tensor(out_np - out_ref_np), "out") numpy.testing.assert_allclose(out_np, out_ref_np, rtol=rtol, atol=atol) @@ -1240,6 +1353,51 @@ def parity_test_gqa_padding_prompt(): torch.testing.assert_close(out_ort, out_ref, rtol=1e-2, atol=1e-2) +# ################################################################################################# +# Test Utilities +# ################################################################################################# + + +def print_diff_statistics(diff_tensor: torch.Tensor, prefix: str = ""): + """ + Print percentile statistics (75%, 95%, 99%) for a difference tensor. + This helps assess parity quality beyond just max difference. + + Args: + diff_tensor: Tensor containing absolute differences between expected and actual outputs. + prefix: Optional prefix string for the output message. + """ + if not enable_debug_print: + return + + diff_flat = diff_tensor.flatten().float() + if diff_flat.numel() == 0: + print(f"{prefix}Diff statistics: empty tensor") + return + + # Compute percentiles + sorted_diff, _ = torch.sort(diff_flat) + n = sorted_diff.numel() + + p75_idx = min(int(n * 0.75), n - 1) + p90_idx = min(int(n * 0.90), n - 1) + p95_idx = min(int(n * 0.95), n - 1) + p99_idx = min(int(n * 0.99), n - 1) + p999_idx = min(int(n * 0.999), n - 1) + + p75 = sorted_diff[p75_idx].item() + p90 = sorted_diff[p90_idx].item() + p95 = sorted_diff[p95_idx].item() + p99 = sorted_diff[p99_idx].item() + p999 = sorted_diff[p999_idx].item() + max_val = sorted_diff[-1].item() + mean_val = diff_flat.mean().item() + + print( + f"{prefix} Diff stats - mean: {mean_val:.6f}, p75: {p75:.6f}, p90: {p90:.6f}, p95: {p95:.6f}, p99: {p99:.6f}, p999: {p999:.6f}, max: {max_val:.6f}" + ) + + # ################################################################################################# # Test Case Generators # ################################################################################################# @@ -1260,11 +1418,11 @@ def get_softmax_options(allow_head_sink: bool = True): return options -def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): +def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True, allow_local: bool = True): batches = [3, 1, 5] seqs = [(35, 35), (1, 1), (64, 64), (128, 128), (240, 240), (2000, 2000)] heads = [(6, 3), (3, 1), (32, 8)] - h_sizes = [128] if quick_build else [128, 32, 64, 256] + h_sizes = [128] if quick_build else [128, 32, 64, 80, 160, 256] smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) rotary_opts = list(get_cuda_rotary_options()) @@ -1291,7 +1449,7 @@ def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): b = batches[combo_index % len(batches)] sq, skv = seqs[combo_index % len(seqs)] n, n2 = heads[combo_index % len(heads)] - lws_opts = [-1, random.randint(1, skv)] + lws_opts = [-1, random.randint(1, skv)] if allow_local else [-1] lws = lws_opts[combo_index % len(lws_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ @@ -1327,19 +1485,21 @@ def gqa_cuda_prompt_test_cases(allow_head_sink: bool = True): yield name, config -def gqa_cuda_past_test_cases(allow_head_sink: bool = True): +def gqa_cuda_past_test_cases( + allow_head_sink: bool = True, allow_local: bool = True, enforce_share_buffer: bool = False +): batches = [2, 1, 3] - # s: new sequence length, s2: past sequence length + # s: new sequence length, s2: past sequence length`` seqs = [(1, 1), (1, 128), (1, 2048), (1, 5000)] subsequent_prompt_seqs = [(3, 256)] heads = [(32, 8), (6, 3), (9, 9)] - h_sizes = [128] if quick_build else [128, 40, 64, 256] + h_sizes = [128] if quick_build else [128, 40, 64, 80, 256] smmoth_softmax__head_sink = get_softmax_options(allow_head_sink) rotary_opts = list(get_cuda_rotary_options()) packed_opts = [False, True] # For past test: pipeline tests share_buffer=True only, comprehensive tests both - share_buffer_opts = [True] if pipeline_mode else [True, False] + share_buffer_opts = [True] if pipeline_mode or enforce_share_buffer else [True, False] softcap_opts = [0.0, 50.0] # Use new strategy for both modes: iterate over key code path parameters @@ -1367,7 +1527,7 @@ def gqa_cuda_past_test_cases(allow_head_sink: bool = True): b = 1 # Force batch=1 for subsequent prompt n, n2 = heads[combo_index % len(heads)] - lws_opts = [-1, random.randint(1, s2)] + lws_opts = [-1, random.randint(1, s2)] if allow_local else [-1] lws = lws_opts[combo_index % len(lws_opts)] softcap = softcap_opts[combo_index % len(softcap_opts)] use_smooth_softmax, has_head_sink = smmoth_softmax__head_sink[ @@ -1419,9 +1579,7 @@ def has_cuda_device(min_capability: int = 80): return major * 10 + minor >= min_capability -def has_flash_attention(bf16: bool = False): - if bf16 and quick_build: - return False +def has_flash_attention(): return has_cuda_device(80) @@ -1460,7 +1618,7 @@ def test_gqa_past_flash_attention(self, name, config): ) -@unittest.skipIf(not has_flash_attention(bf16=True), "Flash Attention is not available, skipping tests.") +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestFlashGQABF16(unittest.TestCase): @parameterized.expand(gqa_cuda_prompt_test_cases()) def test_gqa_prompt_flash_attention_bf16(self, name, config): @@ -1561,102 +1719,10 @@ def test_gqa_padding_prompt_memory_efficient_attention(self): parity_test_gqa_padding_prompt() -# ################################################################################################# -# Fused Kernel Parity Tests (ORT_DISABLE_FUSED_KV and ORT_DISABLE_FLASH_DECODE) -# ################################################################################################# - - -def fused_kernel_test_cases(): - """Test cases specifically for fused vs unfused kernel parity.""" - configs = [ - # Decoding with RoPE and shared buffer - GQAConfig( - batch_size=2, - q_sequence_length=1, - kv_sequence_length=1, - num_heads=16, - kv_num_heads=4, - head_size=128, - past_kv_sequence_length=128, - buffer_sequence_length=256, - rotary=True, - packed=False, - share_buffer=True, - ), - # Packed QKV decoding with RoPE - GQAConfig( - batch_size=2, - q_sequence_length=1, - kv_sequence_length=1, - num_heads=8, - kv_num_heads=2, - head_size=128, - past_kv_sequence_length=64, - buffer_sequence_length=128, - rotary=True, - packed=True, - share_buffer=True, - ), - # Subsequent prompt with RoPE - GQAConfig( - batch_size=1, - q_sequence_length=4, - kv_sequence_length=4, - num_heads=8, - kv_num_heads=4, - head_size=128, - past_kv_sequence_length=32, - buffer_sequence_length=64, - rotary=True, - packed=False, - share_buffer=True, - ), - ] - for i, config in enumerate(configs): - yield f"fused_config_{i}", config - - @unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") class TestFusedKernelParity(unittest.TestCase): """Tests that verify fused kernels produce the same results as unfused kernels.""" - @parameterized.expand(fused_kernel_test_cases()) - def test_fused_kv_parity(self, name, config): - """Test ORT_DISABLE_FUSED_KV: fused vs unfused KV append kernels.""" - os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" - - # Run with fused kernels (default) - if "ORT_DISABLE_FUSED_KV" in os.environ: - del os.environ["ORT_DISABLE_FUSED_KV"] - - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) - - # Run with unfused kernels - os.environ["ORT_DISABLE_FUSED_KV"] = "1" - - parity_check_gqa_past( - config=config, - ep="CUDAExecutionProvider", - device="cuda", - torch_type=torch.float16, - ort_type=TensorProto.FLOAT16, - causal=True, - rtol=rtol["fp16"], - atol=atol["fp16"], - ) - - # Clean up - del os.environ["ORT_DISABLE_FUSED_KV"] - def test_flash_decode_parity(self): """Test ORT_DISABLE_FLASH_DECODE: fast decode vs standard path.""" os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "0" @@ -1709,5 +1775,50 @@ def test_flash_decode_parity(self): del os.environ["ORT_DISABLE_FLASH_DECODE"] +@unittest.skipIf(not has_flash_attention(), "Flash Attention is not available, skipping tests.") +class TestGQARegressions(unittest.TestCase): + """Specific regression tests for historical bugs.""" + + def test_gqa_rope_separate_qkv_bug(self): + """ + Regression test for separate QKV + RoPE + FlashAttention bug. + The bug caused q_out to be nullptr when unpacking separate QKV with only Q rotation (standard GQA), + leading to unrotated Q being used in Attention. + """ + if "CUDAExecutionProvider" not in get_available_providers(): + self.skipTest("CUDA required") + + # Config that triggers the path: Prompt phase, Separate QKV inputs, RoPE enabled + config = GQAConfig( + batch_size=1, + num_heads=4, + kv_num_heads=4, + head_size=128, + q_sequence_length=16, + kv_sequence_length=16, + past_kv_sequence_length=0, + buffer_sequence_length=16, + rotary=True, + rotary_interleaved=False, + share_buffer=True, + ) + + torch_type = torch.float16 + ort_type = TensorProto.FLOAT16 + device = "cuda" + + parity_check_gqa_prompt( + config=config, + ep="CUDAExecutionProvider", + device=device, + torch_type=torch_type, + ort_type=ort_type, + causal=True, + rtol=1e-3, + atol=1e-3, + std=1.0, + ) + + if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py index 444d86da75ba6..c07eb39e6df75 100644 --- a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py @@ -253,6 +253,7 @@ def export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_head dynamic_axes=dynamic_axes, opset_version=11, do_constant_folding=True, + dynamo=False, ) print("exported:", onnx_model_path) diff --git a/onnxruntime/test/python/transformers/test_phi_vision.py b/onnxruntime/test/python/transformers/test_phi_vision.py index d276366706af9..5a5fa926eb255 100644 --- a/onnxruntime/test/python/transformers/test_phi_vision.py +++ b/onnxruntime/test/python/transformers/test_phi_vision.py @@ -208,6 +208,7 @@ def export(self, model, inputs): "input": {0: "batch", 1: "seq"}, "attention_mask": {0: "batch", 2: "seq", 3: "seq"}, }, + dynamo=False, ) else: torch.onnx.export( @@ -217,6 +218,7 @@ def export(self, model, inputs): export_params=True, opset_version=14, do_constant_folding=True, + dynamo=False, ) def tearDown(self): diff --git a/onnxruntime/test/python/transformers/test_qmoe_cpu.py b/onnxruntime/test/python/transformers/test_qmoe_cpu.py index 90ebb148a26a5..8415c7b08b77c 100644 --- a/onnxruntime/test/python/transformers/test_qmoe_cpu.py +++ b/onnxruntime/test/python/transformers/test_qmoe_cpu.py @@ -23,9 +23,11 @@ # normalization on the selected experts. This provides proper weight distribution # while maintaining computational efficiency. # -------------------------------------------------------------------------- +import os import time import unittest from collections import OrderedDict +from contextlib import contextmanager import numpy import torch @@ -76,6 +78,8 @@ class TensorProtoPlaceholder: ort_provider = ["CPUExecutionProvider"] +ORT_USE_MLAS_Q4_GEMM_MOE = "ORT_USE_MLAS_Q4_GEMM_MOE" + torch.manual_seed(42) numpy.random.seed(42) @@ -364,7 +368,7 @@ def create_cpu_moe_onnx_graph( use_swiglu=False, use_quant=False, quant_bits=4, - swiglu_interleaved=False, + swiglu_fusion=0, block_size=0, ): if not has_onnx: @@ -400,10 +404,10 @@ def create_cpu_moe_onnx_graph( "router_probs", # 1 "fc1_experts_weights", # 2 "fc1_scales", # 3 - "", # 4: fc1_bias + "fc1_experts_bias" if fc1_bias is not None else "", # 4 "fc2_experts_weights", # 5 "fc2_scales", # 6 - "", # 7: fc2_bias + "fc2_experts_bias" if fc2_bias is not None else "", # 7 "", # 8: fc3_weights "", # 9: fc3_scales "", # 10: fc3_bias @@ -442,11 +446,10 @@ def create_cpu_moe_onnx_graph( normalize_routing_weights=normalize_routing, activation_type=activation, # Add new attributes with backwards-compatible default values - swiglu_fusion=1 if use_swiglu else 0, # 1 if using SwiGLU activation + swiglu_fusion=swiglu_fusion, swiglu_limit=7.0, activation_alpha=1.702, activation_beta=1.0, - swiglu_interleaved=1 if swiglu_interleaved else 0, # Enable this attribute domain="com.microsoft", ), ] @@ -559,6 +562,30 @@ def create_cpu_moe_onnx_graph( ) ) + if fc1_bias is not None: + fc1_bias_np = fc1_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc1_experts_bias", + onnx_dtype, + list(fc1_bias.shape), + fc1_bias_np.flatten().tolist(), + raw=False, + ) + ) + + if fc2_bias is not None: + fc2_bias_np = fc2_bias.detach().cpu().numpy().astype(ort_to_numpy_type_map[onnx_dtype]) + initializers.append( + helper.make_tensor( + "fc2_experts_bias", + onnx_dtype, + list(fc2_bias.shape), + fc2_bias_np.flatten().tolist(), + raw=False, + ) + ) + graph_inputs = [ helper.make_tensor_value_info("input", onnx_dtype, [sequence_length, hidden_size]), ] @@ -626,7 +653,7 @@ def __init__( self.num_experts_per_token = num_experts_per_token -def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): +def swiglu(x: torch.Tensor, alpha: float = 1.702, beta: float = 1.0, limit: float = 7.0): dim = x.shape[-1] x = x.view(-1, dim // 2, 2) x_glu, x_linear = x[..., 0], x[..., 1] @@ -635,8 +662,8 @@ def swiglu(x: torch.Tensor, alpha: float = 1.702, limit: float = 7.0): x_glu = x_glu.clamp(max=limit) x_linear = x_linear.clamp(min=-limit, max=limit) - y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + 1) - return y + y = x_glu * torch.sigmoid(alpha * x_glu) * (x_linear + beta) + return y.view(-1, dim // 2) class MoEBlockSparseTop2MLP(nn.Module): @@ -855,7 +882,7 @@ def ort_forward(self, hidden_states: torch.Tensor, enable_performance_test=False e = time.time() time_ms = (e - s) / repeat * 1000 is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = hasattr(self, "swiglu_fusion") and self.swiglu_fusion == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" print(f"ORT Performance - {act_type} {self.quant_bits}-bit: {time_ms:.3f} ms/inference") @@ -868,62 +895,80 @@ def recreate_onnx_model(self): """Recreate the ONNX model with the current weights to reflect any changes to the quantization code.""" w1_list, w2_list = [], [] + w1_bias_list, w2_bias_list = [], [] w1_scale_list, w2_scale_list = [], [] w1_zp_list, w2_zp_list = [], [] is_4_bit = self.quant_bits == 4 for i in range(self.num_experts): - if self.block_size > 0: - # Use block-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( - self.experts[i].w1.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( - self.experts[i].w2.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + if hasattr(self.experts[i], "w3"): + w1, w3 = self.experts[i].w1.weight, self.experts[i].w3.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias + w3_bias = getattr(self.experts[i].w3, "bias", None) + + # Combine and interleave w1 and w3 for the fused kernel + w1_combined = torch.cat([w1, w3], dim=0) # [2*inter, hidden] + if getattr(self, "swiglu_fusion", 0) == 1: + w1_combined = w1_combined.view(2, -1, self.hidden_dim).transpose(0, 1).reshape(-1, self.hidden_dim) + + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1_combined, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + else: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1_combined, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + + if w1_bias is not None and w3_bias is not None: + b1_combined = torch.cat([w1_bias, w3_bias], dim=0) + if getattr(self, "swiglu_fusion", 0) == 1: + b1_combined = b1_combined.view(2, -1).transpose(0, 1).reshape(-1) + w1_bias_list.append(b1_combined.detach().cpu()) + elif w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) else: - # Use row-wise quantization - w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( - self.experts[i].w1.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( - self.experts[i].w2.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) + # PhiMoESwiGLUMLP already has interleaved weights in w1 + w1 = self.experts[i].w1.weight + w2 = self.experts[i].w2.weight + w1_bias = self.experts[i].w1.bias - if self.use_swiglu: - if self.swiglu_interleaved: - pass + if self.block_size > 0: + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant_blockwise( + w1, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant_blockwise( + w2, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant + ) else: - if self.block_size > 0: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant_blockwise( - self.experts[i].w3.weight, self.block_size, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - else: - w3_scale, pre_qweight3, w3_qdq, w3_zp = quant_dequant( - self.experts[i].w3.weight, is_4_bit, asymmetric=self.use_asymmetric_quant - ) - - gate_weights = pre_qweight1 - value_weights = pre_qweight3 - gate_scales = w1_scale - value_scales = w3_scale - gate_zp = w1_zp - value_zp = w3_zp - - pre_qweight1 = torch.cat([gate_weights, value_weights], dim=0) - w1_scale = torch.cat([gate_scales, value_scales], dim=0) - if w1_zp is not None and w3_zp is not None: - w1_zp = torch.cat([gate_zp, value_zp], dim=0) - - if self.swiglu_interleaved: - self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) + w1_scale, pre_qweight1, w1_qdq, w1_zp = quant_dequant( + w1, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + w2_scale, pre_qweight2, w2_qdq, w2_zp = quant_dequant( + w2, is_4_bit, asymmetric=self.use_asymmetric_quant + ) + if w1_bias is not None: + w1_bias_list.append(w1_bias.detach().cpu()) + if self.use_swiglu: + if getattr(self, "swiglu_fusion", 0) == 1: + self.experts[i].w1.weight = nn.Parameter(w1_qdq.contiguous().clone()) else: intermediate_size = self.experts[i].w1.weight.shape[0] gate_dequant = w1_qdq[:intermediate_size].contiguous().clone() value_dequant = w1_qdq[intermediate_size:].contiguous().clone() - self.experts[i].w1.weight.data = gate_dequant - self.experts[i].w3.weight.data = value_dequant + if hasattr(self.experts[i], "w3"): + self.experts[i].w1.weight.data = gate_dequant + self.experts[i].w3.weight.data = value_dequant + else: + self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() else: self.experts[i].w1.weight.data = w1_qdq.contiguous().clone() @@ -931,6 +976,9 @@ def recreate_onnx_model(self): w1_list.append(pre_qweight1) w2_list.append(pre_qweight2) + + if self.experts[i].w2.bias is not None: + w2_bias_list.append(self.experts[i].w2.bias) w1_scale_list.append(w1_scale) w2_scale_list.append(w2_scale) if w1_zp is not None: @@ -963,9 +1011,9 @@ def recreate_onnx_model(self): onnx_dtype=self.onnx_dtype, fc1_experts_weights=self.moe_experts_weight1, fc2_experts_weights=self.moe_experts_weight2, - # Biases are not used in QMoE - fc1_bias=None, - fc2_bias=None, + # Pass collected biases + fc1_bias=torch.stack(w1_bias_list, dim=0) if w1_bias_list else None, + fc2_bias=torch.stack(w2_bias_list, dim=0) if w2_bias_list else None, # Scales are used for dequantization fc1_scales=moe_experts_weight_scale1, fc2_scales=moe_experts_weight_scale2, @@ -975,7 +1023,7 @@ def recreate_onnx_model(self): use_swiglu=self.use_swiglu, use_quant=True, # Always use QMoE quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved if hasattr(self, "swiglu_interleaved") else False, + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, # Add block_size for block-wise quantization ) except Exception: @@ -1020,7 +1068,7 @@ def parity_check(self): max_diff = (torch_output.cpu() - ort_output.cpu()).abs().max() is_swiglu = hasattr(self, "use_swiglu") and self.use_swiglu - is_interleaved = hasattr(self, "swiglu_interleaved") and self.swiglu_interleaved + is_interleaved = getattr(self, "swiglu_fusion", 0) == 1 act_type = f"SwiGLU(interleaved={is_interleaved})" if is_swiglu else "SiLU" quant_type = "Asymmetric" if self.use_asymmetric_quant else "Symmetric" block_type = f"Block({self.block_size})" if self.block_size > 0 else "Row" @@ -1047,24 +1095,6 @@ def parity_check(self): ) print("Torch sample:", torch_output.cpu().reshape(-1, hidden_dim)[i, k].item()) print("ORT sample:", ort_output.cpu().reshape(-1, hidden_dim)[i, k].item()) - # Print routing and per-expert contributions for this token from the PyTorch reference - try: - hidden_states_flat = hidden_state.view(-1, hidden_dim) - token_vec = hidden_states_flat[i : i + 1] - gate_logits = self.gate(token_vec) - topk_vals, topk_experts = torch.topk(gate_logits, self.top_k, dim=-1) - topk_soft = F.softmax(topk_vals, dim=1) - print("Gate logits:", gate_logits.detach().cpu().numpy()) - print("Selected experts:", topk_experts.detach().cpu().numpy()) - print("Routing weights:", topk_soft.detach().cpu().numpy()) - # Compute per-expert contributions for selected experts - for idx_e, e in enumerate(topk_experts[0].tolist()): - expert_layer = self.experts[e] - expert_out = expert_layer(token_vec) - contrib = expert_out[0, k].item() * topk_soft[0, idx_e].item() - print(f"Expert {e} contrib at hidden {k}: {contrib}") - except Exception as _: - pass ort_dtype_quant_bits_tolerance_map = { "FP32:0": (5e-3, 1e-3), @@ -1111,6 +1141,43 @@ def small_test_cases(): yield batch_size, sequence_length +def with_mlas_q4_mode(test_cases): + expanded_cases = [] + for case in test_cases: + quant_bits = case[2] + if quant_bits == 4: + expanded_cases.append((*case, None)) + expanded_cases.append((*case, False)) + expanded_cases.append((*case, True)) + else: + expanded_cases.append((*case, None)) + return expanded_cases + + +@contextmanager +def scoped_env_var(name: str, value: str): + previous = os.environ.get(name) + os.environ[name] = value + try: + yield + finally: + if previous is None: + os.environ.pop(name, None) + else: + os.environ[name] = previous + + +def run_parity_with_mlas_q4_mode(test_runner, enable_mlas_q4_gemm: bool | None): + if enable_mlas_q4_gemm is None: # No env var + test_runner() + else: + env_value = "1" if enable_mlas_q4_gemm else "0" + mode = "enabled" if enable_mlas_q4_gemm else "disabled" + print(f"DirectQ4 mode ({ORT_USE_MLAS_Q4_GEMM_MOE}) is {mode}") + with scoped_env_var(ORT_USE_MLAS_Q4_GEMM_MOE, env_value): + test_runner() + + class SwigluMoEBlock(SparseMoeBlockORTHelper): def __init__( self, @@ -1128,7 +1195,7 @@ def __init__( self.num_experts = config.num_local_experts self.top_k = config.num_experts_per_token self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1232,7 +1299,7 @@ def __init__( self.top_k = config.num_experts_per_tok self.router_jitter_noise = config.router_jitter_noise self.use_swiglu = True - self.swiglu_interleaved = True + self.swiglu_fusion = 1 self.block_size = block_size use_quant = self.quant_bits > 0 @@ -1314,7 +1381,8 @@ def __init__( use_swiglu=self.use_swiglu, use_quant=use_quant, quant_bits=self.quant_bits, - swiglu_interleaved=self.swiglu_interleaved, + # swiglu_fusion=1 means fused and interleaved, which is the standard for QMoE. + swiglu_fusion=getattr(self, "swiglu_fusion", 0), block_size=self.block_size, ) @@ -1354,8 +1422,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states -disable_cpu_qmoe_tests = False - # Define test cases for different MoE types phi3_test_cases = [ (1, 32, 4), @@ -1373,10 +1439,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestPhiQMoECPU(unittest.TestCase): - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 2000 # Different base seed from other tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1411,10 +1476,10 @@ def test_phi3_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_test_cases) - def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(phi3_test_cases)) + def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 3000 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1436,10 +1501,12 @@ def test_phi3_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quan onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1468,10 +1535,12 @@ def test_phi3_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - phi3_moe.parity_check() + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(phi3_blockwise_test_cases) - def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(phi3_blockwise_test_cases)) + def test_phi3_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1489,10 +1558,8 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le block_size=block_size, use_asymmetric_quant=True, ) - phi3_moe.parity_check() - + run_parity_with_mlas_q4_mode(phi3_moe.parity_check, enable_mlas_q4_gemm) -disable_cpu_qmoe_tests = False swiglu_test_cases = [ (1, 32, 4), @@ -1510,10 +1577,9 @@ def test_phi3_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_le ] -@unittest.skipIf(disable_cpu_qmoe_tests, "Skipping qMoE cpu tests") class TestSwigluQMoECPU(unittest.TestCase): - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): # Create unique seed based on test parameters to ensure different inputs for each test base_seed = 1000 # Different base seed from regular MoE tests param_hash = hash((batch_size, sequence_length, quant_bits)) @@ -1547,10 +1613,10 @@ def test_swiglu_qmoe_parity_cpu(self, batch_size, sequence_length, quant_bits): self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_test_cases) - def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits): + @parameterized.expand(with_mlas_q4_mode(swiglu_test_cases)) + def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, enable_mlas_q4_gemm): base_seed = 1100 param_hash = hash((batch_size, sequence_length, quant_bits)) unique_seed = base_seed + abs(param_hash) % 1000 @@ -1572,10 +1638,12 @@ def test_swiglu_qmoe_asymmetric_parity_cpu(self, batch_size, sequence_length, qu onnx_dtype=TensorProto.FLOAT, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(42) numpy.random.seed(42) @@ -1603,10 +1671,12 @@ def test_swiglu_qmoe_blockwise_parity_cpu(self, batch_size, sequence_length, qua self.assertFalse(torch.isnan(torch_result).any()) self.assertFalse(torch.isinf(torch_result).any()) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) - @parameterized.expand(swiglu_blockwise_test_cases) - def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_length, quant_bits, block_size): + @parameterized.expand(with_mlas_q4_mode(swiglu_blockwise_test_cases)) + def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu( + self, batch_size, sequence_length, quant_bits, block_size, enable_mlas_q4_gemm + ): torch.manual_seed(43) numpy.random.seed(43) @@ -1624,7 +1694,7 @@ def test_swiglu_qmoe_blockwise_asymmetric_parity_cpu(self, batch_size, sequence_ block_size=block_size, use_asymmetric_quant=True, ) - swiglu_moe.parity_check() + run_parity_with_mlas_q4_mode(swiglu_moe.parity_check, enable_mlas_q4_gemm) @unittest.skipIf(True, "Skipping QMoE CPU benchmark tests") @@ -1633,9 +1703,6 @@ class TestQMoESwiGLUBenchmark(unittest.TestCase): def test_qmoe_swiglu_throughput_benchmark(self): """Comprehensive throughput benchmark for QMoE SwiGLU across different configurations.""" - if disable_cpu_qmoe_tests: - self.skipTest("QMoE CPU tests disabled") - print("\n=== QMoE SwiGLU Throughput Benchmark ===") # Test configurations: (name, hidden_size, intermediate_size, num_experts, top_k, quant_bits) diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py index e3ca8e6b6ac9c..e90a14f8d7d61 100644 --- a/onnxruntime/test/python/transformers/test_whisper.py +++ b/onnxruntime/test/python/transformers/test_whisper.py @@ -471,8 +471,9 @@ def export(self, model, inputs, input_names, output_names, dynamic_axes): input_names=input_names, output_names=output_names, dynamic_axes=dynamic_axes, - opset_version=17, + opset_version=18, do_constant_folding=True, + dynamo=False, verbose=False, ) @@ -530,9 +531,7 @@ def test_hf_whisper_encoder_self_attention(self, precision, ep): use_gpu=True, only_onnxruntime=False, ) - name = f"hf_{precision}_encoder_self_attention.onnx" - # optimized_model.save_model_to_file(name) # Uncomment for debugging purposes - self.verify_fusion(optimized_model, name) + self.verify_fusion(optimized_model, f"hf_{precision}_encoder_self_attention.onnx") @parameterized.expand( [ diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 38ef0daefe51a..e472cbcee12d6 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -21,6 +22,7 @@ #include "core/common/common.h" #include "core/common/narrow.h" #include "core/graph/constants.h" +#include "core/framework/plugin_ep_stream.h" #include "core/session/onnxruntime_c_api.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_lite_custom_op.h" @@ -478,6 +480,94 @@ TEST(CApiTest, dim_param) { ASSERT_EQ(strcmp(dim_param, ""), 0); } +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a dense OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_DenseTensor) { + Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); + + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(), + x_shape.data(), x_shape.size()); + Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); +} + +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a scalar OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_Scalar) { + Ort::MemoryInfo info_cpu = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemTypeDefault); + + std::vector x_shape = {}; // Scalar (no shape) + std::array x_values = {1.0f}; + Ort::Value x_value = Ort::Value::CreateTensor(info_cpu, x_values.data(), x_values.size(), + x_shape.data(), x_shape.size()); + Ort::TensorTypeAndShapeInfo type_shape_info = x_value.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + x_value.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); + ASSERT_EQ(shape.shape, nullptr); + ASSERT_EQ(shape.shape_len, 0); +} + +#if !defined(DISABLE_SPARSE_TENSORS) +// Tests calling OrtApi::GetTensorElementTypeAndShapeDataReference for a sparse OrtValue tensor. +TEST(CApiTest, Value_GetTensorElementTypeAndShapeDataReference_SparseTensor) { + std::vector common_shape{9, 9}; + std::vector A_values{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, + 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, + 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, + 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, + 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, + 50.0, 51.0, 52.0, 53.0}; + + // 2 - D index + std::vector indices_shape{gsl::narrow(A_values.size()), 2}; + std::vector A_indices{0, 1, 0, 2, 0, 6, 0, 7, 0, 8, 1, 0, 1, + 1, 1, 2, 1, 6, 1, 7, 1, 8, 2, 0, 2, 1, + 2, 2, 2, 6, 2, 7, 2, 8, 3, 3, 3, 4, 3, + 5, 3, 6, 3, 7, 3, 8, 4, 3, 4, 4, 4, 5, + 4, 6, 4, 7, 4, 8, 5, 3, 5, 4, 5, 5, 5, + 6, 5, 7, 5, 8, 6, 0, 6, 1, 6, 2, 6, 3, + 6, 4, 6, 5, 7, 0, 7, 1, 7, 2, 7, 3, 7, + 4, 7, 5, 8, 0, 8, 1, 8, 2, 8, 3, 8, 4, + 8, 5}; + + Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); + Ort::Value::Shape ort_dense_shape{common_shape.data(), common_shape.size()}; + Ort::Value::Shape ort_values_shape{&indices_shape[0], 1U}; + auto value_sparse = Ort::Value::CreateSparseTensor(info, A_values.data(), ort_dense_shape, ort_values_shape); + value_sparse.UseCooIndices(A_indices.data(), A_indices.size()); + + Ort::TensorTypeAndShapeInfo type_shape_info = value_sparse.GetTensorTypeAndShapeInfo(); + + ONNXTensorElementDataType elem_type = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + Ort::Value::Shape shape{}; + value_sparse.GetTensorElementTypeAndShapeDataReference(elem_type, shape); + + ASSERT_EQ(elem_type, type_shape_info.GetElementType()); + + std::vector expected_shape = type_shape_info.GetShape(); + gsl::span actual_shape(shape.shape, shape.shape_len); + ASSERT_EQ(actual_shape, gsl::span(expected_shape)); +} +#endif // !defined(DISABLE_SPARSE_TENSORS) + static std::pair LoadAndGetInputShapePresent(const ORTCHAR_T* const model_url) { Ort::Session session(*ort_env, model_url, Ort::SessionOptions{}); const auto input_num = session.GetInputCount(); @@ -4827,3 +4917,204 @@ TEST(CApiTest, ModelWithExternalDataOutsideModelDirectoryShouldFailToLoad) { exception_message.find("model") != std::string::npos) << "Exception message should indicate external data or security issue. Got: " << exception_message; } + +TEST(CApiTest, InMemoryModel_ExternalDataOutsideWorkingDirectory_FailToLoad) { + // Attempt to create an ORT session with the malicious model (loaded from bytes). + // This should fail due to the use of an external file path that is not under current working directory. + // i.e. ../../../../etc/passwd + constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_arbitrary_external_file.onnx"); + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); + Ort::SessionOptions session_options; + + // Load model contents into array + std::ifstream model_file_stream(model_path, std::ios::in | std::ios::binary); + ASSERT_TRUE(model_file_stream.good()); + model_file_stream.seekg(0, std::ios::end); + const auto file_contents_size = onnxruntime::narrow(model_file_stream.tellg()); + model_file_stream.seekg(0, std::ios::beg); + std::vector file_contents(file_contents_size, 0); + model_file_stream.read(&file_contents[0], file_contents_size); + model_file_stream.close(); + + bool exception_thrown = false; + std::string exception_message; + + try { + // This should throw an exception due to malicious external data + Ort::Session session(env, file_contents.data(), file_contents_size, session_options); + } catch (const Ort::Exception& e) { + exception_thrown = true; + exception_message = e.what(); + } catch (const std::exception& e) { + exception_thrown = true; + exception_message = e.what(); + } + + // Verify that loading the model failed + EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data path"; + + // Verify that the exception message indicates security or external data issues + EXPECT_TRUE(exception_message.find("External data path") != std::string::npos && + exception_message.find("escapes working directory") != std::string::npos) + << "Exception message should indicate external data or security issue. Got: " << exception_message; +} + +TEST(CApiTest, InMemoryModel_SessionConfigExternalFileFolder_ExternalDataOutsideModelDirectory_FailToLoad) { + // Attempt to create an ORT session with the malicious model (loaded from bytes). + // A valid external file folder path is explicitly set via session options. + // However, this should still fail due to the use of an external file path that escapes the set directory. + // i.e. ../../../../etc/passwd + constexpr const ORTCHAR_T* model_path = TSTR("testdata/test_arbitrary_external_file.onnx"); + + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "test"); + Ort::SessionOptions session_options; + session_options.AddConfigEntry(kOrtSessionOptionsModelExternalInitializersFileFolderPath, "testdata"); + + // Load model contents into array + std::ifstream model_file_stream(model_path, std::ios::in | std::ios::binary); + ASSERT_TRUE(model_file_stream.good()); + model_file_stream.seekg(0, std::ios::end); + const auto file_contents_size = onnxruntime::narrow(model_file_stream.tellg()); + model_file_stream.seekg(0, std::ios::beg); + std::vector file_contents(file_contents_size, 0); + model_file_stream.read(&file_contents[0], file_contents_size); + model_file_stream.close(); + + bool exception_thrown = false; + std::string exception_message; + + try { + // This should throw an exception due to malicious external data + Ort::Session session(env, file_contents.data(), file_contents_size, session_options); + } catch (const Ort::Exception& e) { + exception_thrown = true; + exception_message = e.what(); + } catch (const std::exception& e) { + exception_thrown = true; + exception_message = e.what(); + } + + // Verify that loading the model failed + EXPECT_TRUE(exception_thrown) << "Expected model loading to fail due to malicious external data path"; + + // Verify that the exception message indicates security or external data issues + EXPECT_TRUE(exception_message.find("External data path") != std::string::npos && + exception_message.find("escapes both model directory") != std::string::npos && + exception_message.find("and real model directory") != std::string::npos) + << "Exception message should indicate external data or security issue. Got: " << exception_message; +} + +#ifdef ORT_ENABLE_STREAM +#if USE_CUDA + +namespace { +struct TestCudaStreamOverrideUsed : onnxruntime::Stream { + TestCudaStreamOverrideUsed(onnxruntime::Stream* stream) + : onnxruntime::Stream(stream->GetHandle(), stream->GetDevice()), real_stream(stream) {} + + std::unique_ptr CreateNotification(size_t num_consumers) override { + return real_stream->CreateNotification(num_consumers); + } + + TestCudaStreamOverrideUsed(const TestCudaStreamOverrideUsed&) = delete; + TestCudaStreamOverrideUsed& operator=(const TestCudaStreamOverrideUsed&) = delete; + + void Flush() override { + flush_count++; + real_stream->Flush(); + } + + onnxruntime::Status CleanUpOnRunEnd() override { return real_stream->CleanUpOnRunEnd(); } + + onnxruntime::Stream* real_stream; + size_t flush_count{0}; +}; +} // namespace + +TEST(CApiTest, TestSyncStreamOverride) { +#ifdef _WIN32 + auto cuda_lib = ORT_TSTR("onnxruntime_providers_cuda.dll"); +#else + auto cuda_lib = ORT_TSTR("onnxruntime_providers_cuda.so"); +#endif + + if (!std::filesystem::exists(cuda_lib)) { + GTEST_SKIP() << "CUDA library was not found"; + } + + constexpr const char* cuda_ep_name = "ORT Cuda"; + ort_env->RegisterExecutionProviderLibrary(cuda_ep_name, cuda_lib); + auto ep_devices = ort_env->GetEpDevices(); + + Ort::ConstEpDevice cuda_device; + for (const auto& device : ep_devices) { + if (device.Device().Type() == OrtHardwareDeviceType_GPU && + device.Device().VendorId() == 0x10DE) { // NVIDIA vendor ID + cuda_device = device; + break; + } + } + + if (!cuda_device) { + GTEST_SKIP() << "No CUDA device found, skipping test."; + } + + // Create session with CUDA EP using C++ public API in Ort:: namespace + { + // Create a stream on CUDA Device + const auto sync_stream = cuda_device.CreateSyncStream(); + TestCudaStreamOverrideUsed cuda_override_stream(sync_stream); + + Ort::SessionOptions session_options; + session_options.AppendExecutionProvider_V2(*ort_env, {cuda_device}, Ort::KeyValuePairs{}); + + Ort::Session session(*ort_env, MODEL_URI, session_options); + + constexpr const std::array input_names = {"X"}; + constexpr const std::array output_names = {"Y"}; + constexpr const std::array input_shape = {3LL, 2LL}; + float x_value[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + auto input_value = Ort::Value::CreateTensor( + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU), + x_value, std::size(x_value), input_shape.data(), input_shape.size()); + Ort::Value ort_inputs[] = {std::move(input_value)}; + + Ort::RunOptions run_options; + run_options.SetSyncStream(reinterpret_cast(&cuda_override_stream)); + + auto output_values = session.Run(run_options, + input_names.data(), ort_inputs, std::size(ort_inputs), + output_names.data(), output_names.size()); + + ASSERT_GT(cuda_override_stream.flush_count, 0U) + << "Expected the custom CUDA stream override to be used during session run."; + } + + ort_env->UnregisterExecutionProviderLibrary(cuda_ep_name); +} +#endif +#endif + +#if !defined(ORT_MINIMAL_BUILD) +TEST(CApiTest, GetEpGraphAssignmentInfo_NotEnabledError) { + // Test that calling OrtApi::Session_GetEpGraphAssignmentInfo() without enabling the appropriate + // session configuration option returns an error. + + Ort::SessionOptions options; + // Do not set: + // options.AddConfigEntry(kOrtSessionOptionsRecordEpGraphAssignmentInfo, "1"); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/mul_1.onnx"), options); + try { + session.GetEpGraphAssignmentInfo(); + ASSERT_TRUE(false) << "Call to Session_GetEpGraphAssignmentInfo should have failed"; + } catch (const Ort::Exception& ex) { + ASSERT_EQ(ex.GetOrtErrorCode(), ORT_FAIL); + + std::ostringstream oss; + oss << "Session configuration entry '" << kOrtSessionOptionsRecordEpGraphAssignmentInfo << "' must be set to \"1\""; + ASSERT_THAT(ex.what(), testing::HasSubstr(oss.str())); + } +} +#endif diff --git a/onnxruntime/test/testdata/custom_mul.onnx b/onnxruntime/test/testdata/custom_mul.onnx new file mode 100644 index 0000000000000..87bb64764a669 Binary files /dev/null and b/onnxruntime/test/testdata/custom_mul.onnx differ diff --git a/onnxruntime/test/testdata/custom_mul.py b/onnxruntime/test/testdata/custom_mul.py new file mode 100644 index 0000000000000..2639648561fe1 --- /dev/null +++ b/onnxruntime/test/testdata/custom_mul.py @@ -0,0 +1,45 @@ +import onnx + + +def create_custom_mul_model(): + # === Inputs === + x = onnx.helper.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [3, 2]) + w = onnx.helper.make_tensor_value_info("W", onnx.TensorProto.FLOAT, [3, 2]) + + # === Output === + y = onnx.helper.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [3, 2]) + + # === Custom Node: Custom_Mul === + # Replace "Mul" with your custom op name and domain + custom_node = onnx.helper.make_node( + op_type="Custom_Mul", # <-- custom op name + inputs=["X", "W"], + outputs=["Y"], + domain="test", # <-- custom domain + ) + + # === Graph === + graph = onnx.helper.make_graph( + nodes=[custom_node], + name="CustomMulGraph", + inputs=[x, w], + outputs=[y], + ) + + # === Model (opset version 13 or later is fine) === + model = onnx.helper.make_model( + graph, + opset_imports=[ + onnx.helper.make_opsetid("", 13), # standard ONNX domain + onnx.helper.make_opsetid("com.example", 1), + ], # your custom domain + producer_name="custom_mul_builder", + ) + + return model + + +# ===== Save the Model ===== +model = create_custom_mul_model() +onnx.save(model, "custom_mul.onnx") +print("Saved custom_mul.onnx") diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 60d28c491dfd1..3565208833266 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -785,7 +785,10 @@ "^test_reduce_max_empty_set_cpu", // DNNL result in "(shapes (2, 1, 4), (1, 0, 1) mismatch)". this is the same for test_reduce_min_empty_set which is already in the list "^test_reduce_min_empty_set_cpu", "^test_resize_upsample_sizes_nearest_not_smaller_cpu", - "^test_clip_min_greater_than_max_cpu" + "^test_clip_min_greater_than_max_cpu", + // Fail since v1.20.1 (new matmul 1D tests) + "^test_matmul_1d_1d_cpu", + "^test_matmul_4d_1d_cpu" ], // ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported "tests_with_pre_opset7_dependencies": [ diff --git a/onnxruntime/test/unittest_util/base_tester.cc b/onnxruntime/test/unittest_util/base_tester.cc index d8bfd425f1f1a..2e0459103a7c9 100644 --- a/onnxruntime/test/unittest_util/base_tester.cc +++ b/onnxruntime/test/unittest_util/base_tester.cc @@ -424,7 +424,7 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, bool SetEpsForAllNodes(Graph& graph, const std::vector>& execution_providers, const std::vector>* custom_registries, - const std::function& ep_uses_kernel_registry_fn) { + const std::function& ep_only_uses_kernel_registry_fn) { const OpSchemaKernelTypeStrResolver kernel_type_str_resolver{}; const KernelRegistry::TypeConstraintMap type_constraint_map{}; @@ -440,7 +440,7 @@ bool SetEpsForAllNodes(Graph& graph, node.SetExecutionProviderType(provider_type); - if (!ep_uses_kernel_registry_fn(*ep)) { + if (!ep_only_uses_kernel_registry_fn(*ep)) { found = true; break; } @@ -659,7 +659,12 @@ void BaseTester::RunWithConfig(size_t* number_of_pre_packed_weights_counter, #endif kDnnlExecutionProvider, kTensorrtExecutionProvider, +#ifdef USE_NV + // Only include NV TRT RTX EP when is ORT is built with the provider-bridge + // version of the EP (i.e., USE_NV is defined). This allows use of the plugin EP version of the EP + // when ORT is not built any provider-bridge EPs. kNvTensorRTRTXExecutionProvider, +#endif kOpenVINOExecutionProvider, kDmlExecutionProvider, kAclExecutionProvider, @@ -830,12 +835,15 @@ void BaseTester::ExecuteModelForEps( ASSERT_TRUE(!execution_providers.empty()) << "Empty execution providers vector."; if (try_assign_ep_for_nodes) { - auto ep_uses_kernel_registry = [](const IExecutionProvider& ep) { + auto ep_only_uses_kernel_registry = [](const IExecutionProvider& ep) { const auto& provider_type = ep.Type(); - constexpr std::array kEpsThatDoNotUseKernelRegistry{ + constexpr std::array kEpsThatCompileNodes{ kOpenVINOExecutionProvider, - kTensorrtExecutionProvider, + kTensorrtExecutionProvider, // uses kernel registry for Memcpy* nodes only +#ifdef USE_NV + kNvTensorRTRTXExecutionProvider, // uses kernel registry for Memcpy* nodes only +#endif kNnapiExecutionProvider, kVSINPUExecutionProvider, kCoreMLExecutionProvider, @@ -844,24 +852,33 @@ void BaseTester::ExecuteModelForEps( kSnpeExecutionProvider, }; - // check list of known EPs that do not use a kernel registry - if (const auto ep_it = std::find(kEpsThatDoNotUseKernelRegistry.begin(), kEpsThatDoNotUseKernelRegistry.end(), + // check list of known EPs that compile nodes + if (const auto ep_it = std::find(kEpsThatCompileNodes.begin(), kEpsThatCompileNodes.end(), provider_type); - ep_it != kEpsThatDoNotUseKernelRegistry.end()) { + ep_it != kEpsThatCompileNodes.end()) { return false; } - // assume that a dynamic plugin EP which does not return a kernel registry does not use one - if (provider_type == dynamic_plugin_ep_infra::GetEpName() && - ep.GetKernelRegistry() == nullptr) { - return false; + const OrtEp* ort_ep = ep.GetOrtEp(); + + if (ort_ep != nullptr) { // This is a plugin EP + + if (ep.GetKernelRegistry() == nullptr) { + // assume that a dynamic plugin EP which does not return a kernel registry does not use one + return false; + } + + if (ort_ep->Compile != nullptr) { + // assume that a plugin EP that compiles nodes does not use a kernel registry for all nodes + return false; + } } // otherwise, assume that the EP uses a kernel registry return true; }; - if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_uses_kernel_registry)) { + if (!SetEpsForAllNodes(model.MainGraph(), execution_providers, custom_registries, ep_only_uses_kernel_registry)) { std::string providers; for (const auto& ep : execution_providers) { providers.append(ep->Type() + " "); diff --git a/onnxruntime/test/unittest_util/base_tester.h b/onnxruntime/test/unittest_util/base_tester.h index 58b67a0d67d3c..79a74ef1651c5 100644 --- a/onnxruntime/test/unittest_util/base_tester.h +++ b/onnxruntime/test/unittest_util/base_tester.h @@ -700,6 +700,10 @@ class BaseTester { const int64_t expected_values_count = T::CalcNumInt4Pairs(shape.Size()); ORT_ENFORCE(expected_values_count == values_count, values_count, " input values doesn't match tensor size of ", expected_values_count); + } else if constexpr (std::is_same_v || std::is_same_v) { + const int64_t expected_values_count = T::CalcNumInt2Quads(shape.Size()); + ORT_ENFORCE(expected_values_count == values_count, values_count, + " input values doesn't match tensor size of ", expected_values_count); } #if !defined(DISABLE_FLOAT4_TYPES) else if constexpr (std::is_same_v) { diff --git a/onnxruntime/test/unittest_util/checkers.cc b/onnxruntime/test/unittest_util/checkers.cc index 7b2a5a4a4ff2f..88a6241bf7ee3 100644 --- a/onnxruntime/test/unittest_util/checkers.cc +++ b/onnxruntime/test/unittest_util/checkers.cc @@ -9,6 +9,7 @@ #include "core/graph/constants.h" #include "core/framework/TensorSeq.h" #include "core/framework/int4.h" +#include "core/framework/int2.h" #include "core/framework/float4.h" #include "test/unittest_util/framework_test_utils.h" #include "test/unittest_util/conversion.h" @@ -259,6 +260,44 @@ struct TensorCheck { } }; +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + const Int2x4* cur_expected; + const Int2x4* cur_actual; + const auto size = narrow(actual.Shape().Size()); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < size; ++i) { + size_t r = i >> 2; + size_t c = i & 0x3; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } + } +}; + +template <> +struct TensorCheck { + void operator()(const Tensor& expected, const Tensor& actual, const ValidateOutputParams& params, + const std::string& /*provider_type*/) const { + ORT_UNUSED_PARAMETER(params); + const UInt2x4* cur_expected; + const UInt2x4* cur_actual; + const auto size = narrow(actual.Shape().Size()); + cur_expected = expected.Data(); + cur_actual = actual.Data(); + + for (size_t i = 0; i < size; ++i) { + size_t r = i >> 2; + size_t c = i & 0x3; + EXPECT_EQ(cur_expected[r].GetElem(c), cur_actual[r].GetElem(c)) << "i:" << i; + } + } +}; + template <> struct TensorCheck { void operator()(const Tensor& expected, @@ -536,7 +575,7 @@ void Check(std::string_view name, const OrtValue& expected, const Tensor utils::MLTypeCallDispatcher IsResultExactlyMatch(const Tenso return std::make_pair(COMPARE_RESULT::SUCCESS, ""); } +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const Int2x4* expected_output = expected_value.Data(); + const Int2x4* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 2; + size_t c = di & 0x3; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << static_cast(expected_output[r].GetElem(c)) + << ", got " << static_cast(real_output[r].GetElem(c)); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + +template <> +std::pair IsResultExactlyMatch(const Tensor& outvalue, + const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const UInt2x4* expected_output = expected_value.Data(); + const UInt2x4* real_output = outvalue.Data(); + for (size_t di = 0; di != size1; ++di) { + size_t r = di >> 2; + size_t c = di & 0x3; + + if (expected_output[r].GetElem(c) != real_output[r].GetElem(c)) { + std::ostringstream oss; + oss << "expected " << static_cast(expected_output[r].GetElem(c)) + << ", got " << static_cast(real_output[r].GetElem(c)); + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} + std::pair CompareFloat16Result(const Tensor& outvalue, const Tensor& expected_value, double per_sample_tolerance, double relative_per_sample_tolerance, @@ -356,6 +396,10 @@ std::pair CompareTwoTensors(const Tensor& outvalue, return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return IsResultExactlyMatch(outvalue, expected_tensor); } else if (outvalue.IsDataType()) { return CompareFloat16Result(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing); diff --git a/requirements.txt b/requirements.txt index 2fd9362c949dd..ff8cc04d6f219 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,3 @@ -coloredlogs flatbuffers numpy >= 1.21.6 packaging diff --git a/samples/cxx/CMakeLists.txt b/samples/cxx/CMakeLists.txt new file mode 100644 index 0000000000000..875e37c64eda2 --- /dev/null +++ b/samples/cxx/CMakeLists.txt @@ -0,0 +1,50 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +cmake_minimum_required(VERSION 3.28) + +project(onnxruntime_sample CXX) + +set(CMAKE_CXX_STANDARD 20) + +foreach(VAR IN ITEMS ORT_LIBRARY_DIR ORT_HEADER_DIR) + if (NOT DEFINED ${VAR}) + message(FATAL_ERROR "Required variable ${VAR} is not set. " + "Set ORT_LIBRARY_DIR to the ONNX Runtime lib directory and " + "ORT_HEADER_DIR to the ONNX Runtime include directory.") + endif() +endforeach() + +# Resolve to absolute paths +get_filename_component(ORT_LIBRARY_DIR "${ORT_LIBRARY_DIR}" ABSOLUTE) +get_filename_component(ORT_HEADER_DIR "${ORT_HEADER_DIR}" ABSOLUTE) + +# +# onnxruntime_sample_program +# +block() +add_executable(onnxruntime_sample_program) + +target_sources(onnxruntime_sample_program PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/main.cc) + +target_include_directories(onnxruntime_sample_program PRIVATE ${ORT_HEADER_DIR}) + +target_link_directories(onnxruntime_sample_program PRIVATE ${ORT_LIBRARY_DIR}) +target_link_libraries(onnxruntime_sample_program PRIVATE onnxruntime) + +# Copy ONNX Runtime shared libraries next to the executable. +# Collect shared library files from the ORT library directory based on platform. +if (WIN32) + file(GLOB ORT_SHARED_LIBS "${ORT_LIBRARY_DIR}/*.dll") +elseif (APPLE) + file(GLOB ORT_SHARED_LIBS "${ORT_LIBRARY_DIR}/*.dylib") +else() + file(GLOB ORT_SHARED_LIBS "${ORT_LIBRARY_DIR}/*.so" "${ORT_LIBRARY_DIR}/*.so.*") +endif() + +add_custom_command(TARGET onnxruntime_sample_program POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different + ${ORT_SHARED_LIBS} + $ +) +endblock() diff --git a/samples/cxx/README.md b/samples/cxx/README.md new file mode 100644 index 0000000000000..1904c082cef7a --- /dev/null +++ b/samples/cxx/README.md @@ -0,0 +1,92 @@ +# ONNX Runtime C++ Sample + +A minimal C++ program demonstrating basic ONNX Runtime inference. It loads an ONNX model that adds two float tensors (`C = A + B`), runs inference, and verifies the result. + +## Prerequisites + +- CMake 3.28 or later +- C++20 compatible compiler (e.g., Visual Studio 2022) +- An ONNX Runtime release package (download from [GitHub releases](https://github.com/microsoft/onnxruntime/releases)) +- For model generation: + - Python with the `onnx` package + +## Directory Structure + +``` +samples/cxx/ +├── CMakeLists.txt # Build configuration +├── main.cc # Sample program source +├── add_model.onnx # ONNX model (C = A + B) +├── generate_model.py # Script to generate the ONNX model +└── README.md # This file +``` + +## Steps + +### 1. Extract the ONNX Runtime package + +Download and extract an ONNX Runtime release archive. For example: + +``` +tar -xf onnxruntime-win-x64-1.25.0.zip +``` + +This creates a directory like `onnxruntime-win-x64-1.25.0/` containing `include/` and `lib/` subdirectories. + +### 2. [Optional] Generate the ONNX model + +``` +cd samples/cxx +pip install onnx +python generate_model.py +``` + +This creates `add_model.onnx` in the current directory. + +### 3. Configure and build + +From the `samples/cxx` directory: + +**Windows:** +``` +cmake -S . -B build ^ + -DORT_HEADER_DIR:PATH=path\to\onnxruntime-win-x64-1.25.0\include ^ + -DORT_LIBRARY_DIR:PATH=path\to\onnxruntime-win-x64-1.25.0\lib +cmake --build build --config Release +``` + +**Linux / macOS:** +``` +cmake -S . -B build \ + -DORT_HEADER_DIR:PATH=path/to/onnxruntime-linux-x64-1.25.0/include \ + -DORT_LIBRARY_DIR:PATH=path/to/onnxruntime-linux-x64-1.25.0/lib +cmake --build build --config Release +``` + +Adjust the paths to match your extracted package name and location. + +The build automatically copies the ONNX Runtime shared libraries next to the executable. + +#### CMake Variables + +| Variable | Description | +|---|---| +| `ORT_HEADER_DIR` | Path to the ONNX Runtime `include` directory | +| `ORT_LIBRARY_DIR` | Path to the ONNX Runtime `lib` directory | + +### 4. Run + +**Windows:** +``` +build\Release\onnxruntime_sample_program.exe +``` + +**Linux / macOS:** +``` +./build/onnxruntime_sample_program +``` + +You can also pass a model path as an argument: +``` +onnxruntime_sample_program path/to/add_model.onnx +``` diff --git a/samples/cxx/add_model.onnx b/samples/cxx/add_model.onnx new file mode 100644 index 0000000000000..36308c1372a22 Binary files /dev/null and b/samples/cxx/add_model.onnx differ diff --git a/samples/cxx/generate_model.py b/samples/cxx/generate_model.py new file mode 100644 index 0000000000000..9ac70ab29deb4 --- /dev/null +++ b/samples/cxx/generate_model.py @@ -0,0 +1,42 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""Generate a simple ONNX model that computes C = A + B. + +Inputs: + A : float tensor of shape [1, 3] + B : float tensor of shape [1, 3] + +Output: + C : float tensor of shape [1, 3] + +Usage: + pip install onnx + python generate_model.py +""" + +from onnx import TensorProto, helper, save_model +from onnx.checker import check_model + + +def main(): + # Define inputs and output + a = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 3]) + b = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 3]) + c = helper.make_tensor_value_info("C", TensorProto.FLOAT, [1, 3]) + + # Create the Add node + add_node = helper.make_node("Add", inputs=["A", "B"], outputs=["C"]) + + # Build the graph and model + graph = helper.make_graph([add_node], "add_graph", [a, b], [c]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + # Validate and save + check_model(model) + save_model(model, "add_model.onnx") + print("Saved add_model.onnx") + + +if __name__ == "__main__": + main() diff --git a/samples/cxx/main.cc b/samples/cxx/main.cc new file mode 100644 index 0000000000000..4e31e033ab8c7 --- /dev/null +++ b/samples/cxx/main.cc @@ -0,0 +1,170 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Sample program demonstrating basic ONNX Runtime C++ API usage. +// Loads a simple ONNX model (C = A + B), runs inference, and prints the result. +// +// Generate the model first: python generate_model.py + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" + +// Throw std::runtime_error if `condition` is false. Includes file and line info. +#define THROW_IF_NOT(condition) \ + do { \ + if (!(condition)) { \ + throw std::runtime_error(std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ": " + \ + "check failed: " #condition); \ + } \ + } while (0) + +int main(int argc, char* argv[]) { + try { + // ----------------------------------------------------------------------- + // 1. Initialize the ONNX Runtime environment + // ----------------------------------------------------------------------- + Ort::Env env(ORT_LOGGING_LEVEL_WARNING, "onnxruntime_sample"); + std::cout << "ONNX Runtime version: " << Ort::GetVersionString() << "\n\n"; + + // ----------------------------------------------------------------------- + // 2. Create session options (could add execution providers here) + // ----------------------------------------------------------------------- + Ort::SessionOptions session_options; + session_options.SetIntraOpNumThreads(1); + session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC); + + // ----------------------------------------------------------------------- + // 3. Load the ONNX model from a file + // Generate with: python generate_model.py + // ----------------------------------------------------------------------- + const std::filesystem::path model_path = (argc > 1) ? argv[1] : "add_model.onnx"; + std::cout << "Loading model: " << model_path.string() << "\n"; + + Ort::Session session(env, model_path.native().c_str(), session_options); + + // ----------------------------------------------------------------------- + // 4. Query model metadata: input/output names and shapes + // ----------------------------------------------------------------------- + Ort::AllocatorWithDefaultOptions allocator; + + const size_t num_inputs = session.GetInputCount(); + const size_t num_outputs = session.GetOutputCount(); + std::cout << "Model inputs: " << num_inputs << "\n"; + std::cout << "Model outputs: " << num_outputs << "\n"; + + // Collect input/output names + std::vector input_names; + std::vector output_names; + + for (size_t i = 0; i < num_inputs; ++i) { + auto name = session.GetInputNameAllocated(i, allocator); + std::cout << " Input " << i << ": " << name.get() << "\n"; + input_names.emplace_back(name.get()); + } + for (size_t i = 0; i < num_outputs; ++i) { + auto name = session.GetOutputNameAllocated(i, allocator); + std::cout << " Output " << i << ": " << name.get() << "\n"; + output_names.emplace_back(name.get()); + } + std::cout << "\n"; + + // ----------------------------------------------------------------------- + // 5. Prepare input tensors + // ----------------------------------------------------------------------- + // Our model expects two float tensors of shape [1, 3]. + constexpr int64_t batch_size = 1; + constexpr int64_t num_elements = 3; + const std::array input_shape = {batch_size, num_elements}; + + std::array input_a = {1.0f, 2.0f, 3.0f}; + std::array input_b = {4.0f, 5.0f, 6.0f}; + + auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); + + auto tensor_a = Ort::Value::CreateTensor( + memory_info, input_a.data(), input_a.size(), + input_shape.data(), input_shape.size()); + + auto tensor_b = Ort::Value::CreateTensor( + memory_info, input_b.data(), input_b.size(), + input_shape.data(), input_shape.size()); + + THROW_IF_NOT(tensor_a.IsTensor()); + THROW_IF_NOT(tensor_b.IsTensor()); + + // The Run() API expects arrays of C strings for input/output names. + std::vector input_name_ptrs; + std::vector output_name_ptrs; + for (const auto& n : input_names) input_name_ptrs.push_back(n.c_str()); + for (const auto& n : output_names) output_name_ptrs.push_back(n.c_str()); + + std::array input_tensors{std::move(tensor_a), std::move(tensor_b)}; + + // ----------------------------------------------------------------------- + // 6. Run inference + // ----------------------------------------------------------------------- + std::cout << "Running inference...\n"; + + Ort::RunOptions run_options; + auto output_tensors = session.Run( + run_options, + input_name_ptrs.data(), input_tensors.data(), input_tensors.size(), + output_name_ptrs.data(), output_name_ptrs.size()); + + // ----------------------------------------------------------------------- + // 7. Process output + // ----------------------------------------------------------------------- + THROW_IF_NOT(!output_tensors.empty() && output_tensors[0].IsTensor()); + + const float* output_data = output_tensors[0].GetTensorData(); + auto type_info = output_tensors[0].GetTensorTypeAndShapeInfo(); + size_t output_count = type_info.GetElementCount(); + + std::cout << "\nInputs:\n"; + std::cout << " A = ["; + for (size_t i = 0; i < input_a.size(); ++i) { + std::cout << (i ? ", " : "") << input_a[i]; + } + std::cout << "]\n"; + + std::cout << " B = ["; + for (size_t i = 0; i < input_b.size(); ++i) { + std::cout << (i ? ", " : "") << input_b[i]; + } + std::cout << "]\n"; + + std::cout << "\nOutput (A + B):\n"; + std::cout << " C = ["; + for (size_t i = 0; i < output_count; ++i) { + std::cout << (i ? ", " : "") << output_data[i]; + } + std::cout << "]\n"; + + // Verify correctness + bool correct = true; + for (size_t i = 0; i < num_elements; ++i) { + if (output_data[i] != input_a[i] + input_b[i]) { + correct = false; + break; + } + } + std::cout << "\nResult: " << (correct ? "PASS" : "FAIL") << "\n"; + + return correct ? EXIT_SUCCESS : EXIT_FAILURE; + } catch (const Ort::Exception& e) { + std::cerr << "ONNX Runtime error: " << e.what() << "\n"; + return EXIT_FAILURE; + } catch (const std::exception& e) { + std::cerr << "Error: " << e.what() << "\n"; + return EXIT_FAILURE; + } +} diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index e7e5cbe5ea031..de64183e1bb18 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -896,6 +896,8 @@ def generate_build_tree( if not args.no_kleidiai: cmake_args += ["-Donnxruntime_USE_KLEIDIAI=ON"] + if args.use_qmx: + cmake_args += ["-Donnxruntime_USE_QMX_KLEIDIAI_COEXIST=ON"] if args.enable_arm_neon_nchwc: cmake_args += ["-Donnxruntime_USE_ARM_NEON_NCHWC=ON"] @@ -1759,7 +1761,7 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): # Install cpu only version of torch when cuda is not enabled in Linux. extra = [] if args.use_cuda and is_linux() else ["--index-url", "https://download.pytorch.org/whl/cpu"] run_subprocess( - [sys.executable, "-m", "pip", "install", "torch==2.8.0", "torchvision==0.23.0", *extra], + [sys.executable, "-m", "pip", "install", "torch==2.10.0", "torchvision==0.25.0", *extra], cwd=cwd, dll_path=dll_path, python_path=python_path, @@ -1816,6 +1818,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): onnx_test = False if onnx_test: + log.info("Testing Symlink ONNX Model and External Data") + run_subprocess([sys.executable, "onnxruntime_test_python_symlink_data.py"], cwd=cwd, dll_path=dll_path) + # Disable python onnx tests for TensorRT and CANN EP, because many tests are # not supported yet. if args.use_tensorrt or args.use_cann: @@ -1833,11 +1838,9 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): [sys.executable, "-m", "unittest", "discover", "-s", "quantization"], cwd=cwd, dll_path=dll_path ) - # onnx package does not support python 3.14 yet so skip the transformers tests for python 3.14. - # we can remove this check when onnx package supports python 3.14. if args.enable_transformers_tool_test and (sys.version_info.major, sys.version_info.minor) < ( 3, - 14, + 15, ): import google.protobuf # noqa: PLC0415 import numpy # noqa: PLC0415 diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 33d6c39de1aad..f32666f65cc38 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -372,7 +372,7 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: """Adds arguments for WebAssembly (WASM) platform builds.""" parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly.") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build WebAssembly static library.") - parser.add_argument("--emsdk_version", default="4.0.21", help="Specify version of emsdk.") + parser.add_argument("--emsdk_version", default="4.0.23", help="Specify version of emsdk.") parser.add_argument( "--enable_wasm_jspi", action="store_true", help="Enable WebAssembly JavaScript Promise Integration." ) @@ -763,6 +763,12 @@ def add_execution_provider_args(parser: argparse.ArgumentParser) -> None: "--no_kleidiai", action="store_true", help="Disable KleidiAI integration (used with ACL/ArmNN)." ) + # --- Qualcomm QMX Library --- + qmx_group = parser.add_argument_group("QMX kernel library") + qmx_group.add_argument( + "--use_qmx", action="store_true", help="Enable Qualcomm QMX kernel to coexist with Arm KleidiAI." + ) + # --- RKNPU --- rknpu_group = parser.add_argument_group("RKNPU Execution Provider") rknpu_group.add_argument("--use_rknpu", action="store_true", help="Enable RKNPU EP.") diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines-cuda13.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines-cuda13.yml new file mode 100644 index 0000000000000..aee2f18a774cb --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines-cuda13.yml @@ -0,0 +1,145 @@ +parameters: +- name: RunOnnxRuntimeTests + displayName: Run Tests? + type: boolean + default: true + +- name: UseIncreasedTimeoutForTests + displayName: Increase timeout for tests? Set it to false if you are doing an Onnx Runtime release. + type: boolean + default: false + +- name: IsReleaseBuild + displayName: Is a release build? Set it to true if you are doing an ONNX Runtime release. + type: boolean + default: false + +- name: PreReleaseVersionSuffixString + displayName: Suffix added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the type of pre-release package. + type: string + values: + - alpha + - beta + - rc + - none + default: none + +- name: PreReleaseVersionSuffixNumber + displayName: Number added to pre-release package version. Only used if IsReleaseBuild is true. Denotes the sequence of a pre-release package. + type: number + default: 0 + +- name: AdditionalBuildFlag + displayName: Build flags to append to build command + type: string + default: '--use_azure' + +- name: NugetPackageSuffix + displayName: Suffix to append to nuget package + type: string + default: 'NONE' + +# these 2 parameters are used for debugging. +- name: SpecificArtifact + displayName: Use Specific Artifact (Debugging only) + type: boolean + default: false + +- name: BuildId + displayName: Pipeline BuildId, you could find it in the URL + type: string + default: '0' + +- name: CudaVersion + displayName: CUDA version + type: string + default: '13.0' + values: + - 13.0 + +resources: + repositories: + - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step + type: github + endpoint: ort-examples + name: microsoft/onnxruntime-inference-examples + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release + +variables: +- template: templates/common-variables.yml +- name: win_trt_home + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda13 }} +- name: win_cuda_home + value: $(Agent.TempDirectory)\v13.0 +- name: win_cudnn_home + value: $(Agent.TempDirectory)\9.14.0.64_cuda13 +- name: CudaArchs + value: '75-real;80-real;86-real;89-real;90-real;100-real;120-real;120-virtual' + +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines + parameters: + settings: + networkIsolationPolicy: Permissive + featureFlags: + binskimScanAllExtensions: true + sdl: + binskim: + enabled: true + scanOutputDirectoryOnly: true + sourceAnalysisPool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + componentgovernance: + ignoreDirectories: '$(Build.Repository.LocalPath)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/benchmark,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11/tests,$(Build.Repository.LocalPath)/cmake/external/onnxruntime-extensions,$(Build.Repository.LocalPath)/js/react_native/e2e/node_modules,$(Build.Repository.LocalPath)/js/node_modules,$(Build.Repository.LocalPath)/onnxruntime-inference-examples,$(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/benchmark,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11/tests,$(Build.SourcesDirectory)/cmake/external/onnxruntime-extensions,$(Build.SourcesDirectory)/js/react_native/e2e/node_modules,$(Build.SourcesDirectory)/js/node_modules,$(Build.SourcesDirectory)/onnxruntime-inference-examples,$(Build.BinariesDirectory)' + sourceRepositoriesToScan: + exclude: + - repository: onnxruntime-inference-examples + spotBugs: + enabled: false + justificationForDisabling: "Getting ##[error]1. SpotBugs Error gdn.unknownFormatResult - File: spotbugs.xml, which indicates that SpotBugs found one or more errors, which are not handled by the Guardian right now." + codeql: + compiled: + enabled: false + justificationForDisabling: 'CodeQL is taking nearly 6 hours resulting in timeouts in our production pipelines' + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + + stages: + - template: stages/set_packaging_variables_stage.yml + parameters: + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + + - template: templates/linux-cpu-packaging-pipeline.yml + + - template: stages/nuget-combine-cuda-stage.yml + parameters: + CudaVersion: ${{ parameters.CudaVersion }} + RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} + UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} + win_trt_home: ${{ variables.win_trt_home }} + win_cuda_home: ${{ variables.win_cuda_home }} + DoEsrp: true + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} + buildJava: true + buildNodejs: true + SpecificArtifact: ${{ parameters.SpecificArtifact }} + BuildId: ${{ parameters.BuildId }} + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + CudaArchs: ${{ variables.CudaArchs }} + win_cudnn_home: ${{ variables.win_cudnn_home }} diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml index 7242c5fe7b6a6..5ddac928b32d3 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-test-pipelines.yml @@ -104,9 +104,18 @@ stages: - template: nuget/templates/test_macos.yml parameters: - AgentPool: macOS-14 + AgentPool: 'AcesShared' + UseHostedVmImage: 'false' + PoolDemands: 'ImageOverride -equals ACES_VM_SharedPool_Sequoia' ArtifactSuffix: 'CPU' +- template: nodejs/templates/test_macos.yml + parameters: + AgentPool: 'AcesShared' + UseHostedVmImage: 'false' + PoolDemands: 'ImageOverride -equals ACES_VM_SharedPool_Sequoia' + StageSuffix: 'MacOS_ARM64' + - template: nodejs/templates/test_win.yml parameters: AgentPool: 'onnxruntime-Win-CPU-VS2022-Latest' @@ -117,10 +126,6 @@ stages: AgentPool: 'onnxruntime-Ubuntu2204-AMD-CPU' StageSuffix: 'Linux_CPU_x64' -- template: nodejs/templates/test_macos.yml - parameters: - StageSuffix: 'macOS_CPU_x64' - - template: nuget/templates/test_win.yml parameters: AgentPool: 'onnxruntime-Win2022-GPU-A10' @@ -155,7 +160,61 @@ stages: NugetPackageName: 'Microsoft.ML.OnnxRuntime.Gpu.Linux' CudaVersion: 12.8 +- template: templates/test-binary-archive-stage.yml + parameters: + artifactName: onnxruntime-linux-aarch64 + artifactPipelineResource: build + previousStageName: Setup + platform: linux-aarch64 + agentPool: onnxruntime-linux-ARM64-CPU-2019 +- template: templates/test-binary-archive-stage.yml + parameters: + artifactName: onnxruntime-linux-x64 + artifactPipelineResource: build + previousStageName: Setup + platform: linux-x64 + agentPool: onnxruntime-Ubuntu2204-AMD-CPU + +- template: templates/test-binary-archive-stage.yml + parameters: + artifactName: onnxruntime-osx-arm64 + artifactPipelineResource: build + previousStageName: Setup + platform: osx-arm64 + agentPool: + name: AcesShared + os: macOS + demands: + - ImageOverride -equals ACES_VM_SharedPool_Sequoia + agentSetupSteps: + - template: templates/setup-build-tools.yml + parameters: + host_cpu_arch: arm64 + +- template: templates/test-binary-archive-stage.yml + parameters: + artifactName: onnxruntime-win-arm64 + artifactPipelineResource: build + previousStageName: Setup + platform: win-arm64 + agentPool: onnxruntime-qnn-windows-vs-2022-arm64 + +- template: templates/test-binary-archive-stage.yml + parameters: + artifactName: onnxruntime-win-arm64x + artifactPipelineResource: build + previousStageName: Setup + platform: win-arm64x + agentPool: onnxruntime-qnn-windows-vs-2022-arm64 + +- template: templates/test-binary-archive-stage.yml + parameters: + artifactName: onnxruntime-win-x64 + artifactPipelineResource: build + previousStageName: Setup + platform: win-x64 + agentPool: onnxruntime-Win-CPU-VS2022-Latest # Run GPU tests. - stage: Windows_Packaging_cuda_Testing @@ -225,7 +284,7 @@ stages: - checkout: self clean: true submodules: none - + - download: build artifact: 'Windows_Packaging_tensorrt_build_artifacts' displayName: 'Download Windows GPU Packages Build' @@ -246,7 +305,7 @@ stages: versionSpec: "17" jdkArchitectureOption: x64 jdkSourceOption: 'PreInstalled' - + - task: PythonScript@0 displayName: 'Update CTest Path References' inputs: diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index b4012b74196ee..ec3e8a9621e4c 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -23,11 +23,6 @@ parameters: type: number default: 0 -- name: PackageName - displayName: What is the package name? Override using an environment variable CustomPackageName. - type: string - default: 'Microsoft.ML.OnnxRuntime.Foundry' - variables: - template: templates/common-variables.yml - name: ReleaseVersionSuffix @@ -121,7 +116,7 @@ extends: buildArch: x64 msbuildPlatform: arm64 packageName: arm64 - buildparameter: --arm64ec --buildasx --caller_framework WinAI + buildparameter: --arm64 --buildasx --caller_framework WinAI runTests: false buildJava: false buildNodejs: false @@ -137,141 +132,8 @@ extends: AdditionalBuildFlags: '--use_webgpu --skip_tests' DoEsrp: true - - stage: NugetPackaging - dependsOn: [Windows_Packaging_CUDA, Windows_Packaging_CPU_arm64, ManagedNugetPackaging, MacOS_C_API_Package_Publish] - jobs: - - job: CreateNugetPackage - pool: 'Onnxruntime-Win2022-GPU-A10' - timeoutInMinutes: 120 - steps: - - checkout: self - clean: true - submodules: none - - - task: UsePythonVersion@0 - inputs: - versionSpec: '3.12' - addToPath: true - - task: PipAuthenticate@1 - displayName: 'Pip Authenticate' - inputs: - artifactFeeds: 'Lotus' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - managed nuget' - inputs: - artifactName: 'onnxruntime-managed-nuget' - targetPath: '$(Build.BinariesDirectory)/managed-nuget' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - win-x64' - inputs: - artifactName: 'onnxruntime-win-x64-cuda' - targetPath: '$(Build.BinariesDirectory)/win-x64' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - win-arm64' - inputs: - artifactName: 'onnxruntime-win-arm64' - targetPath: '$(Build.BinariesDirectory)/win-arm64' - - - task: DownloadPipelineArtifact@0 - displayName: 'Download Pipeline Artifact - osx' - inputs: - artifactName: 'onnxruntime-osx' - targetPath: '$(Build.BinariesDirectory)/osx' - - - task: PowerShell@2 - displayName: 'Create osx directories' - inputs: - targetType: 'inline' - script: | - mkdir -p $(Build.BinariesDirectory)/osx-arm64 - Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-arm64* -Destination $(Build.BinariesDirectory)/osx-arm64 - - - task: PowerShell@2 - displayName: 'List all files downloaded' - inputs: - targetType: 'inline' - script: | - $files = Get-ChildItem $(Build.BinariesDirectory) -Recurse - foreach ($file in $files) { - Write-Host "File: $($file.FullName)" - if ($file -like "*onnxruntime*") { - Write-Host "File onnxruntime: $($file.FullName) - Size: $($file.Length)" - } - } - $dirs = Get-ChildItem $(Build.BinariesDirectory) -Directory - foreach ($dir in $dirs) { - Write-Host "Directory: $($dir.FullName)" - } - $osx_arm64_archive = Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64* - if ($osx_arm64_archive.Count -eq 0) { - Write-Host "No osx-arm64 archive found." - } else { - Write-Host "osx-arm64 archive found: $($osx_arm64_archive[0].FullName)" - } - workingDirectory: $(Build.BinariesDirectory) - - - task: PowerShell@2 - displayName: 'Extract Nuget Package Version' - inputs: - targetType: 'inline' - script: | - $nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/managed-nuget -Filter Microsoft.ML.OnnxRuntime.Managed.*.nupkg -Recurse) - $package_name = $nupkgs[0].Name - $version_length = $package_name.Length - "Microsoft.ML.OnnxRuntime.Managed.".Length - ".nupkg".Length - $package_version = $package_name.Substring("Microsoft.ML.OnnxRuntime.Managed.".Length, $version_length) - Write-Host "##vso[task.setvariable variable=package_version;]$package_version" - workingDirectory: $(Build.BinariesDirectory) - - - task: PowerShell@2 - displayName: 'Extract Archives' - inputs: - targetType: 'inline' - script: | - Expand-Archive -Path $(Build.BinariesDirectory)/win-x64/onnxruntime-win-x64-cuda*.zip -DestinationPath $(Build.BinariesDirectory)/win-x64 - Expand-Archive -Path $(Build.BinariesDirectory)/win-arm64/onnxruntime-win-arm64*.zip -DestinationPath $(Build.BinariesDirectory)/win-arm64 - $osx_arm64_archive = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64*)[0].FullName - tar -xzf $osx_arm64_archive -C $(Build.BinariesDirectory)/osx-arm64 2>$null - $win_x64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-x64 -Filter onnxruntime-win-x64-cuda*)[0].FullName - $win_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-arm64 -Filter onnxruntime-win-arm64*)[0].FullName - $osx_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64*)[0].FullName - Write-Host "##vso[task.setvariable variable=win_x64;]$win_x64" - Write-Host "##vso[task.setvariable variable=win_arm64;]$win_arm64" - Write-Host "##vso[task.setvariable variable=osx_x64;]$osx_x64" - Write-Host "##vso[task.setvariable variable=osx_arm64;]$osx_arm64" - workingDirectory: $(Build.BinariesDirectory) - - - task: PowerShell@2 - displayName: 'Get Package Name' - inputs: - targetType: 'inline' - script: | - if ($env:CustomPackageName) { - Write-Host "##vso[task.setvariable variable=PackageName;]$env:CustomPackageName" - Write-Host "PackageName: $env:CustomPackageName" - } else { - Write-Host "##vso[task.setvariable variable=PackageName;]${{ parameters.PackageName }}" - Write-Host "PackageName: ${{ parameters.PackageName }}" - } - workingDirectory: $(Build.BinariesDirectory) - - - task: PythonScript@0 - displayName: 'Generate Nuget Package' - inputs: - scriptPath: '$(Build.SourcesDirectory)/tools/nuget/generate_nuspec_for_custom_nuget.py' - arguments: '--nuspec_path "$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec" --root_dir "$(Build.SourcesDirectory)" --commit_id "$(Build.SourceVersion)" --win_arm64 "$(win_arm64)" --win_x64 "$(win_x64)" --osx_arm64 "$(osx_arm64)" --osx_x64 "$(osx_x64)" --package_version "$(package_version)" --package_name "$(PackageName)"' - - - task: NuGetCommand@2 - displayName: 'Pack Nuget Package' - inputs: - command: 'pack' - packagesToPack: '$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec' - packDestination: $(Build.ArtifactStagingDirectory)\ - - - task: 1ES.PublishPipelineArtifact@1 - displayName: 'Publish Artifact: Nuget' - inputs: - artifactName: '${{ parameters.PackageName }}' - targetPath: '$(Build.ArtifactStagingDirectory)' + - template: templates/foundry-local-nuget-packaging.yml + parameters: + DependsOn: [Setup, Windows_Packaging_CUDA, Windows_Packaging_CPU_arm64, ManagedNugetPackaging, MacOS_C_API_Package_Publish] + DoEsrp: true + PackageName: 'Microsoft.ML.OnnxRuntime.Foundry' diff --git a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml index 2868963b637a8..0b63a4f5b83c1 100644 --- a/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml +++ b/tools/ci_build/github/azure-pipelines/dml-nuget-packaging.yml @@ -75,9 +75,13 @@ extends: DoNugetPack: 'true' DoEsrp: ${{ parameters.DoEsrp }} NuPackScript: | + python -m pip install setuptools msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} /p:CurrentData=$(BuildDate) /p:CurrentTime=$(BuildTime) + if errorlevel 1 exit /b 1 copy $(Build.SourcesDirectory)\csharp\src\Microsoft.ML.OnnxRuntime\bin\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) - copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.nupkg $(Build.ArtifactStagingDirectory) + if errorlevel 1 exit /b 1 + powershell -ExecutionPolicy Bypass -File $(Build.SourcesDirectory)\tools\ci_build\github\windows\select_dml_package.ps1 -SourceDir "$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo" -IsReleaseBuild "${{ parameters.IsReleaseBuild }}" -Action copy -DestinationDir "$(Build.ArtifactStagingDirectory)" + if errorlevel 1 exit /b 1 mkdir $(Build.ArtifactStagingDirectory)\testdata copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\custom_op_library.* $(Build.ArtifactStagingDirectory)\testdata @@ -94,13 +98,17 @@ extends: DoEsrp: ${{ parameters.DoEsrp }} RunTests: 'false' NuPackScript: | + python -m pip install setuptools msbuild $(Build.SourcesDirectory)\csharp\OnnxRuntime.CSharp.proj /p:Configuration=RelWithDebInfo /p:TargetArchitecture=arm64 /t:CreatePackage /p:OrtPackageId=Microsoft.ML.OnnxRuntime.DirectML /p:IsReleaseBuild=${{ parameters.IsReleaseBuild }} - cd $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\ - ren Microsoft.ML.OnnxRuntime.DirectML.* win-dml-arm64.zip + if errorlevel 1 exit /b 1 + powershell -ExecutionPolicy Bypass -File $(Build.SourcesDirectory)\tools\ci_build\github\windows\select_dml_package.ps1 -SourceDir "$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo" -IsReleaseBuild "${{ parameters.IsReleaseBuild }}" -Action rename -NewName "win-dml-arm64.zip" + if errorlevel 1 exit /b 1 copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\win-dml-arm64.zip $(Build.ArtifactStagingDirectory) + if errorlevel 1 exit /b 1 mkdir $(Build.ArtifactStagingDirectory)\testdata copy $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\custom_op_library.* $(Build.ArtifactStagingDirectory)\testdata - template: stages/nuget_dml_packaging_stage.yml parameters: - DoEsrp: ${{ parameters.DoEsrp }} \ No newline at end of file + DoEsrp: ${{ parameters.DoEsrp }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} diff --git a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml index 9d831df54096a..8fcde9e88edd7 100644 --- a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml +++ b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml @@ -4,172 +4,62 @@ resources: source: 'Zip-Nuget-Java-Nodejs Packaging Pipeline' trigger: true branch: main + repositories: + - repository: 1esPipelines + type: git + name: 1ESPipelineTemplates/1ESPipelineTemplates + ref: refs/tags/release variables: mavenVersion: '3.9.8' -stages: -- template: templates/final-jar-testing-win.yml +extends: + # The pipeline extends the 1ES PT which will inject different SDL and compliance tasks. + # For non-production pipelines, use "Unofficial" as defined below. + # For productions pipelines, use "Official". + template: v1/1ES.Official.PipelineTemplate.yml@1esPipelines parameters: - PoolName: 'onnxruntime-Win-CPU-VS2022-Latest' - -- template: templates/final-jar-testing-linux.yml - parameters: - OS: Linux - PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' - -- template: templates/final-jar-testing-linux.yml - parameters: - OS: MacOS - PoolName: 'macOS-14' - -- stage: GPU_JAR_Testing - dependsOn: [] - jobs: - - job: Final_Jar_Testing_Windows_GPU - workspace: - clean: all - pool: 'onnxruntime-Win2022-GPU-A10' - timeoutInMinutes: 60 - variables: - - name: runCodesignValidationInjection - value: false - - steps: - - template: templates/set-version-number-variables-step.yml - - - template: templates/jobs/download_win_gpu_library.yml + featureFlags: + binskimScanAllExtensions: true + sdl: + binskim: + enabled: true + scanOutputDirectoryOnly: true + sourceAnalysisPool: + name: onnxruntime-Win-CPU-VS2022-Latest + os: windows + componentgovernance: + ignoreDirectories: '$(Build.Repository.LocalPath)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/benchmark,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11,$(Build.Repository.LocalPath)/cmake/external/onnx/third_party/pybind11/tests,$(Build.Repository.LocalPath)/cmake/external/onnxruntime-extensions,$(Build.Repository.LocalPath)/js/react_native/e2e/node_modules,$(Build.Repository.LocalPath)/js/node_modules,$(Build.Repository.LocalPath)/onnxruntime-inference-examples,$(Build.SourcesDirectory)/cmake/external/emsdk/upstream/emscripten/tests,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/benchmark,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11,$(Build.SourcesDirectory)/cmake/external/onnx/third_party/pybind11/tests,$(Build.SourcesDirectory)/cmake/external/onnxruntime-extensions,$(Build.SourcesDirectory)/js/react_native/e2e/node_modules,$(Build.SourcesDirectory)/js/node_modules,$(Build.SourcesDirectory)/onnxruntime-inference-examples,$(Build.BinariesDirectory)' + spotBugs: + enabled: false + justificationForDisabling: "Getting ##[error]1. SpotBugs Error gdn.unknownFormatResult - File: spotbugs.xml, which indicates that SpotBugs found one or more errors, which are not handled by the Guardian right now." + codeql: + compiled: + enabled: false + justificationForDisabling: 'CodeQL is taking nearly 6 hours resulting in timeouts in our production pipelines' + tsa: + enabled: true + codeSignValidation: + enabled: true + break: true + policheck: + enabled: true + exclusionsFile: '$(Build.SourcesDirectory)\tools\ci_build\policheck_exclusions.xml' + + stages: + - template: templates/final-jar-testing-win.yml parameters: - CudaVersion: 12.8 - DownloadCUDA: true - DownloadTRT: true - - - template: templates/setup-maven.yml - - - task: Maven@4 - displayName: 'Download Java Dependencies' - inputs: - mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' - goals: 'dependency:copy-dependencies' - options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' - publishJUnitTestResults: false - javaHomeOption: 'JDKVersion' - jdkVersionOption: '1.17' - mavenVersionOption: 'Default' - - download: build - artifact: 'onnxruntime-java-gpu' - displayName: 'Download Final Jar' - - script: | - move $(Pipeline.Workspace)\build\onnxruntime-java-gpu\*.jar $(Pipeline.Workspace)\build\onnxruntime-java\ - - - task: PowerShell@2 - displayName: 'Run Java Tests with PowerShell' - inputs: - targetType: 'inline' - script: | - # Exit script on any error - $ErrorActionPreference = "Stop" - - cd $(Pipeline.Workspace)/build/onnxruntime-java - del *.asc - del *.sha256 - del *.sha512 - del *.pom - del *.sha1 - del *.pom - cd .. - mkdir tests - cd tests - jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar - del $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar - dir $(Pipeline.Workspace)/build/tests - Write-Host "Running JUnit Tests..." - & java -DUSE_CUDA=1 ` - -cp "$(Pipeline.Workspace)\build\tests;$(Pipeline.Workspace)\build\onnxruntime-java\*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)\build\tests ` - --fail-if-no-tests --disable-banner --reports-dir "$($env:Build_ArtifactStagingDirectory)/TestResults" - - - task: PublishTestResults@2 - displayName: 'Publish Test Results' - inputs: - testResultsFormat: 'JUnit' - testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' - failTaskOnFailedTests: true + PoolName: 'onnxruntime-Win-CPU-VS2022-Latest' + - template: templates/final-jar-testing-linux.yml + parameters: + OS: linux + PoolName: 'onnxruntime-Ubuntu2204-AMD-CPU' - - job: Final_Jar_Testing_Linux_GPU - workspace: - clean: all - pool: - name: 'Onnxruntime-Linux-GPU-A10' - variables: - - name: runCodesignValidationInjection - value: false - - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 - timeoutInMinutes: 60 - steps: - - checkout: self - submodules: false - - - template: templates/set-version-number-variables-step.yml - - - bash: | - sudo apt-get install -y msopenjdk-17 - dpkg -l msopenjdk-17 - - - bash: | - echo "Downloading and installing Maven $(mavenVersion) for Linux..." - MAVEN_DIR="$(Agent.TempDirectory)/apache-maven-$(mavenVersion)" - # Download Maven binary - wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz - - # Extract to the temp directory - mkdir -p ${MAVEN_DIR} - tar -xzf $(Agent.TempDirectory)/maven.tar.gz -C $(Agent.TempDirectory) - - # Add Maven's bin directory to the PATH for subsequent tasks in the job - echo "##vso[task.prependpath]${MAVEN_DIR}/bin" - displayName: 'Install Maven (Linux)' - - - script: | - echo "Maven is now on the PATH." - mvn --version - - - download: build - artifact: 'onnxruntime-java-gpu' - displayName: 'Download Final Jar' - - # Rename the downloaded folder - - script: | - mv $(Pipeline.Workspace)/build/onnxruntime-java-gpu $(Pipeline.Workspace)/build/onnxruntime-java - - - task: Maven@4 - displayName: 'Download Dependencies' - inputs: - mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' - goals: 'dependency:copy-dependencies' - options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' - publishJUnitTestResults: false - javaHomeOption: 'Path' - jdkDirectory: '/usr/lib/jvm/msopenjdk-17-amd64' - jdkVersionOption: 'Default' - mavenVersionOption: 'Default' - - # Now all the jars are in the $(Pipeline.Workspace)/build folder - - - template: templates/get-docker-image-steps.yml + - template: templates/final-jar-testing-linux.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 - Context: tools/ci_build/github/linux/docker/ - DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ variables.docker_base_image }} --build-arg TRT_VERSION=${{ variables.linux_trt_version }}" - Repository: onnxruntimeubi8packagestest + OS: macOS + PoolName: 'AcesShared' + PoolDemands: 'ImageOverride -equals ACES_VM_SharedPool_Sequoia' - - bash: | - docker run --network=none --rm \ - --gpus all \ - --volume $(Build.SourcesDirectory):/onnxruntime_src \ - --volume $(Pipeline.Workspace)/build:/build \ - --volume /data/models:/build/models:ro \ - onnxruntimeubi8packagestest \ - /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) - displayName: 'Test' + - template: templates/final-jar-testing-gpu.yml diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml index c00cbb06f26fd..4bfb9c630fede 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-daily-perf-pipeline.yml @@ -142,7 +142,7 @@ jobs: workingDirectory: '$(Build.SourcesDirectory)/onnxruntime/python/tools/tensorrt/perf/' condition: always() - - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas] coloredlogs' + - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas]' displayName: 'Install dashboard dependencies' - script: | @@ -165,7 +165,7 @@ jobs: - ${{ if eq(parameters.PostToDashboard, true) }}: - - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas] coloredlogs' + - script: 'python3 -m pip install pandas azure-kusto-data[pandas] azure-kusto-ingest[pandas]' displayName: 'Install dashboard dependencies' - script: | @@ -191,4 +191,4 @@ jobs: pathtoPublish: '$(Build.SourcesDirectory)/Artifact' artifactName: 'result-$(Build.BuildNumber)' - - template: templates/clean-agent-build-directory-step.yml \ No newline at end of file + - template: templates/clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml index ae595bbf0c96b..cd41fc575020b 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test.yml @@ -6,12 +6,20 @@ steps: - task: PowerShell@2 - displayName: 'Move Artifact Directory' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT')) + displayName: 'Move Artifact Directory (Windows)' inputs: targetType: 'inline' script: | Move-Item -Path "$(Pipeline.Workspace)/build/NPM_packages" -Destination "$(Build.BinariesDirectory)/nodejs-artifact" +- task: CmdLine@2 + condition: and(succeeded(), ne(variables['Agent.OS'], 'Windows_NT')) + displayName: 'Move Artifact Directory (POSIX)' + inputs: + script: | + mv "$(Pipeline.Workspace)/build/NPM_packages" "$(Build.BinariesDirectory)/nodejs-artifact" + - script: mkdir e2e_test workingDirectory: '$(Build.BinariesDirectory)' @@ -38,4 +46,4 @@ steps: npm init -y npm install $(NpmPackageFilesForTest) --onnxruntime-node-install-cuda=skip node -p "require('onnxruntime-node')" - workingDirectory: '$(Build.BinariesDirectory)/e2e_test' \ No newline at end of file + workingDirectory: '$(Build.BinariesDirectory)/e2e_test' diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index 4dd19ce2c250c..7e184492fab59 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -1,5 +1,9 @@ parameters: StageSuffix: '' + AgentPool : 'macOS-15' + UseHostedVmImage: 'true' + PoolDemands: '' + stages: - stage: Nodejs_Test_MacOS_${{ parameters.StageSuffix }} dependsOn: @@ -11,7 +15,12 @@ stages: clean: all timeoutInMinutes: 120 pool: - vmImage: 'macOS-15' + ${{ if eq(parameters.UseHostedVmImage, 'true') }}: + vmImage: ${{ parameters.AgentPool }} + ${{ else }}: + name: ${{ parameters.AgentPool }} + ${{ if ne(parameters.PoolDemands, '') }}: + demands: ${{ parameters.PoolDemands }} variables: - name: OnnxRuntimeBuildDirectory diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index 02613871d61ff..2548eebeb9d42 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -49,8 +49,8 @@ stages: clean: true submodules: none - - - template: ../../templates/setup-build-tools.yml + + - template: ../../templates/setup-build-tools.yml parameters: host_cpu_arch: 'x64' diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml index 1d122d64b1211..5fc52e2c76468 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml @@ -1,6 +1,10 @@ parameters: + AgentPool : 'macOS-15' + UseHostedVmImage: 'true' IsMacOS : 'true' ArtifactSuffix: '' + PoolDemands: '' + stages: - stage: NuGet_Test_MacOS dependsOn: @@ -11,7 +15,12 @@ stages: workspace: clean: all pool: - vmImage: 'macOS-15' + ${{ if eq(parameters.UseHostedVmImage, 'true') }}: + vmImage: ${{ parameters.AgentPool }} + ${{ else }}: + name: ${{ parameters.AgentPool }} + ${{ if ne(parameters.PoolDemands, '') }}: + demands: ${{ parameters.PoolDemands }} variables: - name: OnnxRuntimeBuildDirectory @@ -27,18 +36,36 @@ stages: - script: | mv $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} $(Build.BinariesDirectory)/nuget-artifact - mv $(Pipeline.Workspace)/build/onnxruntime-osx $(Build.BinariesDirectory)/testdata + + # Artifact is a folder containing tgz. Extract it to testdata. + mkdir -p $(Build.BinariesDirectory)/testdata + for archive in $(Pipeline.Workspace)/build/onnxruntime-osx/*.tgz; do + tar -xzf "$archive" -C $(Build.BinariesDirectory)/testdata + done + + # Ensure libcustom_op_library.dylib is where EndToEndTests expects it (testdata/testdata) + mkdir -p $(Build.BinariesDirectory)/testdata/testdata + find $(Build.BinariesDirectory)/testdata -name "libcustom_op_library.dylib" -exec cp {} $(Build.BinariesDirectory)/testdata/testdata/ \; + - template: get-nuget-package-version-as-variable.yml parameters: packageFolder: '$(Build.BinariesDirectory)/nuget-artifact' + - script: | + git submodule update --init cmake/external/onnx + cd cmake/external/onnx + git fetch origin v1.13.1 --depth=1 + git checkout v1.13.1 + cd ../../.. + displayName: 'Initialize ONNX submodule for test data (pinned to v1.13.1 since new data types like float8 is not supported in nuget)' + - script: | $(Build.SourcesDirectory)/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh \ $(Build.BinariesDirectory)/nuget-artifact \ $(NuGetPackageVersionNumber) \ true - + if [ $? -ne 0 ]; then echo "Failed to run test" exit 1 @@ -48,4 +75,5 @@ stages: OnnxRuntimeBuildDirectory: $(Build.BinariesDirectory) DisableContribOps: $(DisableContribOps) DisableMlOps: $(DisableMlOps) - IsReleaseBuild: $(IsReleaseBuild) \ No newline at end of file + IsReleaseBuild: $(IsReleaseBuild) + ORT_LOADER_VERBOSITY: 1 diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 0481a356cf9a1..2a8e222a9e192 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -96,6 +96,5 @@ extends: QnnSdk: ${{ parameters.QnnSdk }} IsReleaseBuild: ${{ parameters.IsReleaseBuild }} DoEsrp: ${{ parameters.DoEsrp }} - ArtifactName: 'drop-nuget-qnn-arm64x' StageName: 'OnnxRuntime_QNN_Nuget_Win_Arm64x' build_config: ${{ parameters.build_config }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml index 0400b4b2233ec..4b07e4173e6c9 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-combine-cuda-stage.yml @@ -57,6 +57,7 @@ stages: CudaVersion: ${{ parameters.CudaVersion }} buildJava: ${{ parameters.buildJava }} buildNodejs: ${{ parameters.buildNodejs }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - ${{ if eq(parameters.buildNodejs, 'true') }}: - template: nodejs-linux-packaging-stage.yml @@ -71,6 +72,7 @@ stages: win_trt_home: ${{ parameters.win_trt_home }} win_cuda_home: ${{ parameters.win_cuda_home }} buildJava: ${{ parameters.buildJava }} + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} CudaArchs: ${{ parameters.CudaArchs }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml index 55e78ae79b208..a0f578e22f910 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-packaging-stage.yml @@ -186,6 +186,7 @@ stages: # 1* stands for version number. we use it to filter Gpu.Windows and Gpu.Linux packages PackageName: 'Microsoft.ML.OnnxRuntime.Gpu.1*nupkg' VerifyNugetSigning: false + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - template: ../templates/validate-package.yml parameters: @@ -194,6 +195,7 @@ stages: PackageName: 'Microsoft.ML.OnnxRuntime.Gpu.Windows.*nupkg' PlatformsSupported: 'win-x64' VerifyNugetSigning: false + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - template: ../templates/validate-package.yml parameters: @@ -202,6 +204,7 @@ stages: PackageName: 'Microsoft.ML.OnnxRuntime.Gpu.Linux.*nupkg' PlatformsSupported: 'linux-x64' VerifyNugetSigning: false + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} - task: MSBuild@1 displayName: 'Clean C#' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index 486efa85877cb..0bce2a052ff1b 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -6,6 +6,9 @@ parameters: type: boolean - name: buildNodejs type: boolean +- name: IsReleaseBuild + type: boolean + default: false stages: - stage: Linux_C_API_Packaging_GPU @@ -202,6 +205,7 @@ stages: ScriptPath: '$(Build.SourcesDirectory)/onnxruntime/tools/nuget/validate_package.py' PlatformsSupported: 'linux-x64' VerifyNugetSigning: false + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} workingDirectory: '$(Build.ArtifactStagingDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index 9aaf055001bbd..b4bdcaf937bb7 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -55,6 +55,9 @@ parameters: - name: win_cudnn_home type: string default: '' +- name: IsReleaseBuild + type: boolean + default: false stages: # Windows CUDA without TensorRT Packaging @@ -179,10 +182,12 @@ stages: ScriptPath: '$(Build.SourcesDirectory)\onnxruntime\tools\nuget\validate_package.py' PlatformsSupported: 'win-x64' VerifyNugetSigning: false + IsReleaseBuild: ${{ parameters.IsReleaseBuild }} workingDirectory: '$(Build.ArtifactStagingDirectory)' - task: BatchScript@1 displayName: 'Test C API application for GPU package' + condition: and(succeeded(), ne(${{parameters.CudaVersion}}, '13.0')) inputs: filename: $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet\run_capi_application.bat arguments: $(Build.SourcesDirectory)\onnxruntime $(Build.ArtifactStagingDirectory)\onnxruntime-win-x64-gpu-$(OnnxRuntimeVersion).zip $(Build.SourcesDirectory)\onnxruntime-inference-examples\c_cxx\squeezenet diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml index 6eb7c52712671..f767ef110561a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cpu-packaging-stage.yml @@ -66,131 +66,17 @@ stages: - stage: Python_Packaging_Windows_CPU dependsOn: [] jobs: - - job: Windows_py_Wheels - pool: - name: 'onnxruntime-Win-CPU-VS2022-Latest' - os: windows - templateContext: - sdl: - codeSignValidation: - enabled: true - # TODO: check why pyd file was not signed - break: false - additionalTargetsGlobPattern: f|**\*.pyd - psscriptanalyzer: - enabled: true - binskim: - enabled: true - scanOutputDirectoryOnly: true - outputs: - - output: pipelineArtifact - targetPath: $(Build.ArtifactStagingDirectory) - artifactName: onnxruntime-win-$(PythonVersion) - strategy: - matrix: - Python311_x64: - PythonVersion: '3.11' - Python312_x64: - PythonVersion: '3.12' - Python313_x64: - PythonVersion: '3.13' - Python314_x64: - PythonVersion: '3.14' - variables: - OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' - ExtraParam: ${{ parameters.build_py_parameters }} - timeoutInMinutes: 180 - workspace: - clean: all - - steps: - - checkout: self - clean: true - submodules: recursive - - - template: ../templates/setup-build-tools.yml - parameters: - host_cpu_arch: 'x64' - python_version: $(PythonVersion) - - - template: ../templates/set-nightly-build-option-variable-step.yml - - - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt - env: - TMPDIR: "$(Agent.TempDirectory)" - - - task: PythonScript@0 - displayName: 'Build' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: > - --config ${{ parameters.cmake_build_type }} - --enable_lto - --build_dir $(Build.SourcesDirectory)\build - --skip_submodule_sync - --cmake_generator "Visual Studio 17 2022" - --enable_pybind - --enable_onnx_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache - ${{ parameters.build_py_parameters }} - --parallel --use_binskim_compliant_compile_flags --update --build - $(TelemetryOption) - - - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: - - template: ../templates/publish-symbolrequestprod-api.yml - parameters: - ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: - symbolExpiryTime: 60 - includePublicSymbolServer: true - symbolsArtifactName: onnxruntime_cpu_win_x64_$(PythonVersion) - symbolsVersion: $(Build.BuildId) - symbolProject: 'ONNX Runtime' - subscription: 'OnnxrunTimeCodeSign_20240611' - searchPattern: | - $(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime.pdb - $(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_providers_shared.pdb - $(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_pybind11_state.pdb - - # Esrp signing - - template: ../templates/win-esrp-dll.yml - parameters: - FolderPath: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: true - Pattern: '*.pyd,*.dll' - - - task: PythonScript@0 - displayName: 'Build wheel' - inputs: - scriptPath: '$(Build.SourcesDirectory)\setup.py' - arguments: 'bdist_wheel ${{ parameters.build_py_parameters }} $(NightlyBuildOption)' - workingDirectory: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' - - - task: CopyFiles@2 - displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' - inputs: - SourceFolder: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\dist' - Contents: '*.whl' - TargetFolder: '$(Build.ArtifactStagingDirectory)' - - - script: | - 7z x *.whl - workingDirectory: '$(Build.ArtifactStagingDirectory)' - displayName: 'unzip the package' - + - template: ../templates/py-win-cpu.yml + parameters: + architecture: 'x64' + build_py_parameters: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} - - powershell: | - if ("$(PythonVersion)" -notcontains "3.14") { - python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq - Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} - Remove-Item -Recurse -Force onnxruntime - if ("$(ExtraParam)" -contains "--use_azure") { - $env:path="$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\_deps\vcpkg-src\installed\x64-windows\bin;$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\_deps\vcpkg-src\installed\x86-windows\bin;$env:path" - python onnxruntime_test_python_azure.py - } - python onnx_backend_test_series.py - } - workingDirectory: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' - displayName: 'Run Python Tests' + - template: ../templates/py-win-cpu.yml + parameters: + architecture: 'arm64' + build_py_parameters: ${{ parameters.build_py_parameters }} + cmake_build_type: ${{ parameters.cmake_build_type }} - ${{ if eq(parameters.enable_mac_cpu, true) }}: - stage: Python_Packaging_MacOS diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index 7e47227c23d5b..385cee35eb95d 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -175,8 +175,6 @@ stages: - stage: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Tests dependsOn: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Build - # Skip this stage for Python 3.14 for now until onnx package support python 3.14. - condition: and(succeeded(), ne('${{ parameters.PYTHON_VERSION }}', '3.14')) jobs: - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Tests workspace: diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-webgpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-webgpu-stage.yml index 8bd8521d80104..1897d94db76c7 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-webgpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-webgpu-stage.yml @@ -131,8 +131,6 @@ stages: - stage: Win_py_webgpu_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Tests dependsOn: Win_py_webgpu_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Build - # Skip this stage for Python 3.14 for now until onnx package support python 3.14. - condition: and(succeeded(), ne('${{ parameters.PYTHON_VERSION }}', '3.14')) jobs: - job: Win_py_webgpu_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }}_Tests workspace: diff --git a/tools/ci_build/github/azure-pipelines/templates/build-win-arm64x-steps.yml b/tools/ci_build/github/azure-pipelines/templates/build-win-arm64x-steps.yml new file mode 100644 index 0000000000000..50e7cbb13d6e1 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/build-win-arm64x-steps.yml @@ -0,0 +1,28 @@ +# Runs a Windows ARM64X build in `buildDirectory`. + +parameters: + buildDirectory: '$(Build.BinariesDirectory)' + additionalBuildPyArgs: '' + +steps: +- task: PythonScript@0 + displayName: 'Build arm64 project for arm64x - generate the def & lib file for next build' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + ${{ parameters.additionalBuildPyArgs }} + --build_shared_lib + --arm64 + --buildasx + --build_dir="${{ parameters.buildDirectory }}/arm64" + +- task: PythonScript@0 + displayName: 'Build arm64ec project for arm64x - the real arm64x' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + ${{ parameters.additionalBuildPyArgs }} + --build_shared_lib + --arm64ec + --buildasx + --build_dir="${{ parameters.buildDirectory }}" diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows-qnn.yml deleted file mode 100644 index ab3e0ebaab39a..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows-qnn.yml +++ /dev/null @@ -1,141 +0,0 @@ -# sets up common build tools for the windows build machines before build - -parameters: -- name: DoEsrp - displayName: Run code sign tasks? Must be true if you are doing an Onnx Runtime release. - type: boolean - default: true - -- name: buildConfig - displayName: buildConfig - type: string - default: 'RelWithDebInfo' - -- name: artifactName - displayName: artifactName,like 'onnxruntime-win-x64-1.6.0' - type: string - default: '' - -- name: artifactNameNoVersionString - type: string - default: 'onnxruntime-win-x64' - -- name: commitId - displayName: commitId - type: string - default: '' - -- name: trtEnabled - displayName: Include TRT EP libraries? - type: boolean - default: true - -steps: - - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: - - template: publish-symbolrequestprod-api.yml - parameters: - ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: - symbolExpiryTime: 60 - includePublicSymbolServer: true - symbolsArtifactName: ${{parameters.artifactNameNoVersionString}} - symbolsVersion: $(Build.BuildId) - symbolProject: 'ONNX Runtime' - subscription: 'OnnxrunTimeCodeSign_20240611' - searchPattern: | - $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime.pdb - $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_*.pdb - - - - task: CmdLine@2 - displayName: 'Copy build artifacts for zipping' - inputs: - script: | - mkdir $(Build.BinariesDirectory)\${{parameters.artifactName}} - mkdir $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - mkdir $(Build.BinariesDirectory)\${{parameters.artifactName}}\include - - if exist $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.dll ( - echo "cuda context headers copied" - mkdir $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers\cuda - copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\resource.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers - copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\custom_op_context.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers - copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\cuda\cuda_context.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers\cuda - copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\cuda\cuda_resource.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include\core\providers\cuda - ) - - echo "Directories created" - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_shared.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_shared.lib $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_shared.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_cuda.lib $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - - # Copy WebGPU dependencies if required - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxcompiler.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\dxil.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - - # Copy QNN dependencies if required - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_qnn.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libQnnHtp*.so $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\libqnnhtp*.cat $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib /Y - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnCpu.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnGpu.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtp.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpPrepare.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV68Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV73Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV81Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSaver.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSystem.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\Qualcomm_LICENSE.pdf $(Build.BinariesDirectory)\${{parameters.artifactName}} - - # copy trt ep libraries only when trt ep is enabled - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.lib $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime.pdb $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime.lib $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib - copy $(Build.SourcesDirectory)\include\onnxruntime\core\session\onnxruntime_*.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include - copy $(Build.SourcesDirectory)\include\onnxruntime\core\framework\provider_options.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include - copy $(Build.SourcesDirectory)\include\onnxruntime\core\providers\cpu\cpu_provider_factory.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include - copy $(Build.SourcesDirectory)\orttraining\orttraining\training_api\include\onnxruntime_training*.h $(Build.BinariesDirectory)\${{parameters.artifactName}}\include - - REM copy the README, license and TPN - copy $(Build.SourcesDirectory)\README.md $(Build.BinariesDirectory)\${{parameters.artifactName}}\README.md - copy $(Build.SourcesDirectory)\docs\Privacy.md $(Build.BinariesDirectory)\${{parameters.artifactName}}\Privacy.md - copy $(Build.SourcesDirectory)\LICENSE $(Build.BinariesDirectory)\${{parameters.artifactName}}\LICENSE - copy $(Build.SourcesDirectory)\ThirdPartyNotices.txt $(Build.BinariesDirectory)\${{parameters.artifactName}}\ThirdPartyNotices.txt - copy $(Build.SourcesDirectory)\VERSION_NUMBER $(Build.BinariesDirectory)\${{parameters.artifactName}}\VERSION_NUMBER - @echo ${{parameters.commitId}} > $(Build.BinariesDirectory)\${{parameters.artifactName}}\GIT_COMMIT_ID - - workingDirectory: '$(Build.BinariesDirectory)\${{parameters.buildConfig}}' - - - ${{ if eq(parameters.DoEsrp, true) }}: - - template: win-esrp-dll.yml - parameters: - FolderPath: '$(Build.BinariesDirectory)\${{parameters.artifactName}}' - DisplayName: 'ESRP - Sign Native dlls' - DoEsrp: ${{parameters.DoEsrp}} - Pattern: '*.dll,*.exe' - - - task: DeleteFiles@1 - displayName: 'Delete CodeSignSummary*.md' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\${{parameters.artifactName}}' - Contents: 'CodeSignSummary*.md' - - - task: ArchiveFiles@2 - inputs: - rootFolderOrFile: '$(Build.BinariesDirectory)\${{parameters.artifactName}}' - includeRootFolder: true - archiveType: 'zip' # Options: zip, 7z, tar, wim - archiveFile: '$(Build.ArtifactStagingDirectory)\${{parameters.artifactName}}.zip' - replaceExistingArchive: true - - - task: 1ES.PublishPipelineArtifact@1 - inputs: - targetPath: '$(Build.ArtifactStagingDirectory)' - artifactName: '${{parameters.artifactNameNoVersionString}}' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml index 28a1960aac27b..5f9dd5677e7bc 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-artifacts-package-and-publish-steps-windows.yml @@ -12,7 +12,7 @@ parameters: default: 'RelWithDebInfo' - name: artifactName - displayName: artifactName,like 'onnxruntime-win-x64-1.6.0' + displayName: artifactName, like 'onnxruntime-win-x64-1.6.0' type: string default: '' @@ -30,6 +30,11 @@ parameters: type: boolean default: true +- name: publishArtifactStagingDirectory + displayName: Whether to publish the artifact staging directory as an artifact named `artifactNameNoVersionString`. + type: boolean + default: false + steps: - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: - template: publish-symbolrequestprod-api.yml @@ -89,6 +94,7 @@ steps: copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnHtpV81Stub.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSaver.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\QnnSystem.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib + copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\Qualcomm_LICENSE.pdf $(Build.BinariesDirectory)\${{parameters.artifactName}} # copy trt ep libraries only when trt ep is enabled copy $(Build.BinariesDirectory)\${{parameters.buildConfig}}\${{parameters.buildConfig}}\onnxruntime_providers_tensorrt.dll $(Build.BinariesDirectory)\${{parameters.artifactName}}\lib @@ -133,3 +139,9 @@ steps: archiveType: 'zip' # Options: zip, 7z, tar, wim archiveFile: '$(Build.ArtifactStagingDirectory)\${{parameters.artifactName}}.zip' replaceExistingArchive: true + + - ${{ if parameters.publishArtifactStagingDirectory }}: + - task: 1ES.PublishPipelineArtifact@1 + inputs: + targetPath: '$(Build.ArtifactStagingDirectory)' + artifactName: '${{parameters.artifactNameNoVersionString}}' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 5025046a02b0e..448dbafcaaaac 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -163,6 +163,20 @@ stages: PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} +- template: win-ci.yml + parameters: + DoEsrp: true + stage_name_suffix: CPU_arm64x_${{ parameters.BuildVariant }} + buildArch: x64 + msbuildPlatform: arm64x + packageName: arm64x + buildparameter: ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} + runTests: false + buildJava: false + buildNodejs: false + PreReleaseVersionSuffixString: ${{ parameters.PreReleaseVersionSuffixString }} + PreReleaseVersionSuffixNumber: ${{ parameters.PreReleaseVersionSuffixNumber }} + - template: win-ci.yml parameters: DoEsrp: true @@ -203,6 +217,10 @@ stages: - input: pipelineArtifact artifactName: drop-onnxruntime-java-linux-aarch64 targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-linux-aarch64' + + - input: pipelineArtifact + artifactName: drop-onnxruntime-java-osx-arm64 + targetPath: '$(Build.BinariesDirectory)\java-artifact\onnxruntime-java-osx-arm64' outputs: - output: pipelineArtifact targetPath: $(Build.BinariesDirectory)\java-artifact\onnxruntime-java-win-x64 diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-gpu.yml new file mode 100644 index 0000000000000..22f3621c89c64 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-gpu.yml @@ -0,0 +1,158 @@ +stages: + - stage: GPU_JAR_Testing + dependsOn: [] + jobs: + - job: Final_Jar_Testing_Windows_GPU + templateContext: + type: validationJob + workspace: + clean: all + pool: + name: 'onnxruntime-Win2022-GPU-A10' + os: windows + timeoutInMinutes: 60 + variables: + - name: runCodesignValidationInjection + value: false + + steps: + - template: set-version-number-variables-step.yml + + - template: jobs/download_win_gpu_library.yml + parameters: + CudaVersion: 12.8 + DownloadCUDA: true + DownloadTRT: true + + - template: setup-maven.yml + + - task: Maven@4 + displayName: 'Download Java Dependencies' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' + mavenVersionOption: 'Default' + - download: build + artifact: 'onnxruntime-java-gpu' + displayName: 'Download Final Jar' + - script: | + move $(Pipeline.Workspace)\build\onnxruntime-java-gpu\*.jar $(Pipeline.Workspace)\build\onnxruntime-java\ + + - task: PowerShell@2 + displayName: 'Run Java Tests with PowerShell' + inputs: + targetType: 'inline' + script: | + # Exit script on any error + $ErrorActionPreference = "Stop" + + cd $(Pipeline.Workspace)/build/onnxruntime-java + del *.asc + del *.sha256 + del *.sha512 + del *.pom + del *.sha1 + del *.pom + cd .. + mkdir tests + cd tests + jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + del $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar + dir $(Pipeline.Workspace)/build/tests + Write-Host "Running JUnit Tests..." + & java -DUSE_CUDA=1 ` + -cp "$(Pipeline.Workspace)\build\tests;$(Pipeline.Workspace)\build\onnxruntime-java\*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)\build\tests ` + --fail-if-no-tests --disable-banner --reports-dir "$($env:Build_ArtifactStagingDirectory)/TestResults" + + - task: PublishTestResults@2 + displayName: 'Publish Test Results' + inputs: + testResultsFormat: 'JUnit' + testResultsFiles: '$(Build.ArtifactStagingDirectory)/TestResults/TEST-junit-jupiter.xml' + failTaskOnFailedTests: true + + + - job: Final_Jar_Testing_Linux_GPU + templateContext: + type: validationJob + workspace: + clean: all + pool: + name: 'Onnxruntime-Linux-GPU-A10' + os: linux + variables: + - name: runCodesignValidationInjection + value: false + - name: docker_base_image + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 + timeoutInMinutes: 60 + steps: + - checkout: self + submodules: false + + - template: set-version-number-variables-step.yml + + - bash: | + sudo apt-get install -y msopenjdk-17 + dpkg -l msopenjdk-17 + + - bash: | + echo "Downloading and installing Maven $(mavenVersion) for Linux..." + MAVEN_DIR="$(Agent.TempDirectory)/apache-maven-$(mavenVersion)" + # Download Maven binary + wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz + + # Extract to the temp directory + mkdir -p ${MAVEN_DIR} + tar -xzf $(Agent.TempDirectory)/maven.tar.gz -C $(Agent.TempDirectory) + + # Add Maven's bin directory to the PATH for subsequent tasks in the job + echo "##vso[task.prependpath]${MAVEN_DIR}/bin" + displayName: 'Install Maven (Linux)' + + - script: | + echo "Maven is now on the PATH." + mvn --version + + - download: build + artifact: 'onnxruntime-java-gpu' + displayName: 'Download Final Jar' + + # Rename the downloaded folder + - script: | + mv $(Pipeline.Workspace)/build/onnxruntime-java-gpu $(Pipeline.Workspace)/build/onnxruntime-java + + - task: Maven@4 + displayName: 'Download Dependencies' + inputs: + mavenPomFile: '$(Build.SourcesDirectory)/tools/ci_build/java/pom.xml' + goals: 'dependency:copy-dependencies' + options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' + publishJUnitTestResults: false + javaHomeOption: 'Path' + jdkDirectory: '/usr/lib/jvm/msopenjdk-17-amd64' + jdkVersionOption: 'Default' + mavenVersionOption: 'Default' + + # Now all the jars are in the $(Pipeline.Workspace)/build folder + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ variables.docker_base_image }} --build-arg TRT_VERSION=${{ variables.linux_trt_version }}" + Repository: onnxruntimeubi8packagestest + + - bash: | + docker run --network=none --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Pipeline.Workspace)/build:/build \ + --volume /data/models:/build/models:ro \ + onnxruntimeubi8packagestest \ + /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) + displayName: 'Test' diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml index 5a25232a90c39..bbb664a2de602 100644 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-linux.yml @@ -4,22 +4,43 @@ parameters: - name: OS displayName: Operating System type: string + values: + - linux + - macOS + - windows - name: PoolName type: string +- name: PoolDemands + type: string + default: '' + stages: - stage: Final_Jar_Testing_${{parameters.OS}} dependsOn: [] jobs: - job: Final_Jar_Testing_${{parameters.OS}} + templateContext: + type: validationJob workspace: clean: all - ${{ if eq(parameters.OS, 'MacOS') }}: + ${{ if eq(parameters.OS, 'macOS') }}: pool: - vmImage: 'macOS-15' - ${{ if eq(parameters.OS, 'Linux') }}: + os: macOS + # Use PoolName if provided, otherwise fallback to macOS-15 + ${{ if ne(parameters.PoolName, '') }}: + ${{ if contains(parameters.PoolName, '-') }}: + vmImage: ${{ parameters.PoolName }} + ${{ else }}: + name: ${{ parameters.PoolName }} + ${{ if ne(parameters.PoolDemands, '') }}: + demands: ${{ parameters.PoolDemands }} + ${{ else }}: + vmImage: 'macOS-15' + ${{ if eq(parameters.OS, 'linux') }}: pool: + os: linux name: ${{ parameters.PoolName }} variables: - name: runCodesignValidationInjection @@ -29,10 +50,15 @@ stages: - template: set-version-number-variables-step.yml - bash: | - echo "Downloading and installing Maven $(mavenVersion) for Linux..." + echo "Downloading and installing Maven $(mavenVersion)..." MAVEN_DIR="$(Agent.TempDirectory)/apache-maven-$(mavenVersion)" + # Download Maven binary - wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz + if command -v wget &> /dev/null; then + wget https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz -O $(Agent.TempDirectory)/maven.tar.gz + else + curl -L -o $(Agent.TempDirectory)/maven.tar.gz https://archive.apache.org/dist/maven/maven-3/$(mavenVersion)/binaries/apache-maven-$(mavenVersion)-bin.tar.gz + fi # Extract to the temp directory mkdir -p ${MAVEN_DIR} @@ -40,13 +66,25 @@ stages: # Add Maven's bin directory to the PATH for subsequent tasks in the job echo "##vso[task.prependpath]${MAVEN_DIR}/bin" - displayName: 'Install Maven (Linux)' - condition: and(succeeded(), eq(variables['Agent.OS'], 'Linux')) + displayName: 'Install Maven' + condition: and(succeeded(), in(variables['Agent.OS'], 'Linux', 'Darwin')) - script: | echo "Maven is now on the PATH." mvn --version + - script: | + set -e -x + if ! /usr/libexec/java_home -v 17 >/dev/null 2>&1; then + brew install --cask temurin@17 + fi + JAVA_HOME=$(/usr/libexec/java_home -v 17) + echo "JAVA_HOME is set to: $JAVA_HOME" + echo "##vso[task.setvariable variable=JAVA_HOME]$JAVA_HOME" + echo "##vso[task.prependpath]$JAVA_HOME/bin" + displayName: 'Install JDK 17 (macOS)' + condition: and(succeeded(), eq(variables['Agent.OS'], 'Darwin')) + - download: build artifact: 'onnxruntime-java' displayName: 'Download Final Jar' @@ -58,12 +96,16 @@ stages: goals: 'dependency:copy-dependencies' options: '-DoutputDirectory=$(Pipeline.Workspace)/build/onnxruntime-java' publishJUnitTestResults: false - javaHomeOption: 'JDKVersion' - jdkVersionOption: '1.17' mavenVersionOption: 'Default' + ${{ if eq(parameters.OS, 'macOS') }}: + javaHomeOption: 'Path' + jdkDirectory: '$(JAVA_HOME)' + ${{ if eq(parameters.OS, 'linux') }}: + javaHomeOption: 'JDKVersion' + jdkVersionOption: '1.17' - task: Bash@3 - displayName: 'Run Java Tests on Linux/macOS' + displayName: 'Run Java Tests' condition: and(succeeded(), in(variables['Agent.OS'], 'Linux', 'Darwin')) inputs: targetType: 'inline' @@ -80,24 +122,54 @@ stages: cd .. mkdir tests cd tests + # 1. Diagnostics + echo "System Info:" + uname -a + if [[ "$(uname)" == "Darwin" ]]; then arch; fi + echo "Java Version" + java -version + + # 2. Extract jar xf $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar rm -f $(Pipeline.Workspace)/build/onnxruntime-java/testing.jar - ls $(Pipeline.Workspace)/build/tests + + # Identify main jar (avoiding sources and javadoc jars) + MAIN_JAR=$(ls $(Pipeline.Workspace)/build/onnxruntime-java/onnxruntime-*.jar | grep -v 'sources' | grep -v 'javadoc' | head -n 1) + echo "Extracting native libs from $MAIN_JAR" + jar xf $MAIN_JAR ai/onnxruntime/native + + ls -R $(Pipeline.Workspace)/build/tests/ai echo "Java Version" java -version - - # Set the correct library path based on the OS + + + # 3. Find with robustness os_name=$(uname) - if [[ "$os_name" == "Linux" ]]; then - echo "Platform: Linux. Setting LD_LIBRARY_PATH." - export LD_LIBRARY_PATH="$(pwd):$LD_LIBRARY_PATH" - java -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ - --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" - elif [[ "$os_name" == "Darwin" ]]; then - echo "Platform: macOS. Setting DYLD_LIBRARY_PATH." - export DYLD_LIBRARY_PATH="$(pwd):$DYLD_LIBRARY_PATH" - java -DUSE_WEBGPU=1 -DUSE_COREML=1 -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ - --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" + if [[ "$os_name" == "Linux" ]]; then S_FILE="libonnxruntime.so"; else S_FILE="libonnxruntime.dylib"; fi + + echo "Searching for $S_FILE in $(pwd)..." + # Exclude .dSYM paths and find actual file + NATIVE_LIB_PATH=$(find $(pwd) -name "$S_FILE" -not -path "*.dSYM*" -type f | head -n 1) + + if [[ -n "$NATIVE_LIB_PATH" ]]; then + NATIVE_LIB_DIR=$(dirname "$NATIVE_LIB_PATH") + echo "Found native lib dir: $NATIVE_LIB_DIR" + + if [[ "$os_name" == "Linux" ]]; then + echo "Platform: Linux. Setting LD_LIBRARY_PATH." + export LD_LIBRARY_PATH="$NATIVE_LIB_DIR:$(pwd):$LD_LIBRARY_PATH" + java -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ + --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" + elif [[ "$os_name" == "Darwin" ]]; then + echo "Platform: macOS. Setting DYLD_LIBRARY_PATH." + export DYLD_LIBRARY_PATH="$NATIVE_LIB_DIR:$(pwd):$DYLD_LIBRARY_PATH" + java -DUSE_WEBGPU=1 -DUSE_COREML=1 -cp '$(Pipeline.Workspace)/build/tests:$(Pipeline.Workspace)/build/onnxruntime-java/*' org.junit.platform.console.ConsoleLauncher --scan-classpath=$(Pipeline.Workspace)/build/tests \ + --fail-if-no-tests --disable-banner --reports-dir "$(Build.ArtifactStagingDirectory)/TestResults" + fi + else + echo "Error: $S_FILE not found!" + ls -R ai + exit 1 fi diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml index de07e9e89dc81..73fe7e9797295 100644 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing-win.yml @@ -11,6 +11,8 @@ stages: clean: all pool: name: ${{ parameters.PoolName }} + templateContext: + type: validationJob variables: - name: runCodesignValidationInjection value: false diff --git a/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml b/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml new file mode 100644 index 0000000000000..0ad230f835778 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/foundry-local-nuget-packaging.yml @@ -0,0 +1,149 @@ +parameters: + DoEsrp: false + StageName: 'FoundryLocalNugetPackaging' + DependsOn: [] + PackageName: 'Microsoft.ML.OnnxRuntime.Foundry' + +stages: +- stage: ${{ parameters.StageName }} + dependsOn: ${{ parameters.DependsOn }} + jobs: + - job: ${{ parameters.StageName }} + timeoutInMinutes: 120 + pool: + name: 'onnxruntime-Win2022-GPU-A10' + os: windows + templateContext: + sdl: + codeSignValidation: + enabled: true + break: true + psscriptanalyzer: + enabled: true + binskim: + enabled: true + scanOutputDirectoryOnly: true + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: "onnxruntime-foundry-nuget" + variables: + DoEsrp: ${{ parameters.DoEsrp }} + ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] + BuildDate: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Date.BuildDate']] + BuildTime: $[stageDependencies.Setup.Set_Variables.outputs['Set_Build_Time.BuildTime']] + + steps: + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - managed nuget' + inputs: + artifactName: 'onnxruntime-managed-nuget' + targetPath: '$(Build.BinariesDirectory)/managed-nuget' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - win-x64' + inputs: + artifactName: 'onnxruntime-win-x64-cuda' + targetPath: '$(Build.BinariesDirectory)/win-x64' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - win-arm64' + inputs: + artifactName: 'onnxruntime-win-arm64' + targetPath: '$(Build.BinariesDirectory)/win-arm64' + + - task: DownloadPipelineArtifact@0 + displayName: 'Download Pipeline Artifact - osx' + inputs: + artifactName: 'onnxruntime-osx' + targetPath: '$(Build.BinariesDirectory)/osx' + + - task: UsePythonVersion@0 + inputs: + versionSpec: '3.12' + addToPath: true + + - task: PipAuthenticate@1 + displayName: 'Pip Authenticate' + inputs: + artifactFeeds: 'Lotus' + + - task: PowerShell@2 + displayName: 'Create osx directories' + inputs: + targetType: 'inline' + script: | + New-Item -ItemType Directory -Force -Path "$(Build.BinariesDirectory)/osx-arm64" | Out-Null + Move-Item -Path $(Build.BinariesDirectory)/osx/onnxruntime-osx-arm64* -Destination $(Build.BinariesDirectory)/osx-arm64 + + - task: PowerShell@2 + displayName: 'List all files downloaded' + inputs: + targetType: 'inline' + script: | + $files = Get-ChildItem $(Build.BinariesDirectory) -Recurse + foreach ($file in $files) { + Write-Host "File: $($file.FullName)" + if ($file -like "*onnxruntime*") { + Write-Host "File onnxruntime: $($file.FullName) - Size: $($file.Length)" + } + } + $dirs = Get-ChildItem $(Build.BinariesDirectory) -Directory + foreach ($dir in $dirs) { + Write-Host "Directory: $($dir.FullName)" + } + $osx_arm64_archive = Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64* + if ($osx_arm64_archive.Count -eq 0) { + Write-Host "No osx-arm64 archive found." + } else { + Write-Host "osx-arm64 archive found: $($osx_arm64_archive[0].FullName)" + } + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Extract Nuget Package Version' + inputs: + targetType: 'inline' + script: | + $nupkgs = (Get-ChildItem $(Build.BinariesDirectory)/managed-nuget -Filter Microsoft.ML.OnnxRuntime.Managed.*.nupkg -Recurse) + $package_name = $nupkgs[0].Name + $version_length = $package_name.Length - "Microsoft.ML.OnnxRuntime.Managed.".Length - ".nupkg".Length + $package_version = $package_name.Substring("Microsoft.ML.OnnxRuntime.Managed.".Length, $version_length) + Write-Host "##vso[task.setvariable variable=package_version;]$package_version" + workingDirectory: $(Build.BinariesDirectory) + + - task: PowerShell@2 + displayName: 'Extract Archives' + inputs: + targetType: 'inline' + script: | + Expand-Archive -Path $(Build.BinariesDirectory)/win-x64/onnxruntime-win-x64-cuda*.zip -DestinationPath $(Build.BinariesDirectory)/win-x64 + Expand-Archive -Path $(Build.BinariesDirectory)/win-arm64/onnxruntime-win-arm64*.zip -DestinationPath $(Build.BinariesDirectory)/win-arm64 + $osx_arm64_archive = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Filter onnxruntime-osx-arm64*)[0].FullName + tar -xzf $osx_arm64_archive -C $(Build.BinariesDirectory)/osx-arm64 2>$null + $win_x64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-x64 -Directory -Filter onnxruntime-win-x64-cuda*)[0].FullName + $win_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/win-arm64 -Directory -Filter onnxruntime-win-arm64*)[0].FullName + $osx_arm64 = (Get-ChildItem -Path $(Build.BinariesDirectory)/osx-arm64 -Directory -Filter onnxruntime-osx-arm64*)[0].FullName + Write-Host "##vso[task.setvariable variable=win_x64;]$win_x64" + Write-Host "##vso[task.setvariable variable=win_arm64;]$win_arm64" + Write-Host "##vso[task.setvariable variable=osx_arm64;]$osx_arm64" + workingDirectory: $(Build.BinariesDirectory) + + - task: PythonScript@0 + displayName: 'Generate Nuget Package' + inputs: + scriptPath: '$(Build.SourcesDirectory)/tools/nuget/generate_nuspec_for_custom_nuget.py' + arguments: '--nuspec_path "$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec" --root_dir "$(Build.SourcesDirectory)" --commit_id "$(Build.SourceVersion)" --win_arm64 "$(win_arm64)" --win_x64 "$(win_x64)" --osx_arm64 "$(osx_arm64)" --package_version "$(package_version)" --package_name "${{ parameters.PackageName }}"' + + - task: NuGetCommand@2 + displayName: 'Pack Nuget Package' + inputs: + command: 'pack' + packagesToPack: '$(Build.BinariesDirectory)/${{ parameters.PackageName }}.nuspec' + packDestination: $(Build.ArtifactStagingDirectory)\ + + - template: esrp_nuget.yml + parameters: + DisplayName: 'ESRP - sign NuGet package' + FolderPath: '$(Build.ArtifactStagingDirectory)' + DoEsrp: ${{ parameters.DoEsrp }} diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 6cb16313ef309..ef30a49b0fb83 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -96,15 +96,15 @@ jobs: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.21 ccache-git-emscripten-64bit - ./emsdk activate 4.0.21 ccache-git-emscripten-64bit + ./emsdk install 4.0.23 ccache-git-emscripten-64bit + ./emsdk activate 4.0.23 ccache-git-emscripten-64bit displayName: 'emsdk install and activate ccache for emscripten' - ${{if eq(parameters.WithCache, false)}}: - script: | set -ex cd '$(Build.SourcesDirectory)/cmake/external/emsdk' - ./emsdk install 4.0.21 - ./emsdk activate 4.0.21 + ./emsdk install 4.0.23 + ./emsdk activate 4.0.23 displayName: 'emsdk install and activate ccache for emscripten' - template: build-linux-wasm-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml index 8e454f2137ce8..795945a8581ba 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packaging-steps.yml @@ -26,6 +26,15 @@ steps: args: '-r $(Build.BinariesDirectory) -a onnxruntime-osx-${{ parameters.MacosArch }}-$(OnnxRuntimeVersion) -l libonnxruntime.$(OnnxRuntimeVersion).dylib -c Release -s $(Build.SourcesDirectory) -t $(Build.SourceVersion)' workingDirectory: '$(Build.BinariesDirectory)/Release' +- bash: | + mkdir -p $(Build.BinariesDirectory)/onnxruntime-osx-${{ parameters.MacosArch }}-$(OnnxRuntimeVersion)/testdata + cp $(Build.BinariesDirectory)/Release/libcustom_op_library.dylib $(Build.BinariesDirectory)/onnxruntime-osx-${{ parameters.MacosArch }}-$(OnnxRuntimeVersion)/testdata/libcustom_op_library.dylib + # Copy to testdata/testdata so EndToEndTests can find it when running in Debug configuration + mkdir -p $(Build.BinariesDirectory)/testdata/testdata + cp $(Build.BinariesDirectory)/Release/libcustom_op_library.dylib $(Build.BinariesDirectory)/testdata/testdata/libcustom_op_library.dylib + displayName: 'Copy custom op library' + condition: succeeded() + - task: ArchiveFiles@2 inputs: rootFolderOrFile: '$(Build.BinariesDirectory)/onnxruntime-osx-${{ parameters.MacosArch }}-$(OnnxRuntimeVersion)' @@ -40,6 +49,14 @@ steps: targetPath: '$(Build.ArtifactStagingDirectory)' artifactName: 'onnxruntime-osx-${{ parameters.MacosArch }}' +- template: java-api-artifacts-package-and-publish-steps-posix.yml + parameters: + arch: 'osx-${{ parameters.MacosArch }}' + buildConfig: 'Release' + artifactName: 'onnxruntime-java-osx-${{ parameters.MacosArch }}' + libraryName: 'libonnxruntime.dylib' + nativeLibraryName: 'libonnxruntime4j_jni.dylib' + - template: nodejs-artifacts-package-and-publish-steps-posix.yml parameters: arch: arm64 diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index bfccaef1c9852..de16ce483a9f4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -45,9 +45,20 @@ jobs: set -e -x export ONNX_ML=1 export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=ON -DONNX_WERROR=OFF" - python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' + python3 -m pip install -r '$(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/requirements.txt' + + - script: | + set -e -x + if ! /usr/libexec/java_home -v 17 >/dev/null 2>&1; then + brew install --cask temurin@17 + fi + JAVA_HOME=$(/usr/libexec/java_home -v 17) + echo "JAVA_HOME is set to: $JAVA_HOME" + echo "##vso[task.setvariable variable=JAVA_HOME]$JAVA_HOME" + echo "##vso[task.prependpath]$JAVA_HOME/bin" + displayName: 'Install JDK 17' - template: mac-cpu-packaging-steps.yml parameters: MacosArch: arm64 - AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_nodejs --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 + AdditionalBuildFlags: ${{ parameters.AdditionalBuildFlags }} --build_java --build_nodejs --use_coreml --use_webgpu --cmake_extra_defines CMAKE_OSX_ARCHITECTURES=arm64 diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml index 9f0230c4b1141..981989f519ae4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml +++ b/tools/ci_build/github/azure-pipelines/templates/publish-symbolrequestprod-api.yml @@ -52,7 +52,6 @@ steps: inputs: azureSubscription: ${{ parameters.subscription }} azurePowerShellVersion: LatestVersion - pwsh: true ScriptType: InlineScript Inline: | # Part 1: Generate an Azure Token @@ -69,7 +68,7 @@ steps: # Convert the SecureString token to a plain text string for the HTTP header # This is done just-in-time before its use. - $plainTextToken = $secureTokenObject | ConvertFrom-SecureString -AsPlainText + $plainTextToken = [System.Net.NetworkCredential]::new("", $secureTokenObject).Password Write-Host "Token converted to plain text for API call (will not be logged)." # Part 2: Publish Symbols using internal REST API diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index 9b15f389e5349..4ec074055fcc2 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -54,7 +54,7 @@ jobs: FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install coloredlogs flatbuffers numpy packaging protobuf sympy + python3 -m pip install flatbuffers numpy packaging protobuf sympy python3 -m pip install --no-index --find-links . $PYTHON_PACKAGE_NAME python3 -m pip show $PYTHON_PACKAGE_NAME python3 -c "import onnxruntime as ort; print(ort.__version__)" diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 3ea75701accb1..50229d91fbf03 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: PYTHON_VERSION type: string default: '3.11' - + - name: QNN_SDK displayName: QNN SDK Version type: string @@ -32,6 +32,8 @@ jobs: name: ${{ parameters.MACHINE_POOL }} os: windows hostArchitecture: Arm64 + demands: + - Agent.Version -equals 4.264.2 templateContext: sdl: codeSignValidation: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-cpu.yml new file mode 100644 index 0000000000000..326cfd7829f2f --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-cpu.yml @@ -0,0 +1,166 @@ +parameters: +- name: architecture + type: string + default: 'x64' + values: + - x64 + - arm64 + +- name: build_py_parameters + displayName: 'Specify extra build parameters' + type: string + default: '--use_azure' + +- name: cmake_build_type + type: string + displayName: 'CMake build type for Windows. Only for Windows CPU packages.' + default: 'RelWithDebInfo' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +jobs: +- job: Windows_py_Wheels_${{parameters.architecture}} + ${{ if eq(parameters.architecture, 'arm64') }}: + pool: + name: 'onnxruntime-qnn-windows-vs-2022-arm64' + os: windows + hostArchitecture: Arm64 + demands: + - Agent.Version -equals 4.264.2 + ${{ else }}: + pool: + name: 'onnxruntime-Win-CPU-VS2022-Latest' + os: windows + templateContext: + sdl: + codeSignValidation: + enabled: true + # TODO: check why pyd file was not signed + break: false + additionalTargetsGlobPattern: f|**\*.pyd + psscriptanalyzer: + enabled: true + binskim: + enabled: true + scanOutputDirectoryOnly: true + ${{ if eq(parameters.architecture, 'arm64') }}: + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: onnxruntime-win-$(PythonVersion)-arm64 + ${{ else }}: + outputs: + - output: pipelineArtifact + targetPath: $(Build.ArtifactStagingDirectory) + artifactName: onnxruntime-win-$(PythonVersion) + strategy: + matrix: + Python311_${{parameters.architecture}}: + PythonVersion: '3.11' + Python312_${{parameters.architecture}}: + PythonVersion: '3.12' + Python313_${{parameters.architecture}}: + PythonVersion: '3.13' + Python314_${{parameters.architecture}}: + PythonVersion: '3.14' + variables: + OnnxRuntimeBuildDirectory: '$(Build.BinariesDirectory)' + ExtraParam: ${{ parameters.build_py_parameters }} + timeoutInMinutes: 180 + workspace: + clean: all + + steps: + - checkout: self + clean: true + submodules: recursive + + - template: setup-build-tools.yml + parameters: + host_cpu_arch: ${{parameters.architecture}} + python_version: $(PythonVersion) + + - template: set-nightly-build-option-variable-step.yml + + - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt + env: + TMPDIR: "$(Agent.TempDirectory)" + + - task: PythonScript@0 + displayName: 'Build' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config ${{ parameters.cmake_build_type }} + --enable_lto + --build_dir $(Build.SourcesDirectory)\build + --skip_submodule_sync + --cmake_generator "Visual Studio 17 2022" + --enable_pybind + --enable_onnx_tests --use_vcpkg --use_vcpkg_ms_internal_asset_cache --build + ${{ parameters.build_py_parameters }} + --parallel --use_binskim_compliant_compile_flags --update + $(TelemetryOption) + + - ${{if or(eq(variables['Build.SourceBranch'], 'refs/heads/main'), startsWith(variables['Build.SourceBranch'], 'refs/heads/rel-'))}}: + - template: publish-symbolrequestprod-api.yml + parameters: + ${{if eq(variables['Build.SourceBranch'], 'refs/heads/main')}}: + symbolExpiryTime: 60 + includePublicSymbolServer: true + symbolsArtifactName: onnxruntime_cpu_win_${{ parameters.architecture }}_$(PythonVersion) + symbolsVersion: $(Build.BuildId) + symbolProject: 'ONNX Runtime' + subscription: 'OnnxrunTimeCodeSign_20240611' + searchPattern: | + $(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime.pdb + $(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_providers_shared.pdb + $(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime_pybind11_state.pdb + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.build_py_parameters }} $(NightlyBuildOption)' + workingDirectory: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + + - powershell: | + python -m pip uninstall -y onnxruntime onnxruntime-gpu -qq + Get-ChildItem -Path $(Build.ArtifactStagingDirectory)/*.whl | foreach {pip --disable-pip-version-check install --upgrade $_.fullname tabulate} + Remove-Item -Recurse -Force onnxruntime + if ("$(ExtraParam)".Split() -contains "--use_azure") { + + if( "${{parameters.architecture}}" -eq 'arm64') { + $env:path="$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\_deps\vcpkg-src\installed\arm64-windows\bin;$env:path" + } else { + $env:path="$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\_deps\vcpkg-src\installed\x64-windows\bin;$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\_deps\vcpkg-src\installed\x86-windows\bin;$env:path" + } + python onnxruntime_test_python_azure.py + } + python onnx_backend_test_series.py + workingDirectory: '$(Build.SourcesDirectory)\build\${{ parameters.cmake_build_type }}\${{ parameters.cmake_build_type }}' + displayName: 'Run Python Tests' diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index 8a1c4f8a39316..7e176b67f6685 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -48,7 +48,7 @@ stages: variables: OrtPackageId: ${{ parameters.OrtNugetPackageId }} ReleaseVersionSuffix: $[stageDependencies.Setup.Set_Variables.outputs['Set_Release_Version_Suffix.ReleaseVersionSuffix']] - commonBuildArgs: '--skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags ${{ parameters.AdditionalBuildArgs}}' + commonBuildArgs: '--skip_submodule_sync --build_shared_lib --client_package_build --cmake_generator "Visual Studio 17 2022" --config ${{ parameters.build_config }} --parallel --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags --config ${{ parameters.build_config }} ${{ parameters.AdditionalBuildArgs}}' steps: - template: set-version-number-variables-step.yml @@ -61,17 +61,10 @@ stages: parameters: QnnSDKVersion: ${{ parameters.QnnSdk }} - - task: PythonScript@0 - displayName: 'Build arm64x project - generate the def & lib file for next build' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: ' --arm64 --buildasx --build_dir $(Build.BinariesDirectory)\arm64x --use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' - - - task: PythonScript@0 - displayName: 'Build arm64ecx project - the real arm64x' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: ' --arm64ec --buildasx --build_dir $(Build.BinariesDirectory) --use_qnn --qnn_home $(QnnSDKRootDir) $(commonBuildArgs)' + - template: build-win-arm64x-steps.yml + parameters: + buildDirectory: '$(Build.BinariesDirectory)' + additionalBuildPyArgs: '$(commonBuildArgs) --use_qnn --qnn_home $(QnnSDKRootDir)' - task: CmdLine@2 displayName: 'Print contents of binaries directory' @@ -87,12 +80,13 @@ stages: Pattern: 'onnxruntime*.dll' - ${{ if eq(parameters.PublishArchive, true) }}: - - template: c-api-artifacts-package-and-publish-steps-windows-qnn.yml + - template: c-api-artifacts-package-and-publish-steps-windows.yml parameters: buildConfig: ${{ parameters.build_config }} artifactName: 'onnxruntime-win-arm64x-qnn' artifactNameNoVersionString: 'onnxruntime-win-arm64x-qnn' DoEsrp: ${{ parameters.DoEsrp }} + publishArtifactStagingDirectory: true - task: MSBuild@1 displayName: 'Restore NuGet Packages and create project.assets.json' diff --git a/tools/ci_build/github/azure-pipelines/templates/set-variable.yml b/tools/ci_build/github/azure-pipelines/templates/set-variable.yml new file mode 100644 index 0000000000000..2cf49f2f067c2 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/set-variable.yml @@ -0,0 +1,32 @@ +# Sets an ADO pipeline variable. +# See https://learn.microsoft.com/en-us/azure/devops/pipelines/process/set-variables-scripts + +parameters: +- name: name + type: string + +- name: value + type: string + +steps: +- task: PythonScript@0 + displayName: 'Set variable - ${{ parameters.name }}' + inputs: + scriptSource: inline + script: | + import os + + variable_name = os.getenv("VARIABLE_NAME") + variable_value = os.getenv("VARIABLE_VALUE") + + if not variable_name.isidentifier(): + raise ValueError(f"Variable name is not a valid identifier: '{variable_name}'") + + if "\n" in variable_value: + raise ValueError(f"Variable value should not contain any newlines: '{variable_value}'") + + print(f"Setting variable: {variable_name} = '{variable_value}'") + print(f"##vso[task.setvariable variable={variable_name}]{variable_value}") + env: + VARIABLE_NAME: ${{ parameters.name }} + VARIABLE_VALUE: ${{ parameters.value }} diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 7d6e272533696..8303547a47566 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -22,7 +22,7 @@ stages: buildSettingsFile: "tools/ci_build/github/apple/default_full_apple_framework_build_settings.json" cPodName: onnxruntime-c objcPodName: onnxruntime-objc - timeoutInMinutes: 270 + timeoutInMinutes: 360 templateContext: outputs: - output: pipelineArtifact diff --git a/tools/ci_build/github/azure-pipelines/templates/test-binary-archive-stage.yml b/tools/ci_build/github/azure-pipelines/templates/test-binary-archive-stage.yml new file mode 100644 index 0000000000000..b9b9cdc6b0eb3 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/test-binary-archive-stage.yml @@ -0,0 +1,121 @@ +# Tests an ONNX Runtime binary archive produced by the packaging pipeline. + +parameters: +- name: artifactName + type: string +- name: artifactPipelineResource + type: string +- name: previousStageName + type: string + default: '' +- name: platform + type: string +- name: agentPool + type: object +- name: agentSetupSteps + type: stepList + default: [] + +stages: +- stage: Binary_Archive_Testing_${{ replace(parameters.platform, '-', '_') }} + ${{ if ne(parameters.previousStageName, '') }}: + dependsOn: ${{ parameters.previousStageName }} + + jobs: + - job: Binary_Archive_Testing_${{ replace(parameters.platform, '-', '_') }} + pool: ${{ parameters.agentPool }} + + variables: + - name: buildConfig + value: Release + - name: relativePathFromBuildToOutputDir + ${{ if startsWith(parameters.platform, 'win') }}: + value: "${{ variables['buildConfig'] }}" + ${{ else }}: + value: "." + + steps: + - checkout: self + clean: true + submodules: none + + - ${{ each agentSetupStep in parameters.agentSetupSteps }}: + - ${{ agentSetupStep }} + + - download: ${{ parameters.artifactPipelineResource }} + artifact: ${{ parameters.artifactName }} + patterns: | + *.zip + *.tgz + displayName: Download binary archive for ${{ parameters.platform }} + + # Extract the binary archive. + # The archive contains a top-level directory like onnxruntime--/. + # After extraction, set ORT_PACKAGE_DIR to the extracted directory. + - ${{ if startsWith(parameters.platform, 'win') }}: + - task: PowerShell@2 + displayName: 'Extract binary archive' + inputs: + targetType: 'inline' + script: | + $artifactDir = "$(Pipeline.Workspace)/${{ parameters.artifactPipelineResource }}/${{ parameters.artifactName }}" + $archive = (Get-ChildItem -Path $artifactDir -Filter *.zip)[0].FullName + Write-Host "Extracting $archive" + Expand-Archive -Path $archive -DestinationPath $(Build.BinariesDirectory) + $extractedDir = (Get-ChildItem -Path $(Build.BinariesDirectory) -Directory | Where-Object { $_.Name -like "onnxruntime-*" })[0].FullName + Write-Host "Extracted to $extractedDir" + Write-Host "##vso[task.setvariable variable=ORT_PACKAGE_DIR]$extractedDir" + + - ${{ else }}: + - bash: | + set -ex + artifact_dir="$(Pipeline.Workspace)/${{ parameters.artifactPipelineResource }}/${{ parameters.artifactName }}" + archive=$(find "$artifact_dir" -name '*.tgz' | head -1) + echo "Extracting $archive" + tar -xzf "$archive" -C $(Build.BinariesDirectory) + extracted_dir=$(find $(Build.BinariesDirectory) -maxdepth 1 -type d -name 'onnxruntime-*' | head -1) + echo "Extracted to $extracted_dir" + + # Do not output ##vso[] commands with `set -x` or they may be parsed again and include a trailing quote. + set +x + echo "##vso[task.setvariable variable=ORT_PACKAGE_DIR]$extracted_dir" + displayName: 'Extract binary archive' + + # Build and run the C++ sample using the extracted ONNX Runtime package. + + - script: > + cmake + -S $(Build.SourcesDirectory)/samples/cxx + -B $(Build.BinariesDirectory)/sample_build + -DORT_HEADER_DIR:PATH=$(ORT_PACKAGE_DIR)/include + -DORT_LIBRARY_DIR:PATH=$(ORT_PACKAGE_DIR)/lib + displayName: 'Generate C++ sample build system' + + - script: | + cmake --build $(Build.BinariesDirectory)/sample_build --config $(buildConfig) + displayName: 'Build C++ sample' + + - script: > + $(Build.BinariesDirectory)/sample_build/$(relativePathFromBuildToOutputDir)/onnxruntime_sample_program + $(Build.SourcesDirectory)/samples/cxx/add_model.onnx + displayName: 'Run C++ sample' + + # For win-arm64x, also build and run for ARM64EC. + - ${{ if eq(parameters.platform, 'win-arm64x') }}: + - script: > + cmake + -S $(Build.SourcesDirectory)/samples/cxx + -B $(Build.BinariesDirectory)/sample_build_arm64ec + -DORT_HEADER_DIR:PATH=$(ORT_PACKAGE_DIR)/include + -DORT_LIBRARY_DIR:PATH=$(ORT_PACKAGE_DIR)/lib + -A ARM64EC + displayName: 'Generate C++ sample build system (ARM64EC)' + + - script: | + cmake --build $(Build.BinariesDirectory)/sample_build_arm64ec --config $(buildConfig) + displayName: 'Build C++ sample (ARM64EC)' + + - script: > + $(Build.BinariesDirectory)/sample_build_arm64ec/$(relativePathFromBuildToOutputDir)/onnxruntime_sample_program + $(Build.SourcesDirectory)/samples/cxx/add_model.onnx + displayName: 'Run C++ sample (ARM64EC)' diff --git a/tools/ci_build/github/azure-pipelines/templates/validate-package.yml b/tools/ci_build/github/azure-pipelines/templates/validate-package.yml index 529cca4586ef6..950a5a6a34f4d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/validate-package.yml +++ b/tools/ci_build/github/azure-pipelines/templates/validate-package.yml @@ -4,6 +4,7 @@ parameters: PackageType: '' PackageName: '' PackagePath: '' + IsReleaseBuild: false ScriptPath: '$(Build.SourcesDirectory)/tools/nuget/validate_package.py' workingDirectory: "$(Build.BinariesDirectory)" @@ -17,5 +18,5 @@ steps: displayName: 'Validate Package' inputs: scriptPath: '${{parameters.ScriptPath}}' - arguments: '--package_type ${{parameters.PackageType}} --package_name ${{parameters.PackageName}} --package_path ${{parameters.PackagePath}} --platforms_supported ${{parameters.PlatformsSupported}} --verify_nuget_signing ${{parameters.VerifyNugetSigning}}' + arguments: '--package_type ${{parameters.PackageType}} --package_name ${{parameters.PackageName}} --package_path ${{parameters.PackagePath}} --platforms_supported ${{parameters.PlatformsSupported}} --verify_nuget_signing ${{parameters.VerifyNugetSigning}} --is_release_build ${{parameters.IsReleaseBuild}}' workingDirectory: ${{parameters.workingDirectory}} diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index cfb752ddc2b58..8a5584c111525 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -177,12 +177,40 @@ stages: - script: python -m pip install -r $(Build.SourcesDirectory)\tools\ci_build\github\windows\python\requirements.txt - - task: PythonScript@0 - displayName: 'Generate cmake config' - inputs: - scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' - arguments: '--parallel 16 --use_vcpkg --use_vcpkg_ms_internal_asset_cache --config RelWithDebInfo --use_binskim_compliant_compile_flags --enable_lto --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --build --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} $(timeoutParameter) $(buildJavaParameter)' - workingDirectory: '$(Build.BinariesDirectory)' + - template: set-variable.yml + parameters: + name: commonBuildPyArgs + value: >- + --config RelWithDebInfo + --parallel + --use_vcpkg + --use_vcpkg_ms_internal_asset_cache + --use_binskim_compliant_compile_flags + --enable_lto + --disable_rtti + --skip_submodule_sync + --build_shared_lib + --update --build + --cmake_generator "$(VSGenerator)" + --enable_onnx_tests + $(TelemetryOption) + ${{ parameters.buildparameter }} + $(timeoutParameter) + $(buildJavaParameter) + + - ${{ if eq(parameters.msbuildPlatform, 'arm64x') }}: + - template: build-win-arm64x-steps.yml + parameters: + buildDirectory: '$(Build.BinariesDirectory)' + additionalBuildPyArgs: '$(commonBuildPyArgs)' + + - ${{ else }}: + - task: PythonScript@0 + displayName: 'Generate build system and build' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: '$(commonBuildPyArgs) --build_dir $(Build.BinariesDirectory)' + workingDirectory: '$(Build.BinariesDirectory)' # For CPU job, tests are run in the same machine as building - ${{ if eq(parameters.buildJava, 'true') }}: diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index 51581814a4b81..f5b5c4cbdb3e4 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -37,7 +37,10 @@ parameters: jobs: - job: 'BUILD_QNN_EP' - pool: 'onnxruntime-qnn-windows-vs-2022-arm64' + pool: + name: 'onnxruntime-qnn-windows-vs-2022-arm64' + demands: + - Agent.Version -equals 4.264.2 variables: DOTNET_SKIP_FIRST_TIME_EXPERIENCE: true buildArch: arm64 @@ -83,7 +86,7 @@ jobs: --config $(BuildConfig) --build_dir $(Build.BinariesDirectory) --cmake_generator "Visual Studio 17 2022" - --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache + --build_shared_lib --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_qnn $(QnnLibKind) --qnn_home $(QnnSDKRootDir) --update --build --parallel $(ExtraQnnBuildArgs) diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index f5b4c38c85d4c..88eff3ebff86a 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -27,6 +27,17 @@ if [[ $LIB_NAME == *.dylib ]] then dsymutil $BINARY_DIR/$ARTIFACT_NAME/lib/$LIB_NAME -o $BINARY_DIR/$ARTIFACT_NAME/lib/$LIB_NAME.dSYM strip -S $BINARY_DIR/$ARTIFACT_NAME/lib/$LIB_NAME + + # ORT NuGet packaging expects the unversioned library (libonnxruntime.dylib) to contain the binary content, + # because the versioned library is excluded by the nuspec generation script. + # We explicitly overwrite the symlink with the real file to ensure 'nuget pack' (especially on Windows) + # doesn't pack an empty/broken symlink. + # Only applies to versioned libonnxruntime libraries (e.g. libonnxruntime.1.24.0.dylib). + if [[ "$LIB_NAME" =~ ^libonnxruntime\..*\.dylib$ && -L "$BINARY_DIR/$ARTIFACT_NAME/lib/libonnxruntime.dylib" ]]; then + rm "$BINARY_DIR/$ARTIFACT_NAME/lib/libonnxruntime.dylib" + cp "$BINARY_DIR/$ARTIFACT_NAME/lib/$LIB_NAME" "$BINARY_DIR/$ARTIFACT_NAME/lib/libonnxruntime.dylib" + fi + # copy the CoreML EP header for macOS build (libs with .dylib ext) cp $SOURCE_DIR/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include else diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2404_gpu b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2404_gpu index 766a2c8a8b73b..0c63b7775256a 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2404_gpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_2404_gpu @@ -49,7 +49,9 @@ RUN apt-get update && \ libnvonnxparsers-dev=${TRT_VERSION} \ libnvonnxparsers10=${TRT_VERSION} \ tensorrt-dev=${TRT_VERSION} \ - libnvinfer-bin=${TRT_VERSION} && \ + libnvinfer-bin=${TRT_VERSION} \ + libnvinfer-headers-python-plugin-dev=${TRT_VERSION} \ + libnvinfer-win-builder-resource10=${TRT_VERSION} && \ rm -rf /var/lib/apt/lists/* COPY scripts /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt index 42bee7a892b11..7e2b6e74cfdde 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/scripts/requirements.txt @@ -1,5 +1,5 @@ -numpy==2.2.6; python_version < "3.14" -numpy==2.3.2; python_version >= "3.14" +numpy==2.2.6; python_version < "3.11" +numpy==2.4.2; python_version >= "3.11" mypy pytest setuptools>=68.2.2 @@ -7,4 +7,4 @@ wheel protobuf==4.25.8 sympy==1.14 flatbuffers -onnx==1.20.1; python_version < "3.14" +onnx==1.20.1 diff --git a/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt b/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt index c5fc16837e093..63a8e96d8c128 100644 --- a/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/lort/requirements.txt @@ -3,13 +3,13 @@ beartype==0.15.0 flatbuffers cerberus h5py -onnx==1.20.1; python_version < "3.14" +onnx==1.20.1 # Python dependencies required for pytorch development astunparse expecttest!=0.2.0 hypothesis -numpy==2.2.6; python_version < "3.14" -numpy==2.3.2; python_version >= "3.14" +numpy==2.2.6; python_version < "3.11" +numpy==2.4.2; python_version >= "3.11" psutil pyyaml requests diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index 8f5d0776501c0..ffcad5ee67208 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -1,5 +1,5 @@ -numpy==2.2.6; python_version < "3.14" -numpy==2.3.2; python_version >= "3.14" +numpy==2.2.6; python_version < "3.11" +numpy==2.4.2; python_version >= "3.11" mypy pytest setuptools>=68.2.2 @@ -8,6 +8,5 @@ protobuf==6.33.0 sympy==1.14 flatbuffers neural-compressor>=2.2.1 -triton==3.2.0; python_version < "3.14" -triton==3.5.0; python_version >= "3.14" -onnx==1.20.1; python_version < "3.14" +triton==3.5.0 +onnx==1.20.1 diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 85a9c6391af80..ad57cc715589b 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -1,6 +1,6 @@ cerberus -numpy==2.2.6; python_version < "3.14" -numpy==2.3.2; python_version >= "3.14" +numpy==2.2.6; python_version < "3.11" +numpy==2.4.2; python_version >= "3.11" mypy pytest setuptools==78.1.1 @@ -10,6 +10,6 @@ sympy==1.14 flatbuffers protobuf==6.33.0 packaging -onnxscript==0.5.3; python_version < "3.14" -onnx-ir==0.1.10; python_version < "3.14" -onnx==1.20.1; python_version < "3.14" +onnxscript==0.6.2 +onnx-ir==0.1.16 +onnx==1.20.1 diff --git a/tools/ci_build/github/linux/java_linux_final_test.sh b/tools/ci_build/github/linux/java_linux_final_test.sh index cdbfd2bad10a8..67588c08dbf2a 100755 --- a/tools/ci_build/github/linux/java_linux_final_test.sh +++ b/tools/ci_build/github/linux/java_linux_final_test.sh @@ -33,11 +33,13 @@ mkdir tests cd tests jar xf ../onnxruntime-java/testing.jar rm -f ../onnxruntime-java/testing.jar +echo "Contents of tests directory ($BINARY_DIR/tests):" +ls "$BINARY_DIR/tests" echo "Java Version" java -version echo "Directories created" echo "Library path:" "$LD_LIBRARY_PATH" -java -DUSE_CUDA=1 -cp "$BINARY_DIR/tests:$BINARY_DIR/onnxruntime-java/*" org.junit.platform.console.ConsoleLauncher --scan-classpath=$BINARY_DIR/tests \ +java -DUSE_CUDA=1 -cp "$BINARY_DIR/tests:$BINARY_DIR/onnxruntime-java/*" org.junit.platform.console.ConsoleLauncher --scan-classpath="$BINARY_DIR/tests" \ --fail-if-no-tests --disable-banner diff --git a/tools/ci_build/github/linux/python/requirements.txt b/tools/ci_build/github/linux/python/requirements.txt index 6a474973d4f0c..d95e44bb3a280 100644 --- a/tools/ci_build/github/linux/python/requirements.txt +++ b/tools/ci_build/github/linux/python/requirements.txt @@ -1,5 +1,5 @@ -numpy==2.2.6; python_version < "3.14" -numpy==2.3.2; python_version >= "3.14" +numpy==2.2.6; python_version < "3.11" +numpy==2.4.2; python_version >= "3.11" mypy pytest setuptools>=68.2.2 @@ -8,8 +8,8 @@ protobuf==6.33.0 sympy==1.14 flatbuffers psutil -onnxscript==0.5.3; python_version < "3.14" -onnx-ir==0.1.10; python_version < "3.14" +onnxscript==0.6.2 +onnx-ir==0.1.16 jinja2 markupsafe -onnx==1.20.1; python_version < "3.14" +onnx==1.20.1 diff --git a/tools/ci_build/github/windows/bundle_dml_package.ps1 b/tools/ci_build/github/windows/bundle_dml_package.ps1 index ef7f781096b25..36088e772bf2d 100644 --- a/tools/ci_build/github/windows/bundle_dml_package.ps1 +++ b/tools/ci_build/github/windows/bundle_dml_package.ps1 @@ -27,17 +27,27 @@ $arm64ExtractPath = "win-dml-arm64-unzipped" Write-Host "Extracting $arm64ZipFile to $arm64ExtractPath..." & $sevenZipPath x $arm64ZipFile -o"$arm64ExtractPath" -y +# Debug: List contents of extracted arm64 zip +Write-Host "Contents of $arm64ExtractPath (recursive):" +Get-ChildItem -Path $arm64ExtractPath -Recurse | ForEach-Object { Write-Host " - $($_.FullName)" } + # 2. Find the target NuGet package. # It finds all .nupkg files that do not contain "Managed" in their name. -$nupkgFiles = Get-ChildItem -Path . -Recurse -Filter *.nupkg | Where-Object { $_.Name -notlike "*Managed*" } +$nupkgFiles = Get-ChildItem -Path . -Filter *.nupkg | Where-Object { ($_.Name -notlike "*Managed*") -and ($_.Name -notlike "*.symbols.nupkg") } + +Write-Host "Found $($nupkgFiles.Count) candidate nupkg file(s) for bundling:" +$nupkgFiles | ForEach-Object { Write-Host " - $($_.FullName)" } -# 3. Validate that exactly one package was found. -if ($nupkgFiles.Count -ne 1) { - Write-Error "Error: Expected to find exactly one non-managed NuGet package, but found $($nupkgFiles.Count)." +# 3. Select the best package (shortest name prefers Release over Dev, and Main over Symbols) +if ($nupkgFiles.Count -eq 0) { + Write-Error "Error: No matching NuGet packages found to bundle into." exit 1 } -$nupkg = $nupkgFiles[0] -Write-Host "Found package to process: $($nupkg.Name)" +if ($nupkgFiles.Count -gt 1) { + Write-Warning "Found multiple packages. Selecting the one with the shortest filename as the target for bundling." +} +$nupkg = $nupkgFiles | Sort-Object {$_.Name.Length} | Select-Object -First 1 +Write-Host "Selected target package: $($nupkg.Name)" # 4. Validate the package name matches the expected format. if ($nupkg.Name -notlike "Microsoft.ML.OnnxRuntime.DirectML*.nupkg") { @@ -61,14 +71,36 @@ New-Item -ItemType Directory -Path $tempDir | Out-Null Write-Host "Extracting $($nupkg.Name) to $tempDir..." & $sevenZipPath x $nupkg.FullName -o"$tempDir" -y +# Debug: Print the .nuspec content +$nuspecFile = Get-ChildItem -Path $tempDir -Filter *.nuspec | Select-Object -First 1 +if ($nuspecFile) { + Write-Host "Found manifest: $($nuspecFile.FullName)" + Write-Host "--- Manifest Content ---" + Get-Content $nuspecFile.FullName | ForEach-Object { Write-Host $_ } + Write-Host "------------------------" +} + +# Debug: List contents of extracted target nupkg +Write-Host "Contents of $tempDir (recursive):" +Get-ChildItem -Path $tempDir -Recurse | ForEach-Object { Write-Host " - $($_.FullName)" } + # Step B: Create the new runtime directory structure. $newRuntimePath = Join-Path $tempDir "runtimes\win-arm64\native" +Write-Host "Ensuring destination path exists: $newRuntimePath" New-Item -ItemType Directory -Path $newRuntimePath -Force | Out-Null # Step C: Copy the ARM64 binaries into the new structure. $arm64SourcePath = Join-Path . "$arm64ExtractPath\runtimes\win-arm64\native" -Write-Host "Copying ARM64 binaries from $arm64SourcePath to $newRuntimePath..." -Copy-Item -Path "$arm64SourcePath\*" -Destination $newRuntimePath -Recurse -Force +if (Test-Path $arm64SourcePath) { + Write-Host "Copying ARM64 binaries from $arm64SourcePath to $newRuntimePath..." + $filesToCopy = Get-ChildItem -Path "$arm64SourcePath\*" + Write-Host "Files found in source: $($filesToCopy.Count)" + $filesToCopy | ForEach-Object { Write-Host " -> $($_.Name)" } + Copy-Item -Path "$arm64SourcePath\*" -Destination $newRuntimePath -Recurse -Force +} else { + Write-Error "Error: ARM64 source path not found: $arm64SourcePath. Bailing out to avoid creating a broken package." + exit 1 +} # Step D: Delete the original nupkg file. Remove-Item -Path $nupkg.FullName -Force @@ -79,6 +111,13 @@ Push-Location $tempDir & $sevenZipPath a -tzip "$($nupkg.FullName)" ".\" -r Pop-Location +# Debug: Check final nupkg existence +if (Test-Path $nupkg.FullName) { + Write-Host "Final package created successfully: $($nupkg.FullName)" + $finalSize = (Get-Item $nupkg.FullName).Length + Write-Host "Final package size: $finalSize bytes" +} + # --- Cleanup and Final Steps --- Write-Host "Cleaning up temporary directory $tempDir..." Remove-Item -Recurse -Force $tempDir @@ -91,4 +130,4 @@ Write-Host "Copying final artifact to $ArtifactStagingDirectory..." Copy-Item -Path ".\Microsoft.ML.OnnxRuntime.DirectML*.nupkg" -Destination $ArtifactStagingDirectory -Force Write-Host "---" -Write-Host "Script completed successfully." \ No newline at end of file +Write-Host "Script completed successfully." diff --git a/tools/ci_build/github/windows/jar_packaging.py b/tools/ci_build/github/windows/jar_packaging.py index b399782e9410f..f4bc6899260c1 100644 --- a/tools/ci_build/github/windows/jar_packaging.py +++ b/tools/ci_build/github/windows/jar_packaging.py @@ -232,12 +232,11 @@ def run_packaging(package_type: str, build_dir: str): "platforms": [ {"path": "onnxruntime-java-linux-x64", "lib": "libcustom_op_library.so", "archive_lib": True}, {"path": "onnxruntime-java-linux-aarch64", "lib": "libcustom_op_library.so", "archive_lib": False}, + {"path": "onnxruntime-java-osx-arm64", "lib": "libcustom_op_library.dylib", "archive_lib": True}, ] }, "gpu": { - "platforms": [ - {"path": "onnxruntime-java-linux-x64", "lib": "libcustom_op_library.so", "archive_lib": False} - ] + "platforms": [{"path": "onnxruntime-java-linux-x64", "lib": "libcustom_op_library.so", "archive_lib": True}] }, } diff --git a/tools/ci_build/github/windows/jar_packaging_test.py b/tools/ci_build/github/windows/jar_packaging_test.py index 2dd61cf9c3088..e4f7e4945442c 100644 --- a/tools/ci_build/github/windows/jar_packaging_test.py +++ b/tools/ci_build/github/windows/jar_packaging_test.py @@ -52,14 +52,19 @@ def _setup_test_directory(package_type: str, version_string: str): create_empty_file(linux_native_dir / "libonnxruntime_providers_cuda.so") (linux_dir / "_manifest" / "spdx_2.2").mkdir(parents=True, exist_ok=True) - # --- Additional platforms (for CPU test) --- + # --- macOS and other platforms (for CPU test) --- if package_type == "cpu": - # Add linux-aarch64 for CPU test + # Add linux-aarch64 and osx-arm64 for CPU test linux_aarch64_dir = java_artifact_dir / "onnxruntime-java-linux-aarch64" linux_aarch64_native_dir = linux_aarch64_dir / "ai" / "onnxruntime" / "native" / "linux-aarch64" linux_aarch64_native_dir.mkdir(parents=True, exist_ok=True) create_empty_file(linux_aarch64_dir / "libcustom_op_library.so") + osx_arm64_dir = java_artifact_dir / "onnxruntime-java-osx-arm64" + osx_arm64_native_dir = osx_arm64_dir / "ai" / "onnxruntime" / "native" / "osx-arm64" + osx_arm64_native_dir.mkdir(parents=True, exist_ok=True) + create_empty_file(osx_arm64_dir / "libcustom_op_library.dylib") + return tmp_path return _setup_test_directory @@ -128,9 +133,12 @@ def test_cpu_packaging(directory_setup_factory, version_string): with zipfile.ZipFile(testing_jar_path, "r") as zf: jar_contents = zf.namelist() assert "libcustom_op_library.so" in jar_contents + assert "libcustom_op_library.dylib" in jar_contents # 3. Verify the custom op libraries were removed from the source directories linux_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-x64" linux_aarch64_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-linux-aarch64" + osx_arm64_dir = temp_build_dir / "java-artifact" / "onnxruntime-java-osx-arm64" assert not (linux_dir / "libcustom_op_library.so").exists() assert not (linux_aarch64_dir / "libcustom_op_library.so").exists() + assert not (osx_arm64_dir / "libcustom_op_library.dylib").exists() diff --git a/tools/ci_build/github/windows/python/requirements.txt b/tools/ci_build/github/windows/python/requirements.txt index 4e24bf7cbfa97..83593ff47e453 100644 --- a/tools/ci_build/github/windows/python/requirements.txt +++ b/tools/ci_build/github/windows/python/requirements.txt @@ -1,5 +1,5 @@ -numpy==2.2.6; python_version < "3.14" -numpy==2.3.2; python_version >= "3.14" +numpy==2.2.6; python_version < "3.11" +numpy==2.4.2; python_version >= "3.11" mypy pytest setuptools>=68.2.2 @@ -8,11 +8,10 @@ protobuf==6.33.0 sympy==1.14 flatbuffers psutil -onnxscript==0.5.3; python_version < "3.14" -onnx-ir==0.1.10; python_version < "3.14" +onnxscript==0.6.2 +onnx-ir==0.1.16 jinja2 markupsafe semver packaging -coloredlogs -onnx==1.20.1; python_version < "3.14" +onnx==1.20.1 diff --git a/tools/ci_build/github/windows/select_dml_package.ps1 b/tools/ci_build/github/windows/select_dml_package.ps1 new file mode 100644 index 0000000000000..b6a3ed936ae23 --- /dev/null +++ b/tools/ci_build/github/windows/select_dml_package.ps1 @@ -0,0 +1,86 @@ +# select_dml_package.ps1 +# Helper script to select the correct DML NuGet package based on build type +# Usage: select_dml_package.ps1 -SourceDir -IsReleaseBuild -Action [-DestinationDir ] [-NewName ] + +param( + [Parameter(Mandatory=$true)] + [string]$SourceDir, + + [Parameter(Mandatory=$true)] + [string]$IsReleaseBuild, + + [Parameter(Mandatory=$true)] + [ValidateSet("copy", "rename")] + [string]$Action, + + [Parameter(Mandatory=$false)] + [string]$DestinationDir, + + [Parameter(Mandatory=$false)] + [string]$NewName +) + +$ErrorActionPreference = "Stop" + +Write-Host "Searching for packages in: $SourceDir" +Write-Host "IsReleaseBuild: $IsReleaseBuild" +Write-Host "Action: $Action" + +# Convert string to boolean +$isRelease = [System.Convert]::ToBoolean($IsReleaseBuild) + +# Find all matching packages +$allPackages = Get-ChildItem -Path $SourceDir -Filter "Microsoft.ML.OnnxRuntime.DirectML.*.nupkg" +Write-Host "Found $($allPackages.Count) total package(s):" +$allPackages | ForEach-Object { Write-Host " - $($_.Name)" } + +# Filter packages based on build type +$filteredPackages = $allPackages | Where-Object { + $name = $_.Name + $isSymbols = $name -like "*symbols*" + $isDev = $name -like "*-dev*" + + if ($isSymbols) { + return $false + } + + if ($isRelease) { + return -not $isDev + } else { + return $isDev + } +} + +Write-Host "After filtering (isRelease=$isRelease), found $($filteredPackages.Count) matching package(s):" +$filteredPackages | ForEach-Object { Write-Host " - $($_.Name)" } + +if ($filteredPackages.Count -eq 0) { + Write-Error "No matching package found!" + exit 1 +} + +# Select the first matching package (sorted by name length for consistency) +$selectedPackage = $filteredPackages | Sort-Object { $_.Name.Length } | Select-Object -First 1 +Write-Host "Selected package: $($selectedPackage.FullName)" + +# Perform the action +if ($Action -eq "copy") { + if (-not $DestinationDir) { + Write-Error "DestinationDir is required for copy action" + exit 1 + } + Write-Host "Copying to: $DestinationDir" + Copy-Item -Path $selectedPackage.FullName -Destination $DestinationDir -Force + Write-Host "Copy successful." +} +elseif ($Action -eq "rename") { + if (-not $NewName) { + Write-Error "NewName is required for rename action" + exit 1 + } + Write-Host "Renaming to: $NewName" + Rename-Item -Path $selectedPackage.FullName -NewName $NewName -Force + Write-Host "Rename successful." +} + +exit 0 diff --git a/tools/ci_build/requirements/transformers-test/requirements.txt b/tools/ci_build/requirements/transformers-test/requirements.txt index e95509c7ddec3..c764225dbc98d 100644 --- a/tools/ci_build/requirements/transformers-test/requirements.txt +++ b/tools/ci_build/requirements/transformers-test/requirements.txt @@ -2,15 +2,14 @@ packaging # protobuf and numpy is same as tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt protobuf==6.33.0 -numpy==2.2.6; python_version < "3.14" -numpy==2.3.2; python_version >= "3.14" -torch==2.8.0 -torchvision==0.23.0 -coloredlogs==15.0 +numpy==2.2.6; python_version < "3.11" +numpy==2.4.2; python_version >= "3.11" +torch==2.10.0 +torchvision==0.25.0 transformers==4.52.1 parameterized>=0.8.1 sentencepiece psutil einops -onnxscript==0.5.3; python_version < "3.14" -onnx-ir==0.1.10; python_version < "3.14" +onnxscript==0.6.2 +onnx-ir==0.1.16 diff --git a/tools/nuget/generate_nuspec_for_custom_nuget.py b/tools/nuget/generate_nuspec_for_custom_nuget.py index 3abd03119cbc5..6e51c51895191 100644 --- a/tools/nuget/generate_nuspec_for_custom_nuget.py +++ b/tools/nuget/generate_nuspec_for_custom_nuget.py @@ -14,7 +14,6 @@ def generate_files(lines, args): platform_map = { "win-arm64": args.win_arm64, "win-x64": args.win_x64, - "osx-x64": args.osx_x64, "osx-arm64": args.osx_arm64, } @@ -116,7 +115,6 @@ def parse_arguments(): parser.add_argument("--win_arm64", required=True, help="Ort win-arm64 directory") parser.add_argument("--win_x64", required=True, help="Ort win-x64 directory") parser.add_argument("--osx_arm64", required=True, help="Ort osx-arm64 directory") - parser.add_argument("--osx_x64", required=True, help="Ort osx-x64 directory") parser.add_argument("--package_version", required=True, help="Version of the package") parser.add_argument("--package_name", required=True, help="Name of the package") diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 9884cbf5793df..1f882c847c707 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -238,6 +238,9 @@ def add_common_dependencies(xml_text, package_name, version): xml_text.append('') xml_text.append('') + if package_name == "Microsoft.ML.OnnxRuntime.Foundry": + xml_text.append('') + def generate_dependencies(xml_text, package_name, version): dml_dependency = '' diff --git a/tools/nuget/validate_package.py b/tools/nuget/validate_package.py index 961109c595ed5..0ad1fc07eafd7 100644 --- a/tools/nuget/validate_package.py +++ b/tools/nuget/validate_package.py @@ -67,6 +67,10 @@ def parse_arguments(): "--verify_nuget_signing", help="Flag indicating if Nuget package signing is to be verified. Only accepts 'true' or 'false'", ) + parser.add_argument( + "--is_release_build", + help="Flag indicating if validating a release build or dev build. Only accepts 'true' or 'false'", + ) return parser.parse_args() @@ -285,7 +289,14 @@ def validate_zip(args): def validate_nuget(args): files = glob.glob(os.path.join(args.package_path, args.package_name)) - nuget_packages_found_in_path = [i for i in files if i.endswith(".nupkg") and "Managed" not in i] + is_release_build = args.is_release_build and args.is_release_build.lower() == "true" + nuget_packages_found_in_path = [ + i + for i in files + if i.endswith(".nupkg") + and "Managed" not in i + and ((is_release_build and "-dev" not in i) or (not is_release_build and "-dev" in i)) + ] if len(nuget_packages_found_in_path) != 1: print("Nuget packages found in path: ") print(nuget_packages_found_in_path)