diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 045a841e31ef5..d33e4d923a0bc 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -61,7 +61,7 @@ jobs: distribution: 'microsoft' - if: ${{ matrix.language == 'javascript' }} - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: 20 diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index d7f3288036ecd..5aaab5f8e1a10 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -117,7 +117,7 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] steps: - uses: actions/checkout@v5 - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 - uses: reviewdog/action-eslint@v1 diff --git a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml index 37a5279411895..cf3bee49f1971 100644 --- a/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml +++ b/.github/workflows/linux-wasm-ci-build-and-test-workflow.yml @@ -37,7 +37,15 @@ jobs: runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-Ubuntu2204-AMD-CPU"] env: buildArch: x64 - common_build_args: --parallel ${{ inputs.use_vcpkg == true && '--use_vcpkg --use_vcpkg_ms_internal_asset_cache' || '' }} --config ${{ inputs.build_config }} --skip_submodule_sync --build_wasm --enable_wasm_simd ${{ inputs.enable_wasm_threads == true && '--enable_wasm_threads' || '' }} ${{ inputs.extra_build_args }} + common_build_args: >- + --parallel + ${{ inputs.use_vcpkg == true && '--use_vcpkg --use_vcpkg_ms_internal_asset_cache' || '' }} + --config ${{ inputs.build_config }} + --skip_submodule_sync + --build_wasm + --enable_wasm_simd + ${{ inputs.enable_wasm_threads == true && '--enable_wasm_threads' || '' }} + ${{ inputs.extra_build_args }} steps: - name: Checkout code @@ -46,7 +54,7 @@ jobs: submodules: recursive - name: Set up Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: "22" @@ -70,6 +78,7 @@ jobs: python ./tools/ci_build/build.py \ ${{ env.common_build_args }} \ --build_dir ${{ github.workspace }}/build/wasm_inferencing \ + ${{ inputs.build_config == 'Release' && '--enable_wasm_api_exception_catching' || '' }} \ --skip_tests working-directory: ${{ github.workspace }} @@ -82,6 +91,7 @@ jobs: --use_jsep \ --use_webnn \ --target onnxruntime_webassembly \ + ${{ inputs.build_config == 'Release' && '--enable_wasm_api_exception_catching' || '' }} \ --skip_tests working-directory: ${{ github.workspace }} @@ -94,6 +104,20 @@ jobs: --use_webgpu \ --use_webnn \ --target onnxruntime_webassembly \ + ${{ inputs.build_config == 'Release' && '--enable_wasm_api_exception_catching' || '' }} \ + --skip_tests + working-directory: ${{ github.workspace }} + + - name: Build (simd + threads + WebGPU experimental, JSPI) + if: ${{ inputs.build_webgpu == true }} + run: | + python ./tools/ci_build/build.py \ + ${{ env.common_build_args }} \ + --build_dir ${{ github.workspace }}/build/wasm_inferencing_webgpu_jspi \ + --use_webgpu \ + --use_webnn \ + --enable_wasm_jspi \ + --target onnxruntime_webassembly \ --skip_tests working-directory: ${{ github.workspace }} @@ -111,6 +135,10 @@ jobs: cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.wasm ${{ github.workspace }}/artifacts/wasm/ cp ${{ github.workspace }}/build/wasm_inferencing_webgpu/${{ inputs.build_config }}/ort-wasm-simd-threaded.asyncify.mjs ${{ github.workspace }}/artifacts/wasm/ fi + if [ -d ${{ github.workspace }}/build/wasm_inferencing_webgpu_jspi ]; then + cp ${{ github.workspace }}/build/wasm_inferencing_webgpu_jspi/${{ inputs.build_config }}/ort-wasm-simd-threaded.jspi.wasm ${{ github.workspace }}/artifacts/wasm/ + cp ${{ github.workspace }}/build/wasm_inferencing_webgpu_jspi/${{ inputs.build_config }}/ort-wasm-simd-threaded.jspi.mjs ${{ github.workspace }}/artifacts/wasm/ + fi - name: Upload WASM artifacts if: ${{ inputs.skip_publish != true }} diff --git a/.github/workflows/linux_cuda_ci.yml b/.github/workflows/linux_cuda_ci.yml index 61f994fcebd0e..9ec5ea47fbf0c 100644 --- a/.github/workflows/linux_cuda_ci.yml +++ b/.github/workflows/linux_cuda_ci.yml @@ -27,9 +27,9 @@ jobs: build_config: Release architecture: x64 dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1' + docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' docker_image_repo: onnxruntimecuda12manylinuxbuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --enable_cuda_profiling --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -55,7 +55,7 @@ jobs: with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda image-name: ghcr.io/microsoft/onnxruntime/onnxruntimecuda12manylinuxbuild - build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1' + build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1' push: true azure-container-registry-name: onnxruntimebuildcache env: @@ -99,5 +99,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda' - extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.github/workflows/linux_minimal_build.yml b/.github/workflows/linux_minimal_build.yml index 1870db522bab8..d7b6303d3cdc9 100644 --- a/.github/workflows/linux_minimal_build.yml +++ b/.github/workflows/linux_minimal_build.yml @@ -32,7 +32,7 @@ jobs: with: submodules: false - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 @@ -68,7 +68,7 @@ jobs: uses: actions/checkout@v5 with: submodules: false - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 @@ -125,7 +125,7 @@ jobs: uses: actions/checkout@v5 with: submodules: false - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 @@ -159,7 +159,7 @@ jobs: uses: actions/checkout@v5 with: submodules: false - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 @@ -191,7 +191,7 @@ jobs: uses: actions/checkout@v5 with: submodules: false - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 @@ -225,7 +225,7 @@ jobs: uses: actions/checkout@v5 with: submodules: false - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 @@ -508,7 +508,7 @@ jobs: uses: actions/checkout@v5 with: submodules: false - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: 20 - name: Download Test Data Artifact diff --git a/.github/workflows/linux_tensorrt_ci.yml b/.github/workflows/linux_tensorrt_ci.yml index 9fb6625466c72..064ad87794cdd 100644 --- a/.github/workflows/linux_tensorrt_ci.yml +++ b/.github/workflows/linux_tensorrt_ci.yml @@ -27,9 +27,9 @@ jobs: build_config: Release architecture: x64 dockerfile_path: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda - docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 --build-arg TRT_VERSION=10.9.0.34-1.cuda12.8 --network=host' + docker_build_args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 --build-arg TRT_VERSION=10.9.0.34-1.cuda12.8 --network=host' docker_image_repo: onnxruntimetensorrt86gpubuild - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' run_tests: false # <<< Do not run tests in this job upload_build_output: true # <<< Upload the build/Release directory @@ -57,7 +57,7 @@ jobs: with: dockerfile: ${{ github.workspace }}/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda image-name: ghcr.io/microsoft/onnxruntime/onnxruntimetensorrt86gpubuild - build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_ubi8_gcc12:20250124.1 --build-arg TRT_VERSION=10.9.0.34-1.cuda12.8 --network=host' + build-args: '--build-arg BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 --build-arg TRT_VERSION=10.9.0.34-1.cuda12.8 --network=host' push: true azure-container-registry-name: onnxruntimebuildcache env: @@ -101,5 +101,5 @@ jobs: build_config: Release mode: 'test' # Set mode to test execution_providers: 'cuda tensorrt' - extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.2 --cuda_home=/usr/local/cuda-12.2 --cudnn_home=/usr/local/cuda-12.2 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' + extra_build_flags: '--use_binskim_compliant_compile_flags --build_wheel --cuda_version=12.8 --cuda_home=/usr/local/cuda-12.8 --cudnn_home=/usr/local/cuda-12.8 --use_tensorrt --tensorrt_home /usr --build_java --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=90 onnxruntime_BUILD_UNIT_TESTS=ON onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS=ON' python_path_prefix: 'PATH=/opt/python/cp310-cp310/bin:$PATH' diff --git a/.github/workflows/pr_checks.yml b/.github/workflows/pr_checks.yml index 1d76a9ba413ed..abe627f4ff7bc 100644 --- a/.github/workflows/pr_checks.yml +++ b/.github/workflows/pr_checks.yml @@ -47,6 +47,6 @@ jobs: set +e lintrunner f --all-files -v exit 0 - - uses: parkerbxyz/suggest-changes@v2 + - uses: parkerbxyz/suggest-changes@v3 with: comment: 'You can commit the suggested changes from lintrunner.' diff --git a/.github/workflows/publish-js-apidocs.yml b/.github/workflows/publish-js-apidocs.yml index 37908b8506c93..cada1ceecd8e0 100644 --- a/.github/workflows/publish-js-apidocs.yml +++ b/.github/workflows/publish-js-apidocs.yml @@ -25,7 +25,7 @@ jobs: steps: - uses: actions/checkout@v5 - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: 18 - name: Generate JS docs diff --git a/.github/workflows/react_native.yml b/.github/workflows/react_native.yml index 08426f3d0ccb5..04f40c58868ca 100644 --- a/.github/workflows/react_native.yml +++ b/.github/workflows/react_native.yml @@ -94,7 +94,7 @@ jobs: architecture: x64 - name: Use Node.js 22.x - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '22.x' @@ -230,7 +230,7 @@ jobs: run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer - name: Use Node.js 22.x - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '22.x' @@ -286,7 +286,7 @@ jobs: run: sudo xcode-select --switch /Applications/Xcode_15.3.0.app/Contents/Developer - name: Use Node.js 22.x - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '22.x' diff --git a/.github/workflows/web.yml b/.github/workflows/web.yml index 616c2c6db8a8d..016feab5e0d94 100644 --- a/.github/workflows/web.yml +++ b/.github/workflows/web.yml @@ -48,7 +48,7 @@ jobs: uses: ./.github/workflows/linux-wasm-ci-build-and-test-workflow.yml with: build_config: Release - extra_build_args: "--target onnxruntime_webassembly --skip_tests --enable_wasm_api_exception_catching --disable_rtti" + extra_build_args: "--target onnxruntime_webassembly --skip_tests --disable_rtti" build_jsep: true build_webgpu: true @@ -57,7 +57,7 @@ jobs: uses: ./.github/workflows/linux-wasm-ci-build-and-test-workflow.yml with: build_config: Release - extra_build_args: "--skip_tests --enable_wasm_api_exception_catching --disable_rtti --build_wasm_static_lib" + extra_build_args: "--skip_tests --disable_rtti --build_wasm_static_lib" use_vcpkg: false enable_wasm_threads: false skip_publish: true diff --git a/.github/workflows/windows-web-ci-workflow.yml b/.github/workflows/windows-web-ci-workflow.yml index f837056fedb41..6ebc6bd7fea43 100644 --- a/.github/workflows/windows-web-ci-workflow.yml +++ b/.github/workflows/windows-web-ci-workflow.yml @@ -62,7 +62,7 @@ jobs: git checkout -- .gitattributes - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: "20.x" diff --git a/.github/workflows/windows_cuda.yml b/.github/workflows/windows_cuda.yml index 3d24d4b6b75b6..ae23902a015ab 100644 --- a/.github/workflows/windows_cuda.yml +++ b/.github/workflows/windows_cuda.yml @@ -56,7 +56,7 @@ jobs: Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.8\bin" Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.8\extras\CUPTI\lib64" - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: '20.x' @@ -168,7 +168,7 @@ jobs: python-version: '3.12' architecture: x64 - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_dml.yml b/.github/workflows/windows_dml.yml index b398e69900210..e8ee7751348b4 100644 --- a/.github/workflows/windows_dml.yml +++ b/.github/workflows/windows_dml.yml @@ -47,7 +47,7 @@ jobs: working-directory: ${{ github.workspace }} shell: cmd - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_tensorrt.yml b/.github/workflows/windows_tensorrt.yml index 2a1fe97d9b7b7..f8c6471b412c0 100644 --- a/.github/workflows/windows_tensorrt.yml +++ b/.github/workflows/windows_tensorrt.yml @@ -61,7 +61,7 @@ jobs: Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\v12.8\extras\CUPTI\lib64" Add-Content -Path $env:GITHUB_PATH -Value "$env:RUNNER_TEMP\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: '20.x' @@ -173,7 +173,7 @@ jobs: python-version: '3.12' architecture: x64 - - uses: actions/setup-node@v5 + - uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_webgpu.yml b/.github/workflows/windows_webgpu.yml index f849bdda0dff3..899a8b66eac7a 100644 --- a/.github/workflows/windows_webgpu.yml +++ b/.github/workflows/windows_webgpu.yml @@ -56,7 +56,7 @@ jobs: working-directory: ${{ github.workspace }} - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: "20.x" @@ -231,7 +231,7 @@ jobs: working-directory: ${{ github.workspace }} - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: "20.x" diff --git a/.github/workflows/windows_x64_debug_build_x64_debug.yml b/.github/workflows/windows_x64_debug_build_x64_debug.yml index 6a1b43e54ed89..e4d1477b5619c 100644 --- a/.github/workflows/windows_x64_debug_build_x64_debug.yml +++ b/.github/workflows/windows_x64_debug_build_x64_debug.yml @@ -38,7 +38,7 @@ jobs: run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_x64_release_build_x64_release.yml b/.github/workflows/windows_x64_release_build_x64_release.yml index 0bcd282e8dc50..46b667ac22b02 100644 --- a/.github/workflows/windows_x64_release_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_build_x64_release.yml @@ -38,7 +38,7 @@ jobs: run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml index 3934047266f59..4026869c6e4f2 100644 --- a/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml +++ b/.github/workflows/windows_x64_release_ep_generic_interface_build_x64_release_ep_generic_interface.yml @@ -38,7 +38,7 @@ jobs: run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml index 1c38d8e58970c..4378231338673 100644 --- a/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml +++ b/.github/workflows/windows_x64_release_vitisai_build_x64_release.yml @@ -38,7 +38,7 @@ jobs: run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_x64_release_xnnpack.yml b/.github/workflows/windows_x64_release_xnnpack.yml index 6eb9f00d3997d..dd8c251ea23d3 100644 --- a/.github/workflows/windows_x64_release_xnnpack.yml +++ b/.github/workflows/windows_x64_release_xnnpack.yml @@ -38,7 +38,7 @@ jobs: run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '20.x' diff --git a/.github/workflows/windows_x86.yml b/.github/workflows/windows_x86.yml index 597c1c7f4b6cf..e8f04a955e32e 100644 --- a/.github/workflows/windows_x86.yml +++ b/.github/workflows/windows_x86.yml @@ -38,7 +38,7 @@ jobs: run: python -m pip install -r "${{ github.workspace }}\tools\ci_build\github\windows\python\requirements.txt" - name: Setup Node.js - uses: actions/setup-node@v5 + uses: actions/setup-node@v6 with: node-version: '20.x' architecture: x86 #Add architecture diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8186da507a442..5ea812622b9b6 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -194,6 +194,7 @@ option(onnxruntime_USE_NCCL "Build with NCCL support" OFF) # WebAssembly options option(onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB "Enable this option to create WebAssembly static library" OFF) +option(onnxruntime_ENABLE_WEBASSEMBLY_JSPI "Enable WebAssembly JavaScript Promise Integration" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_THREADS "Enable this option to create WebAssembly byte codes with multi-threads support" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING "Enable this option to turn on exception catching" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_API_EXCEPTION_CATCHING "Enable this option to turn on api exception catching" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 502a60ec8d7b8..3975361d5928c 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -45,7 +45,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") string(APPEND CMAKE_CXX_FLAGS " -msimd128") endif() - if (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING) + # Enable WebAssembly exception catching. + if (onnxruntime_ENABLE_WEBASSEMBLY_JSPI) + string(APPEND CMAKE_C_FLAGS " -fwasm-exceptions -s WASM_LEGACY_EXCEPTIONS=0") + string(APPEND CMAKE_CXX_FLAGS " -fwasm-exceptions -s WASM_LEGACY_EXCEPTIONS=0") + elseif (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING) string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") endif() diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 44b794d9e2f78..df554269dfc7f 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -4,7 +4,7 @@ onnxruntime_fetchcontent_declare( URL ${DEP_URL_cutlass} URL_HASH SHA1=${DEP_SHA1_cutlass} EXCLUDE_FROM_ALL -PATCH_COMMAND ${Patch_EXECUTABLE} --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_4.2.1_maybe_unused.patch +PATCH_COMMAND ${Patch_EXECUTABLE} --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass_4.2.1.patch ) FetchContent_GetProperties(cutlass) diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 81fde14d3dda0..2f55860f34c70 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -5,12 +5,14 @@ set(EIGEN_BUILD_PKGCONFIG OFF CACHE BOOL "" FORCE) set(EIGEN_BUILD_CMAKE_PACKAGE ON CACHE BOOL "" FORCE) set(PATCH_EIGEN_S390X ${PROJECT_SOURCE_DIR}/patches/eigen/s390x-build.patch) +set(PATCH_EIGEN_S390X_WERROR ${PROJECT_SOURCE_DIR}/patches/eigen/s390x-build-werror.patch) onnxruntime_fetchcontent_declare( Eigen3 URL ${DEP_URL_eigen} URL_HASH SHA1=${DEP_SHA1_eigen} - PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_EIGEN_S390X} + PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_EIGEN_S390X} && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_EIGEN_S390X_WERROR} EXCLUDE_FROM_ALL ) onnxruntime_fetchcontent_makeavailable(Eigen3) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index b6a741d8b0fe7..603d578e696ef 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -186,7 +186,8 @@ endif() #2. if ONNX_CUSTOM_PROTOC_EXECUTABLE is not set, Compile everything(including protoc) from source code. if(Patch_FOUND) set(ONNXRUNTIME_PROTOBUF_PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/protobuf/protobuf_cmake.patch && - ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/protobuf/protobuf_android_log.patch) + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/protobuf/protobuf_android_log.patch && + ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/protobuf/protobuf_s390x.patch) else() set(ONNXRUNTIME_PROTOBUF_PATCH_COMMAND "") endif() @@ -314,18 +315,18 @@ if (onnxruntime_ENABLE_CPUINFO) # Adding pytorch CPU info library # TODO!! need a better way to find out the supported architectures set(CPUINFO_SUPPORTED FALSE) - if (APPLE) + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo. + if (onnxruntime_USE_XNNPACK) + set(CPUINFO_SUPPORTED TRUE) + endif() + elseif (APPLE) list(LENGTH CMAKE_OSX_ARCHITECTURES CMAKE_OSX_ARCHITECTURES_LEN) if (CMAKE_OSX_ARCHITECTURES_LEN LESS_EQUAL 1) set(CPUINFO_SUPPORTED TRUE) else() message(WARNING "cpuinfo is not supported when CMAKE_OSX_ARCHITECTURES has more than one value.") endif() - elseif (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - # if xnnpack is enabled in a wasm build it needs clog from cpuinfo, but we won't internally use cpuinfo. - if (onnxruntime_USE_XNNPACK) - set(CPUINFO_SUPPORTED TRUE) - endif() elseif (WIN32) set(CPUINFO_SUPPORTED TRUE) else() @@ -634,7 +635,6 @@ if (onnxruntime_USE_WEBGPU) # set(DAWN_BUILD_SAMPLES OFF CACHE BOOL "" FORCE) set(DAWN_ENABLE_NULL OFF CACHE BOOL "" FORCE) - set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) set(DAWN_BUILD_PROTOBUF OFF CACHE BOOL "" FORCE) set(DAWN_BUILD_TESTS OFF CACHE BOOL "" FORCE) if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") @@ -715,6 +715,7 @@ if (onnxruntime_USE_WEBGPU) endif() if (onnxruntime_CUSTOM_DAWN_SRC_PATH) + set(DAWN_FETCH_DEPENDENCIES OFF CACHE BOOL "" FORCE) # use the custom dawn source path if provided # # specified as: @@ -725,6 +726,7 @@ if (onnxruntime_USE_WEBGPU) EXCLUDE_FROM_ALL ) else() + set(DAWN_FETCH_DEPENDENCIES ON CACHE BOOL "" FORCE) set(ONNXRUNTIME_Dawn_PATCH_COMMAND # The dawn_destroy_buffer_on_destructor.patch contains the following changes: # diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index bb33f5ec0b554..7b631895d8d95 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -777,6 +777,7 @@ endif() if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) set(mlas_platform_srcs ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_lasx.cpp ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index b762b149d9d6f..62f594d194543 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -47,17 +47,28 @@ list(REMOVE_ITEM EM_DAWN_WEBGPU_C_COMPILE_OPTIONS "-fno-exceptions") set_property(TARGET emdawnwebgpu_c PROPERTY COMPILE_OPTIONS ${EM_DAWN_WEBGPU_C_COMPILE_OPTIONS}) endif() + if (CMAKE_CXX_FLAGS MATCHES "-fwasm-exceptions") + get_property(EM_DAWN_WEBGPU_C_COMPILE_OPTIONS TARGET emdawnwebgpu_c PROPERTY COMPILE_OPTIONS) + list(REMOVE_ITEM EM_DAWN_WEBGPU_C_COMPILE_OPTIONS "-fno-exceptions") + set_property(TARGET emdawnwebgpu_c PROPERTY COMPILE_OPTIONS ${EM_DAWN_WEBGPU_C_COMPILE_OPTIONS}) + endif() # target "emdawnwebgpu_cpp" is created by Dawn. When it is linked to onnxruntime_providers_webgpu as "PUBLIC" # dependency, a few build/link flags will be set automatically to make sure emscripten can generate correct # WebAssembly/JavaScript code for WebGPU support. target_link_libraries(onnxruntime_providers_webgpu PUBLIC emdawnwebgpu_cpp) - # ASYNCIFY is required for WGPUFuture support (ie. async functions in WebGPU API) - target_link_options(onnxruntime_providers_webgpu PUBLIC - "SHELL:-s ASYNCIFY=1" - "SHELL:-s ASYNCIFY_STACK_SIZE=65536" - ) + if (onnxruntime_ENABLE_WEBASSEMBLY_JSPI) + target_link_options(onnxruntime_providers_webgpu PUBLIC + "SHELL:-s JSPI=1" + ) + else() + # ASYNCIFY is required for WGPUFuture support (ie. async functions in WebGPU API) + target_link_options(onnxruntime_providers_webgpu PUBLIC + "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + ) + endif() else() onnxruntime_add_include_to_target(onnxruntime_providers_webgpu dawn::dawncpp_headers dawn::dawn_headers) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index eae4433baa20e..fa93113c76160 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -244,10 +244,6 @@ else() ) if (onnxruntime_USE_JSEP) - # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU - # This flag allows async functions to be called from sync functions, in the cost of binary size and - # build time. See https://emscripten.org/docs/porting/asyncify.html for more details. - target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" @@ -275,13 +271,24 @@ else() endif() if (onnxruntime_USE_JSEP OR onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN) - # if any of the above is enabled, we need to use the asyncify library - target_link_options(onnxruntime_webassembly PRIVATE - "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-async.js\"" - "SHELL:-s ASYNCIFY=1" - "SHELL:-s ASYNCIFY_STACK_SIZE=65536" - ) - list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-async.js") + if (onnxruntime_ENABLE_WEBASSEMBLY_JSPI) + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s JSPI=1" + "SHELL:-s JSPI_EXPORTS=[OrtAppendExecutionProvider,OrtCreateSession,OrtRun,OrtRunWithBinding,OrtBindInput]" + ) + else() + # NOTE: "-s ASYNCIFY=1" is required for JSEP to work with WebGPU + # This flag allows async functions to be called from sync functions, in the cost of binary size and + # build time. See https://emscripten.org/docs/porting/asyncify.html for more details. + # + # if any of the above is enabled, we need to use the asyncify library + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-async.js\"" + "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + ) + list(APPEND onnxruntime_webassembly_script_deps "${ONNXRUNTIME_ROOT}/wasm/pre-async.js") + endif() endif() if (onnxruntime_EMSCRIPTEN_SETTINGS) @@ -322,8 +329,12 @@ else() endif() endif() - # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_JSPI) + # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s DISABLE_EXCEPTION_THROWING=0" + ) + endif() if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING) target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs) @@ -379,8 +390,11 @@ else() if (onnxruntime_USE_JSEP) string(APPEND target_name ".jsep") elseif (onnxruntime_USE_WEBGPU OR onnxruntime_USE_WEBNN) - string(APPEND target_name ".asyncify") - # TODO: support JSPI and add ".jspi" once JSPI build is supported + if (onnxruntime_ENABLE_WEBASSEMBLY_JSPI) + string(APPEND target_name ".jspi") + else() + string(APPEND target_name ".asyncify") + endif() endif() set_target_properties(onnxruntime_webassembly PROPERTIES OUTPUT_NAME ${target_name} SUFFIX ".mjs") diff --git a/cmake/patches/cutlass/cutlass_4.2.1.patch b/cmake/patches/cutlass/cutlass_4.2.1.patch new file mode 100644 index 0000000000000..3a3ec5ba103ef --- /dev/null +++ b/cmake/patches/cutlass/cutlass_4.2.1.patch @@ -0,0 +1,39 @@ +diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp +index cb161369..2fdff179 100644 +--- a/include/cute/layout.hpp ++++ b/include/cute/layout.hpp +@@ -1487,7 +1487,7 @@ nullspace(Layout const& layout) + [[maybe_unused]] auto flat_stride = flatten(layout.stride()); + + // Select all indices corresponding to stride-0s +- auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, ++ [[maybe_unused]] auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, + [&](auto init, auto i){ + if constexpr (is_constant_v<0, decltype(get(flat_stride))>) { return append(init, i); } + else { return init; } +diff --git a/include/cutlass/exmy_base.h b/include/cutlass/exmy_base.h +index be207a49..6028e01d 100644 +--- a/include/cutlass/exmy_base.h ++++ b/include/cutlass/exmy_base.h +@@ -1021,18 +1021,18 @@ struct float_exmy_base + + /// Floating point conversion + CUTLASS_HOST_DEVICE +- explicit float_exmy_base(float x) { ++ explicit float_exmy_base(float x) { + storage = static_cast(this)->convert_from_float(x).storage; + } + + // Integer conversion + CUTLASS_HOST_DEVICE +- explicit float_exmy_base(int x) { ++ explicit float_exmy_base(int x) { + storage = static_cast(this)->convert_from_float(float(x)).storage; + } + + CUTLASS_HOST_DEVICE +- explicit float_exmy_base(unsigned x) { ++ explicit float_exmy_base(unsigned x) { + storage = static_cast(this)->convert_from_float(float(x)).storage; + } + diff --git a/cmake/patches/cutlass/cutlass_4.2.1_maybe_unused.patch b/cmake/patches/cutlass/cutlass_4.2.1_maybe_unused.patch deleted file mode 100644 index 03d5972823839..0000000000000 --- a/cmake/patches/cutlass/cutlass_4.2.1_maybe_unused.patch +++ /dev/null @@ -1,13 +0,0 @@ -diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp -index cb161369..2fdff179 100644 ---- a/include/cute/layout.hpp -+++ b/include/cute/layout.hpp -@@ -1487,7 +1487,7 @@ nullspace(Layout const& layout) - [[maybe_unused]] auto flat_stride = flatten(layout.stride()); - - // Select all indices corresponding to stride-0s -- auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, -+ [[maybe_unused]] auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, - [&](auto init, auto i){ - if constexpr (is_constant_v<0, decltype(get(flat_stride))>) { return append(init, i); } - else { return init; } diff --git a/cmake/patches/eigen/s390x-build-werror.patch b/cmake/patches/eigen/s390x-build-werror.patch new file mode 100644 index 0000000000000..d6aab355e2f2f --- /dev/null +++ b/cmake/patches/eigen/s390x-build-werror.patch @@ -0,0 +1,13 @@ +Comment out variable unused in onnxruntime + +--- a/Eigen/src/Core/arch/ZVector/PacketMath.h.orig 2025-10-21 10:24:49.410176124 +0000 ++++ b/Eigen/src/Core/arch/ZVector/PacketMath.h 2025-10-21 10:25:06.010176124 +0000 +@@ -101,7 +101,7 @@ + + static EIGEN_DECLARE_CONST_FAST_Packet4f(ZERO, 0); //{ 0.0, 0.0, 0.0, 0.0} + static EIGEN_DECLARE_CONST_FAST_Packet4i(MINUS1, -1); //{ -1, -1, -1, -1} +-static Packet4f p4f_MZERO = {0x80000000, 0x80000000, 0x80000000, 0x80000000}; ++//static Packet4f p4f_MZERO = {0x80000000, 0x80000000, 0x80000000, 0x80000000}; + #endif + + static Packet4i p4i_COUNTDOWN = {0, 1, 2, 3}; diff --git a/cmake/patches/protobuf/protobuf_s390x.patch b/cmake/patches/protobuf/protobuf_s390x.patch new file mode 100644 index 0000000000000..039eb8931339a --- /dev/null +++ b/cmake/patches/protobuf/protobuf_s390x.patch @@ -0,0 +1,16 @@ +s390x compatibility changes based on + +https://github.com/protocolbuffers/protobuf/commit/a2859cc2ce25711613002104022186c0c37d9f1f + +diff --git a/src/google/protobuf/port_def.inc b/src/google/protobuf/port_def.inc +index edd6d5122598e..a0a296a85da3d 100644 +--- a/src/google/protobuf/port_def.inc ++++ b/src/google/protobuf/port_def.inc +@@ -255,6 +255,7 @@ + #error PROTOBUF_TAILCALL was previously defined + #endif + #if __has_cpp_attribute(clang::musttail) && !defined(__arm__) && \ ++ !defined(__s390x__) && \ + !defined(_ARCH_PPC) && !defined(__wasm__) && \ + !(defined(_MSC_VER) && defined(_M_IX86)) && \ + !(defined(__NDK_MAJOR__) && __NDK_MAJOR <= 24) diff --git a/js/build_webgpu.bat b/js/build_webgpu.bat index 47478be74654b..d32a7ab1abb81 100644 --- a/js/build_webgpu.bat +++ b/js/build_webgpu.bat @@ -69,11 +69,11 @@ popd set PATH=C:\Program Files\Git\usr\bin;%PATH% call %ROOT%build.bat --config %CONFIG% %CONFIG_EXTRA_FLAG% --skip_submodule_sync --build_wasm --target onnxruntime_webassembly --skip_tests^ - --enable_wasm_simd --enable_wasm_threads --use_webnn --use_webgpu --build_dir %BUILD_DIR% + --enable_wasm_simd --enable_wasm_threads --use_webnn --use_webgpu --enable_wasm_jspi --build_dir %BUILD_DIR% IF NOT "%ERRORLEVEL%" == "0" ( exit /b %ERRORLEVEL% ) -copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.asyncify.wasm %ROOT%js\web\dist\ -copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.asyncify.mjs %ROOT%js\web\dist\ +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jspi.wasm %ROOT%js\web\dist\ +copy /Y %BUILD_DIR%\%CONFIG%\ort-wasm-simd-threaded.jspi.mjs %ROOT%js\web\dist\ diff --git a/js/build_webgpu.sh b/js/build_webgpu.sh index 5fbcee7885e39..ea12093c37cf7 100755 --- a/js/build_webgpu.sh +++ b/js/build_webgpu.sh @@ -24,7 +24,8 @@ if [ "$1" = "d" ]; then CONFIG_EXTRA_FLAG="--enable_wasm_profiling --wasm_run_tests_in_browser --cmake_extra_defines onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL=1 --enable_wasm_debug_info" elif [ "$1" = "r" ]; then CONFIG="Release" - CONFIG_EXTRA_FLAG="--enable_wasm_api_exception_catching --disable_rtti" + # CONFIG_EXTRA_FLAG="--enable_wasm_api_exception_catching --disable_rtti" + CONFIG_EXTRA_FLAG="--disable_rtti" else echo "Error: Invalid configuration \"$1\"." echo "Configuration must be 'd' (Debug) or 'r' (Release)." @@ -99,6 +100,7 @@ echo "Calling $ROOT_DIR/build.sh to build WebAssembly..." --enable_wasm_threads \ --use_webnn \ --use_webgpu \ + --enable_wasm_jspi \ --build_dir "$BUILD_DIR" # The 'set -e' command at the beginning of the script ensures that the script will exit @@ -108,10 +110,10 @@ echo "--- Copying build artifacts ---" # Ensure the dist directory exists before copying files mkdir -p "$ROOT_DIR/js/web/dist" -echo "Copying ort-wasm-simd-threaded.asyncify.wasm to $ROOT_DIR/js/web/dist/" -cp -f "$BUILD_DIR/$CONFIG/ort-wasm-simd-threaded.asyncify.wasm" "$ROOT_DIR/js/web/dist/" +echo "Copying ort-wasm-simd-threaded.jspi.wasm to $ROOT_DIR/js/web/dist/" +cp -f "$BUILD_DIR/$CONFIG/ort-wasm-simd-threaded.jspi.wasm" "$ROOT_DIR/js/web/dist/" -echo "Copying ort-wasm-simd-threaded.asyncify.mjs to $ROOT_DIR/js/web/dist/" -cp -f "$BUILD_DIR/$CONFIG/ort-wasm-simd-threaded.asyncify.mjs" "$ROOT_DIR/js/web/dist/" +echo "Copying ort-wasm-simd-threaded.jspi.mjs to $ROOT_DIR/js/web/dist/" +cp -f "$BUILD_DIR/$CONFIG/ort-wasm-simd-threaded.jspi.mjs" "$ROOT_DIR/js/web/dist/" echo "--- WebGPU build process completed successfully ---" diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 98b74a6474331..d41e0936f7fac 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -16,6 +16,7 @@ export declare namespace Env { * - `ort-wasm-simd-threaded.wasm` for default build * - `ort-wasm-simd-threaded.jsep.wasm` for JSEP build (with WebGPU and WebNN) * - `ort-wasm-simd-threaded.asyncify.wasm` for WebGPU build with Asyncify (with WebNN) + * - `ort-wasm-simd-threaded.jspi.wasm` for WebGPU build with JSPI support (with WebNN) */ wasm?: URL | string; /** @@ -27,6 +28,7 @@ export declare namespace Env { * - `ort-wasm-simd-threaded.mjs` for default build * - `ort-wasm-simd-threaded.jsep.mjs` for JSEP build (with WebGPU and WebNN) * - `ort-wasm-simd-threaded.asyncify.mjs` for WebGPU build with Asyncify (with WebNN) + * - `ort-wasm-simd-threaded.jspi.mjs` for WebGPU build with JSPI support (with WebNN) */ mjs?: URL | string; } diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index 89a2b4a6ff1be..1dde572b6fc87 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -38,6 +38,10 @@ interface BuildDefinitions { * defines whether to disable proxy feature in WebAssembly backend in the build. */ readonly DISABLE_WASM_PROXY: boolean; + /** + * defines whether to enable JSPI (JavaScript Plugin Interface) support in the build. + */ + readonly ENABLE_JSPI: boolean; /** * defines whether to enable bundling the wasm JS in the build. * @@ -46,6 +50,7 @@ interface BuildDefinitions { * - `ort-wasm-simd-threaded.mjs` * - `ort-wasm-simd-threaded.jsep.mjs` * - `ort-wasm-simd-threaded.asyncify.mjs` + * - `ort-wasm-simd-threaded.jspi.mjs` * * The value is valid only when it's an ESM build. */ diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index b725f5c8e80b7..6c90d8b4b94eb 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -133,9 +133,11 @@ export const initializeWebAssemblyAndOrtRuntime = async (): Promise => { message.in!.wasm.wasmPaths = { wasm: !BUILD_DEFS.DISABLE_JSEP ? new URL('ort-wasm-simd-threaded.jsep.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href - : !BUILD_DEFS.DISABLE_WEBGPU - ? new URL('ort-wasm-simd-threaded.asyncify.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href - : new URL('ort-wasm-simd-threaded.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href, + : BUILD_DEFS.ENABLE_JSPI + ? new URL('ort-wasm-simd-threaded.jspi.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href + : !BUILD_DEFS.DISABLE_WEBGPU + ? new URL('ort-wasm-simd-threaded.asyncify.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href + : new URL('ort-wasm-simd-threaded.wasm', BUILD_DEFS.ESM_IMPORT_META_URL).href, }; } proxyWorker.postMessage(message); diff --git a/js/web/lib/wasm/wasm-factory.ts b/js/web/lib/wasm/wasm-factory.ts index f2a28396d7486..13bb9f9fbbc93 100644 --- a/js/web/lib/wasm/wasm-factory.ts +++ b/js/web/lib/wasm/wasm-factory.ts @@ -121,6 +121,12 @@ export const initializeWebAssembly = async (flags: Env.WebAssemblyFlags): Promis throw new Error('WebAssembly SIMD is not supported in the current environment.'); } + if (BUILD_DEFS.ENABLE_JSPI) { + if (!('Suspending' in WebAssembly)) { + throw new Error('WebAssembly JSPI is not supported in the current environment.'); + } + } + // check if multi-threading is supported const multiThreadSupported = isMultiThreadSupported(); if (numThreads > 1 && !multiThreadSupported) { diff --git a/js/web/lib/wasm/wasm-utils-import.ts b/js/web/lib/wasm/wasm-utils-import.ts index fa7efa9910c59..e2e46bb37dcfc 100644 --- a/js/web/lib/wasm/wasm-utils-import.ts +++ b/js/web/lib/wasm/wasm-utils-import.ts @@ -166,7 +166,7 @@ const preload = async (absoluteUrl: string): Promise => { * @returns - A promise that resolves to the default export of the module. */ const dynamicImportDefault = async (url: string): Promise => - (await import(/* webpackIgnore: true */ url)).default; + (await import(/* webpackIgnore: true */ /* @vite-ignore */ url)).default; /** * The proxy worker factory imported from the proxy worker module. @@ -214,9 +214,11 @@ const embeddedWasmModule: EmscriptenModuleFactory | undefined = require( !BUILD_DEFS.DISABLE_JSEP ? '../../dist/ort-wasm-simd-threaded.jsep.mjs' - : !BUILD_DEFS.DISABLE_WEBGPU - ? '../../dist/ort-wasm-simd-threaded.asyncify.mjs' - : '../../dist/ort-wasm-simd-threaded.mjs', + : BUILD_DEFS.ENABLE_JSPI + ? '../../dist/ort-wasm-simd-threaded.jspi.mjs' + : !BUILD_DEFS.DISABLE_WEBGPU + ? '../../dist/ort-wasm-simd-threaded.asyncify.mjs' + : '../../dist/ort-wasm-simd-threaded.mjs', ).default : undefined; @@ -278,9 +280,11 @@ export const importWasmModule = async ( } else { const wasmModuleFilename = !BUILD_DEFS.DISABLE_JSEP ? 'ort-wasm-simd-threaded.jsep.mjs' - : !BUILD_DEFS.DISABLE_WEBGPU - ? 'ort-wasm-simd-threaded.asyncify.mjs' - : 'ort-wasm-simd-threaded.mjs'; + : BUILD_DEFS.ENABLE_JSPI + ? 'ort-wasm-simd-threaded.jspi.mjs' + : !BUILD_DEFS.DISABLE_WEBGPU + ? 'ort-wasm-simd-threaded.asyncify.mjs' + : 'ort-wasm-simd-threaded.mjs'; const wasmModuleUrl = urlOverride ?? normalizeUrl(wasmModuleFilename, prefixOverride); // need to preload if all of the following conditions are met: // 1. not in Node.js. diff --git a/js/web/package.json b/js/web/package.json index ecd87fab4302b..d86bd9b383dd8 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -107,6 +107,14 @@ "default": "./dist/ort.webgpu.bundle.min.mjs" }, "require": "./dist/ort.webgpu.min.js" + }, + "./jspi": { + "types": "./types.d.ts", + "import": { + "onnxruntime-web-use-extern-wasm": "./dist/ort.jspi.min.mjs", + "default": "./dist/ort.jspi.bundle.min.mjs" + }, + "require": "./dist/ort.jspi.min.js" } }, "types": "./types.d.ts", diff --git a/js/web/script/build.ts b/js/web/script/build.ts index 22f10b0b90a8f..620fddd8323d0 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -14,6 +14,7 @@ console.time('BUILD'); */ const args = minimist(process.argv.slice(2)); + /** * --bundle-mode=prod (default) * Build multiple ort-web bundles for production. @@ -28,7 +29,7 @@ const args = minimist(process.argv.slice(2)); * Build a single ort-web bundle for nodejs. */ const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = - process.env.npm_config_bundle_mode || args['bundle-mode'] || 'prod'; + args['bundle-mode'] || process.env.npm_config_bundle_mode || 'prod'; /** * --debug @@ -42,7 +43,7 @@ const BUNDLE_MODE: 'prod' | 'dev' | 'perf' | 'node' = * Enable debug mode. In this mode, esbuild metafile feature will be enabled. Full bundle analysis will be saved to a * file as JSON. */ -const DEBUG = process.env.npm_config_debug || args.debug; // boolean|'verbose'|'save' +const DEBUG = args.debug || process.env.npm_config_debug; // boolean|'verbose'|'save' /** * --webgpu-ep @@ -53,7 +54,17 @@ const DEBUG = process.env.npm_config_debug || args.debug; // boolean|'verbose'|' * * (temporary) This flag is used to test the WebGPU EP integration. It will be removed in the future. */ -const USE_WEBGPU_EP = process.env.npm_config_webgpu_ep ?? args['webgpu-ep'] ?? false; +const USE_WEBGPU_EP = args['webgpu-ep'] ?? process.env.npm_config_webgpu_ep ?? false; + +/** + * --jspi + * --no-jspi (default) + * + * Enable or disable the use of JSPI. If enabled, JSPI will be used instead of ASYNCIFY. + * + * (temporary) This flag is used to test the JSPI integration. It will be removed in the future. + */ +const USE_JSPI = args.jspi ?? process.env.npm_config_jspi ?? false; /** * Root folder of the source code: `/js/` @@ -68,6 +79,7 @@ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_JSEP': JSON.stringify(!!USE_WEBGPU_EP), 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', + 'BUILD_DEFS.ENABLE_JSPI': JSON.stringify(!!USE_JSPI), 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': JSON.stringify(!USE_WEBGPU_EP), 'BUILD_DEFS.DISABLE_WEBNN': 'false', @@ -267,7 +279,7 @@ async function buildBundle(options: esbuild.BuildOptions) { * * The distribution code is split into multiple files: * - [output-name][.min].[m]js - * - ort-wasm-simd-threaded[.jsep|.asyncify].mjs + * - ort-wasm-simd-threaded[.jsep|.asyncify|.jspi].mjs */ async function buildOrt({ isProduction = false, @@ -359,7 +371,7 @@ async function buildTest() { * ``` * to: * ``` - * ... await import(/* webpackIgnore: true *\/... + * ... await import(/* webpackIgnore: true *\/ /* @vite-ignore *\/... * ``` * * Why we need this? @@ -375,15 +387,18 @@ async function buildTest() { * - There are multiple entry points that use dynamic import to load the ort-*.mjs and ort-*.wasm. If the content of the * dynamic import is resolved by Webpack, it will be duplicated in the final bundle. This will increase the bundle size. * + * Additionally, Vite is unable to analyze the dynamic import calls, which triggers a warning. These dynamic imports are + * intentional, so the warning should be ignored. Aside from suppressing the warning, this does not change any behavior. + * * What about other bundlers? * * TBD * */ async function postProcess() { - const IMPORT_MAGIC_COMMENT = '/*webpackIgnore:true*/'; + const IMPORT_MAGIC_COMMENTS = ['/*webpackIgnore:true*/', '/*@vite-ignore*/'].join(' '); const IMPORT_ORIGINAL = 'await import('; - const IMPORT_NEW = `await import(${IMPORT_MAGIC_COMMENT}`; + const IMPORT_NEW = `await import(${IMPORT_MAGIC_COMMENTS}`; const files = await fs.readdir(path.join(SOURCE_ROOT_FOLDER, 'web/dist')); for (const file of files) { @@ -437,7 +452,7 @@ async function postProcess() { consumer.eachMapping((mapping) => { if (mapping.generatedLine === line && mapping.generatedColumn >= column) { - mapping.generatedColumn += IMPORT_MAGIC_COMMENT.length; + mapping.generatedColumn += IMPORT_MAGIC_COMMENTS.length; } updatedSourceMap.addMapping({ @@ -472,9 +487,9 @@ async function postProcess() { await fs.writeFile(jsFilePath, jsFileLines.join('\n')); const newJsFileSize = (await fs.stat(jsFilePath)).size; - if (newJsFileSize - originalJsFileSize !== IMPORT_MAGIC_COMMENT.length) { + if (newJsFileSize - originalJsFileSize !== IMPORT_MAGIC_COMMENTS.length) { throw new Error( - `Failed to insert magic comment to file "${file}". Original size: ${ + `Failed to insert magic comments to file "${file}". Original size: ${ originalJsFileSize }, New size: ${newJsFileSize}`, ); @@ -485,32 +500,45 @@ async function postProcess() { async function validate() { const files = await fs.readdir(path.join(SOURCE_ROOT_FOLDER, 'web/dist')); - for (const file of files) { - // validate on all "ort.*.min.js" and "ort.*.min.mjs" files. - if ((file.endsWith('.js') || file.endsWith('.mjs')) && file.startsWith('ort.')) { - const isMinified = file.endsWith('.min.js') || file.endsWith('.min.mjs'); - const content = await fs.readFile(path.join(SOURCE_ROOT_FOLDER, 'web/dist', file), 'utf-8'); - - if (isMinified) { - // all files should not contain BUILD_DEFS definition. BUILD_DEFS should be defined in the build script only. - // - // If the final bundle contains BUILD_DEFS definition, it means the build script is not working correctly. In - // this case, we should fix the build script (this file). - // - if (content.includes('BUILD_DEFS')) { - throw new Error(`Validation failed: "${file}" contains BUILD_DEFS definition.`); - } - } + // validate on all "ort.*.min.js" and "ort.*.min.mjs" files. + const validateFiles = files.filter( + (file) => (file.endsWith('.js') || file.endsWith('.mjs')) && file.startsWith('ort.'), + ); + + for (const file of validateFiles) { + const isMinified = file.endsWith('.min.js') || file.endsWith('.min.mjs'); + const content = await fs.readFile(path.join(SOURCE_ROOT_FOLDER, 'web/dist', file), 'utf-8'); - // all files should contain the magic comment to ignore dynamic import calls. + if (isMinified) { + // all files should not contain BUILD_DEFS definition. BUILD_DEFS should be defined in the build script only. // - if (!file.includes('.webgl.') && !file.includes('.bundle.')) { - const contentToSearch = isMinified ? '/*webpackIgnore:true*/' : '/* webpackIgnore: true */'; - if (!content.includes(contentToSearch)) { - throw new Error(`Validation failed: "${file}" does not contain magic comment.`); - } + // If the final bundle contains BUILD_DEFS definition, it means the build script is not working correctly. In + // this case, we should fix the build script (this file). + // + if (content.includes('BUILD_DEFS')) { + throw new Error(`Validation failed: "${file}" contains BUILD_DEFS definition.`); } } + + if (file.includes('.webgl.') || file.includes('.bundle.')) { + // no further validation required. + // + continue; + } + + // all files should contain the webpack magic comment to ignore dynamic import calls. + // + const webpackContentToSearch = isMinified ? '/*webpackIgnore:true*/' : '/* webpackIgnore: true */'; + if (!content.includes(webpackContentToSearch)) { + throw new Error(`Validation failed: "${file}" does not contain webpack magic comment.`); + } + + // all files should contain the vite magic comment to suppress dynamic import call warnings. + // + const viteContentToSearch = isMinified ? '/*@vite-ignore*/' : '/* @vite-ignore */'; + if (!content.includes(viteContentToSearch)) { + throw new Error(`Validation failed: "${file}" does not contain vite magic comment.`); + } } } @@ -645,6 +673,32 @@ async function main() { }, }); + // ort.jspi[.min].[m]js + await addAllWebBuildTasks({ + outputName: 'ort.jspi', + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'false', + 'BUILD_DEFS.ENABLE_JSPI': 'true', + 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + }, + }); + // ort.jspi.bundle.min.mjs + await buildOrt({ + isProduction: true, + outputName: 'ort.jspi.bundle', + format: 'esm', + define: { + ...DEFAULT_DEFINE, + 'BUILD_DEFS.DISABLE_WEBGPU': 'false', + 'BUILD_DEFS.ENABLE_JSPI': 'true', + 'BUILD_DEFS.DISABLE_JSEP': 'true', + 'BUILD_DEFS.DISABLE_WEBGL': 'true', + 'BUILD_DEFS.ENABLE_BUNDLE_WASM_JS': 'true', + }, + }); + // ort.wasm[.min].[m]js await addAllWebBuildTasks({ outputName: 'ort.wasm', diff --git a/js/web/script/pull-prebuilt-wasm-artifacts.ts b/js/web/script/pull-prebuilt-wasm-artifacts.ts index 87008f51ff4b9..6dce99147abab 100644 --- a/js/web/script/pull-prebuilt-wasm-artifacts.ts +++ b/js/web/script/pull-prebuilt-wasm-artifacts.ts @@ -152,10 +152,12 @@ async function downloadArtifactsForRun(run: any): Promise { fs.readdirSync(WASM_FOLDER).forEach((file) => { if ( [ + 'ort-wasm-simd-threaded.asyncify.mjs', + 'ort-wasm-simd-threaded.asyncify.mjs', 'ort-wasm-simd-threaded.jsep.mjs', 'ort-wasm-simd-threaded.jsep.wasm', - 'ort-wasm-simd-threaded.jsep.mjs', - 'ort-wasm-simd-threaded.jsep.wasm', + 'ort-wasm-simd-threaded.jspi.mjs', + 'ort-wasm-simd-threaded.jspi.wasm', 'ort-wasm-simd-threaded.mjs', 'ort-wasm-simd-threaded.wasm', ].includes(file) diff --git a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json index e880f6bca2ac4..62b4df5806eda 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package-lock.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package-lock.json @@ -12,7 +12,7 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.3.6" + "vite": "^6.4.1" } }, "node_modules/@babel/helper-string-parser": { @@ -1114,9 +1114,9 @@ } }, "node_modules/vite": { - "version": "6.3.6", - "resolved": "https://registry.npmjs.org/vite/-/vite-6.3.6.tgz", - "integrity": "sha512-0msEVHJEScQbhkbVTb/4iHZdJ6SXp/AvxL2sjwYQFfBqleHtnCqv1J3sa9zbWz/6kW1m9Tfzn92vW+kZ1WV6QA==", + "version": "6.4.1", + "resolved": "https://registry.npmjs.org/vite/-/vite-6.4.1.tgz", + "integrity": "sha512-+Oxm7q9hDoLMyJOYfUYBuHQo+dkAloi33apOPP56pzj+vsdJDzr+j1NISE5pyaAuKL4A3UD34qd0lx5+kfKp2g==", "dev": true, "license": "MIT", "dependencies": { diff --git a/js/web/test/e2e/exports/testcases/vite-default/package.json b/js/web/test/e2e/exports/testcases/vite-default/package.json index 84013e2aecb88..3b81cef61c31f 100644 --- a/js/web/test/e2e/exports/testcases/vite-default/package.json +++ b/js/web/test/e2e/exports/testcases/vite-default/package.json @@ -13,6 +13,6 @@ }, "devDependencies": { "@vitejs/plugin-vue": "^5.2.1", - "vite": "^6.3.6" + "vite": "^6.4.1" } } diff --git a/js/web/types.d.ts b/js/web/types.d.ts index b82248c0c83b8..354ce7e927d20 100644 --- a/js/web/types.d.ts +++ b/js/web/types.d.ts @@ -20,3 +20,7 @@ declare module 'onnxruntime-web/webgl' { declare module 'onnxruntime-web/webgpu' { export * from 'onnxruntime-web'; } + +declare module 'onnxruntime-web/jspi' { + export * from 'onnxruntime-web'; +} diff --git a/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h b/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h index 3d21b12f9b55c..6459c20790d5b 100644 --- a/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h +++ b/onnxruntime/contrib_ops/cuda/math/cufft_plan_cache.h @@ -64,6 +64,9 @@ struct ParamsEqual { class CuFFTPlanCache { public: + ~CuFFTPlanCache() { + Clear(); + } CufftPlanInfo TryEmplaceValue(FFTState& key) { std::lock_guard lock(mutex); @@ -81,6 +84,18 @@ class CuFFTPlanCache { std::mutex mutex; + void Clear() { + std::lock_guard lk(mutex); + for (auto& kv : map) { + auto& info = kv.second; + if (info.plan != 0) { + cufftDestroy(info.plan); + info.plan = 0; + } + } + map.clear(); + } + private: CufftPlanInfo CreatePlanInfo(FFTState& key) { cufftHandle plan; diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc index a9bd4afc5cd09..1e69928f2a7ce 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.cc @@ -48,6 +48,7 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { } else { shader.MainFunctionBody() << " let total_seq_length = uniforms.total_sequence_length;\n"; } + shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; // Add indirect dispatch logic for thread 0 if (prepare_indirect_dispatch_) { @@ -62,31 +63,23 @@ Status CopyKVCacheProgram::GenerateShaderCode(ShaderHelper& shader) const { } if (has_past_) { - shader.MainFunctionBody() << "let past_sequence_length = total_seq_length - uniforms.kv_sequence_length;\n"; - if (past_present_share_buffer_) { - shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n" - << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" - << " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n" - << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"; - } else { - const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); - shader.AddInput("past_value", ShaderUsage::UseUniform); - shader.MainFunctionBody() << "let present_offset = global_idx;" - << "if (sequence_id < past_sequence_length) {\n" - << " let pastOffset = " << past_key.IndicesToOffset("past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n" - << " " << present_key.SetByOffset("present_offset", "past_key[pastOffset]") << ";\n" - << " " << present_value.SetByOffset("present_offset", "past_value[pastOffset]") << ";\n" - << "} else {\n" - << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id - past_sequence_length, head_size_id)" : "key_indices_t(batch, sequence_id - past_sequence_length, num_head_id, head_size_id)") << ";\n" - << " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n" - << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n" - << "}"; - } + const auto& past_key = shader.AddInput("past_key", ShaderUsage::UseUniform | ShaderUsage::UseValueTypeAlias | ShaderUsage::UseElementTypeAlias | ShaderUsage::UseIndicesTypeAlias); + shader.AddInput("past_value", ShaderUsage::UseUniform); + shader.MainFunctionBody() << "let present_offset = global_idx;" + << "if (sequence_id < past_sequence_length) {\n" + << " let pastOffset = " << past_key.IndicesToOffset("past_key_indices_t(batch, num_head_id, sequence_id, head_size_id)") << ";\n" + << " " << present_key.SetByOffset("present_offset", "past_key[pastOffset]") << ";\n" + << " " << present_value.SetByOffset("present_offset", "past_value[pastOffset]") << ";\n" + << "} else {\n" + << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id - past_sequence_length, head_size_id)" : "key_indices_t(batch, sequence_id - past_sequence_length, num_head_id, head_size_id)") << ";\n" + << " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n" + << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n" + << "}"; } else { - shader.MainFunctionBody() << " let present_offset = " << (past_present_share_buffer_ ? present_key.IndicesToOffset("output_indices") : "global_idx") << ";\n" - << "let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" - << present_key.SetByOffset("present_offset", "key[offset]") << ";\n" - << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"; + shader.MainFunctionBody() << " let present_offset = " << present_key.IndicesToOffset("present_key_indices_t(batch, num_head_id, past_sequence_length + sequence_id, head_size_id)") << ";\n" + << " let offset = " << key.IndicesToOffset(kv_BNSH_ ? "key_indices_t(batch, num_head_id, sequence_id, head_size_id)" : "key_indices_t(batch, sequence_id, num_head_id, head_size_id)") << ";\n" + << " " << present_key.SetByOffset("present_offset", "key[offset]") << ";\n" + << " " << present_value.SetByOffset("present_offset", "value[offset]") << ";\n"; } return Status::OK(); } @@ -100,19 +93,20 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt // number of input buffers in the shader, which we run out of (<=8) without this optimization. // If indirect_buffer is provided, also prepare indirect dispatch buffer for flash attention. const int components = parameters.head_size_ % 4 == 0 ? 4 : (parameters.head_size_ % 2 == 0 ? 2 : 1); - bool has_past = (parameters.total_sequence_length_ - parameters.kv_sequence_length_) > 0; + // has_past means non-static kv cache with valid past data + bool has_past = !parameters.past_present_share_buffer_ && past_key != nullptr && past_value != nullptr && past_key->SizeInBytes() > 0; // parameters.total_sequence_length_ is past_sequence_length + kv_sequence_length. // parameters.kv_num_heads_ may be smaller than parameters.num_heads_ when parameters.is_gqa_ is true. int num_heads = parameters.is_gqa_ ? parameters.kv_num_heads_ : parameters.num_heads_; // Only copy the new kv data for static kv cache - int copy_sequence_length = has_past && parameters.past_present_share_buffer_ ? parameters.kv_sequence_length_ : parameters.total_sequence_length_; + int copy_sequence_length = parameters.past_present_share_buffer_ ? parameters.kv_sequence_length_ : parameters.total_sequence_length_; TensorShape copy_kv_shape{parameters.batch_size_, num_heads, copy_sequence_length, parameters.head_size_ / components}; int64_t copy_size = copy_kv_shape.Size(); // Determine if we need to prepare indirect dispatch bool prepare_indirect_dispatch = (indirect_buffer != nullptr); - CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, parameters.past_present_share_buffer_, + CopyKVCacheProgram program{"CopyKVCache", has_past, parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH, prepare_indirect_dispatch}; if (parameters.qkv_format_ == Q_K_V_BSNH_BNSH_BNSH) { program.AddInputs({{K, ProgramTensorMetadataDependency::TypeAndRank, components}, @@ -129,7 +123,7 @@ Status CopyKVCache(onnxruntime::webgpu::ComputeContext& context, const WebgpuAtt program.AddInput({seqlen_k, ProgramTensorMetadataDependency::None}); } - if (has_past && !parameters.past_present_share_buffer_) { + if (has_past) { program.AddInputs({{past_key, ProgramTensorMetadataDependency::TypeAndRank, components}, {past_value, ProgramTensorMetadataDependency::TypeAndRank, components}}); } diff --git a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h index 7d71dc0f4d42d..f372aeed0e563 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h +++ b/onnxruntime/contrib_ops/webgpu/bert/flash_attention.h @@ -17,9 +17,9 @@ using namespace onnxruntime::webgpu; class CopyKVCacheProgram final : public Program { public: - CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool past_present_share_buffer, + CopyKVCacheProgram(const std::string& kernel_name, bool has_past, bool kv_BNSH, bool prepare_indirect_dispatch = false) - : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), past_present_share_buffer_(past_present_share_buffer), prepare_indirect_dispatch_(prepare_indirect_dispatch) { + : Program{kernel_name}, has_past_(has_past), kv_BNSH_(kv_BNSH), prepare_indirect_dispatch_(prepare_indirect_dispatch) { } Status GenerateShaderCode(ShaderHelper& sh) const override; @@ -33,7 +33,6 @@ class CopyKVCacheProgram final : public Program { private: bool has_past_; bool kv_BNSH_; - bool past_present_share_buffer_; bool prepare_indirect_dispatch_; }; diff --git a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc index 49cc0209785c5..ce9fc988ac351 100644 --- a/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/webgpu/bert/group_query_attention.cc @@ -287,7 +287,7 @@ Status GroupQueryAttention::ComputeInternal(onnxruntime::webgpu::ComputeContext& !use_sliding_window && CanApplyFlashAttention(attention_bias, present_key, present_value, parameters, context)) { return ApplyFlashAttention(query, key, value, attention_bias, output, past_key, present_key, past_value, - present_value, parameters, context); + present_value, parameters, context, seqlen_k); } TensorShapeVector q_new_dims({parameters.batch_size_, parameters.num_heads_, diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index d9a9a80175a6b..a44c5e688ceab 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -1234,6 +1234,8 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx; + // // Rotary embedding dispatch structure. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index e0bc72a408c3f..796bfd13b47bf 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -742,6 +742,9 @@ Return Value: this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelLasx; this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Lasx; + // add new sqn-lasx kernel + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchLasx; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; }else if( cap_lsx ){ @@ -824,4 +827,4 @@ thread_local size_t ThreadedBufSize = 0; thread_local std::unique_ptr ThreadedBufHolder(nullptr, &_aligned_free); #else thread_local std::unique_ptr ThreadedBufHolder(nullptr, &free); -#endif \ No newline at end of file +#endif diff --git a/onnxruntime/core/mlas/lib/s390x/Quantize.cpp b/onnxruntime/core/mlas/lib/s390x/Quantize.cpp index 6bb4475fc0ef1..7842cd32b73e5 100644 --- a/onnxruntime/core/mlas/lib/s390x/Quantize.cpp +++ b/onnxruntime/core/mlas/lib/s390x/Quantize.cpp @@ -198,6 +198,9 @@ Return Value: auto CharVector = vec_pack(ShortVector0, ShortVector1); vec_xst(CharVector, 0, (int8_t *)(&TmpOutput[0])); + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(CharVector); + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp new file mode 100644 index 0000000000000..04c6540d1783b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx.cpp @@ -0,0 +1,1089 @@ +/*++ + +Module Name: + + sqnbitgemm_kernel_lasx.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for loongarch64. Accelerate inference + optimization using lasx/lsx vector instruction sets. + +--*/ + +#include +#include + +#include +#include +#include +#include +#include "core/common/safeint.h" + +#include "qnbitgemm.h" +#include "sqnbitgemm_kernel_lasx_common.h" + +// 1. qnbitgemm.h->Q4BitGemmPackQuantBDataSize +template +static size_t +QNBitGemmPackQuantBDataSize_Lasx( + size_t N, + size_t K, + size_t BlkLen, + bool /* HasZeroPoint */, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ComputeType == SQNBIT_CompInt8) { + SafeInt PackedQuantBDataSize = SafeInt(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const SafeInt ScaleSize = SafeInt(N) * BlockCountK * sizeof(float); + SafeInt BlkSumSize = SafeInt(BlockCountK) * MlasDivRoundup(N, 16) * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += SafeInt(PackedQuantBDataAlignment) - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += SafeInt(BlkSumAlignment) - 1; + + PackedQuantBDataSize += ScaleSize + BlkSumSize; + return PackedQuantBDataSize.Value(); + } else { + SafeInt PackedQuantBDataSize = SafeInt(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize.Value(); + } +} + +// 2. qnbitgemm.h->SQ4BitGemmPackQuantBData +static void +SQ4BitGemmPackQuantBData_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const SafeInt Iterations = SafeInt(N) * BlockCountK; // one iteration per block + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + // + // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | + // => + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations.Value(), + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const SafeInt data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset.Value(); + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset.Value(); + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +// 3. qnbitgemm.h->SQ4BitGemmPackQuantBDataAndBlkSum +static void +SQ4BitGemmPackQuantBDataAndBlkSum_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + SafeInt offset = SafeInt(N) * BlockCountK; + std::copy(QuantBScaleBegin, QuantBScaleBegin + offset.Value(), packed_quant_b.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { + ComputePackBlkSum_Lasx( + BlkLen, SubBlkLen, N, + packed_quant_b.PackedQuantBScale, + QuantBZPBegin, + packed_quant_b.QuantBBlkSum, + ThreadPool, + BlockCountK + ); + } +} + +// 3. qnbitgemm.h->SQ8BitGemmPackQuantBDataAndBlkSum +static void +SQ8BitGemmPackQuantBDataAndBlkSum_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + Q8PackQuantBDataAndBlkSum_lasx(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +MLAS_FORCEINLINE +__m256 +load_float_n_lasx(const float* data, int n) +{ + if (n <= 0) { + alignas(32) float zero_array[8] = {0}; + return (__m256)__lasx_xvld((void*)&zero_array, 0); + } + alignas(32) float buf[8] = {0}; + if (n > 0 && n <= 8) { + for (int i = 0; i < n; ++i) { + buf[i] = data[i]; + } + } + return (__m256)__lasx_xvld((void*)&buf, 0); +} + +// ComputeDotProducts_BlkLen32Plus_CompFp32_lasx +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen32 = 32; + constexpr size_t SubBlkStep16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen32); + static_assert(SubBlkStep16 == 16); + + __m256 acc[NCols]; + + alignas(32) static const float zero_array[8] = {0}; + UnrolledLoop([&](size_t i) { + acc[i] = (__m256)__lasx_xvld((void*)&zero_array, 0); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; + [[maybe_unused]] int count_half_4 = 0; + [[maybe_unused]] uint8_t offset[NCols]; + + // TODO: Improve Memory Access Performance with Prefetching Matrix Operations + // alignas(32) float a_buf[2][32] = {0.0}; + //__m256 a_buf[8]; + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset.Value()); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + }); + + // not ready for "Manual conversion to float" in neon yet. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen32) { + size_t kklen = std::min((int)SubBlkLen32, (int)(ck - kk)); + + __m256 av0_8_ps = load_float_n_lasx(ARowPtr + k + kk, std::min(kklen, 8)); + __m256 av1_8_ps = load_float_n_lasx(ARowPtr + k + kk + 8, std::min(kklen > 8 ? kklen - 8 : 0, 8)); + __m256 av2_8_ps = load_float_n_lasx(ARowPtr + k + kk + 16, std::min(kklen > 16 ? kklen - 16 : 0, 8)); + __m256 av3_8_ps = load_float_n_lasx(ARowPtr + k + kk + 24, std::min(kklen > 24 ? kklen - 24 : 0, 8)); + + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (int)((kk % (2 * SubBlkLen32)) / SubBlkLen32); + } + + UnrolledLoop([&](size_t i) { + __m256i bv_0_32; + + if constexpr (IsBlkLen64Layout) { + __m256i bv_32_4bit_tmp = __lasx_xvld(b_blk_data_col_ptr[i], 0); + if (!count_half_4) + bv_0_32 = __lasx_xvandi_b(bv_32_4bit_tmp, 0x0F); + else + bv_0_32 = __lasx_xvsrli_b(bv_32_4bit_tmp, 4); + b_blk_data_col_ptr[i] += count_half_4 / 2 * SubBlkStep16; + } else { + // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on b_blk_data_col_ptr to ensure that it could be read in 16 units + std::memcpy(packed_bytes, b_blk_data_col_ptr[i], 16); + __m256i bv_32_4bit_tmp = __lasx_xvld((void*)&packed_bytes, 0); + __m256i bv_0_15_tmp = __lasx_xvpermi_d(__lasx_xvandi_b(bv_32_4bit_tmp, 0x0F), 0x36); + __m256i bv_16_31_tmp = __lasx_xvpermi_d(__lasx_xvsrli_b(bv_32_4bit_tmp, 4), 0x36); + bv_0_32 = __lasx_xvpermi_d(__lasx_xvpermi_w(bv_16_31_tmp, bv_0_15_tmp, 0xEE), 0x72); + b_blk_data_col_ptr[i] += SubBlkStep16; + } + + __m256i zp = HasZeroPoint ? __lasx_xvldrepl_b((void*)&offset[i], 0) : __lasx_xvrepli_b(0x08); + bv_0_32 = __lasx_xvsub_b(bv_0_32, zp); + + // (1)8bit -> 16bit + __m256i bv_0_15 = __lasx_xvexth_h_b(__lasx_xvpermi_d(bv_0_32, 0x72)); + __m256i bv_16_31 = __lasx_xvexth_h_b(__lasx_xvpermi_d(bv_0_32, 0xD8)); + + // (2)16bit -> int32 + __m256i bv_0_7 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_0_15, 0x72)); + __m256i bv_8_15 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_0_15, 0xD8)); + __m256i bv_16_23 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_16_31, 0x72)); + __m256i bv_24_31 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_16_31, 0xD8)); + + // (3)int32 -> fp32 + __m256 fbv_0_7 = __lasx_xvffint_s_w(bv_0_7); + __m256 fbv_8_15 = __lasx_xvffint_s_w(bv_8_15); + __m256 fbv_16_23 = __lasx_xvffint_s_w(bv_16_23); + __m256 fbv_24_31 = __lasx_xvffint_s_w(bv_24_31); + + __m256 scale_ps = (__m256)__lasx_xvldrepl_w(&scale_v[i], 0); + + fbv_0_7 = __lasx_xvfmul_s(fbv_0_7, scale_ps); + fbv_8_15 = __lasx_xvfmul_s(fbv_8_15, scale_ps); + fbv_16_23 = __lasx_xvfmul_s(fbv_16_23, scale_ps); + fbv_24_31 = __lasx_xvfmul_s(fbv_24_31, scale_ps); + + acc[i] = __lasx_xvfmadd_s(fbv_0_7, av0_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_8_15, av1_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_16_23, av2_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_24_31, av3_8_ps, acc[i]); + }); + } + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + ++s; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators_Lasx(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = __lsx_vfadd_s(acc_x, (__m128)__lsx_vld((void*)bias_ptr, 0)); + } + __lsx_vst(acc_x, sum_ptr, 0); + } else { + UnrolledLoop([&](size_t i) { + float sum = hsum_float_8_lasx(acc[i]); + float bias_tmp = bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + sum_ptr[i] = sum + bias_tmp; + }); + } +} + +// ComputeDotProducts_BlkLen16_CompFp32_lasx +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen16_CompFp32_lasx( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen16 = 16; + constexpr size_t SubBlkStep8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen16); + static_assert(SubBlkStep8 == 8); + + __m256 acc[NCols]; + alignas(32) int zero_array[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + UnrolledLoop([&](size_t i) { + acc[i] = (__m256)__lasx_xvld((void*)&zero_array, 0); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; + [[maybe_unused]] uint8_t offset[NCols]; + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset.Value()); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + }); + + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen16) { + size_t kklen = std::min((int)SubBlkLen16, (int)(ck - kk)); + + __m256 av_lo = load_float_n_lasx(ARowPtr + k + kk, std::min(kklen, 8)); + __m256 av_hi = load_float_n_lasx(ARowPtr + k + kk + 8, std::min(kklen > 8 ? kklen - 8 : 0, 8)); + + UnrolledLoop([&](size_t i) { + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on b_blk_data_col_ptr to ensure that it could be read in 8 units + std::memcpy(packed_bytes + 24, b_blk_data_col_ptr[i], 8); + __m256i B_16val = __lasx_xvld((void*)&packed_bytes, 0); + + /* + low->high + | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | x 3 + | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | 24-31 + */ + + b_blk_data_col_ptr[i] += SubBlkStep8; + __m256i lower = __lasx_xvandi_b(B_16val, 0x0F); + __m256i upper = __lasx_xvsrli_b(B_16val, 4); + __m256i packb = __lasx_xvpermi_d(__lasx_xvpackod_d(upper, lower), 0xD8); + + __m256i zp = HasZeroPoint ? __lasx_xvldrepl_b((void*)&offset[i], 0) : __lasx_xvrepli_b(0x08); + packb = __lasx_xvsub_b(packb, zp); + __m256i bv0_15 = __lasx_xvexth_h_b(packb); + + __m256i bv0_7 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv0_15, 0x72)); + __m256i bv8_15 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv0_15, 0xD8)); + + __m256 fbv0_7 = __lasx_xvffint_s_w(bv0_7); + __m256 fbv8_15 = __lasx_xvffint_s_w(bv8_15); + __m256 scale = (__m256)__lasx_xvldrepl_w((void*)&scale_v[i], 0); + fbv0_7 = __lasx_xvfmul_s(fbv0_7, scale); + fbv8_15 = __lasx_xvfmul_s(fbv8_15, scale); + + acc[i] = __lasx_xvfmadd_s(av_lo, fbv0_7, acc[i]); + acc[i] = __lasx_xvfmadd_s(av_hi, fbv8_15, acc[i]); + }); + } + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + ++s; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators_Lasx(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = __lsx_vfadd_s(acc_x, (__m128)__lsx_vld((void*)bias_ptr, 0)); + } + __lsx_vst(acc_x, sum_ptr, 0); + } else { + UnrolledLoop([&](size_t i) { + float sum = 0.0f; + alignas(32) float acc_buf[8]; + __lasx_xvst(acc[i], (void*)&acc_buf, 0); + UnrolledLoop<8>([&](size_t j) { sum += acc_buf[j]; }); + float bias_tmp = bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + sum_ptr[i] = sum + bias_tmp; + }); + } +} + +// SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = (CountN) - NCols4; + while (nblk >= 0) { + ComputeDotProducts_BlkLen16_CompFp32_lasx( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; + SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset.Value(); + QuantBScaleColPtr += scale_offset.Value(); + if constexpr (HasZeroPoint) { + SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset.Value(); + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkLen16_CompFp32_lasx<1, HasZeroPoint>( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +// SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + + while (nblk >= 0) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; + SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset.Value(); + QuantBScaleColPtr += scale_offset.Value(); + if constexpr (HasZeroPoint) { + SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset.Value(); + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than NCols + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx<1, HasZeroPoint, true>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx<1, HasZeroPoint, false>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32_Lasx( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + BlkLen, A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + BlkLen, A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx( + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + + constexpr size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + /* + TODO: constexpr use template parameter + Since QuantBZeroPoint is a model parameter and cannot be determined at compile time, constexpr cannot be used + and comments are required, However, when the usage scenario can be determined, constexpr can be used to enhance + performance. + */ + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time + constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 + for (size_t col = 0; col < CountN; col += NCols8) { + const int cols = std::min((int)NCols8, (int)CountN - (int)col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset.Value(); + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; + SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); + const float* scale_ptr = QuantBScale + b_scale_offset.Value(); + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + bool is_lower = (k % 2) == 0; + + __m256i weight_16_epi16[NCols8]; + __m256 scale_8_ps[NCols8]; + UnrolledLoop([&](size_t col_) { + if ((int)col_ < cols) { + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on QuantBData to ensure that it could be read in 8 units + std::memcpy(packed_bytes + 24, b_data_ptr + col_ * b_data_col_stride_in_bytes, 8); + __m256i B_16val = __lasx_xvld((void*)&packed_bytes, 0); + // low->high + // | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | x 3 + // | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | 24-31 + + __m256i lower = __lasx_xvandi_b(B_16val, 0x0F); + __m256i upper = __lasx_xvsrli_b(B_16val, 4); + __m256i packb = __lasx_xvpermi_d(__lasx_xvpackod_d(upper, lower), 0xD8); + + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + __m256i zero_point = __lasx_xvreplgr2vr_b(static_cast(zp)); + packb = __lasx_xvsub_b(packb, zero_point); + } else { + __m256i zero_point = __lasx_xvrepli_b(0x08); + packb = __lasx_xvsub_b(packb, zero_point); + } + weight_16_epi16[col_] = __lasx_xvexth_h_b(packb); + scale_8_ps[col_] = (__m256)__lasx_xvldrepl_w((void*)(scale_ptr + col_ * BlockCountK), 0); + } else { + weight_16_epi16[col_] = __lasx_xvrepli_d(0); + scale_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + }); + + for (int i_of_2 = 0; i_of_2 < 2; i_of_2++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if ((int)col_ < cols) { + if (i_of_2 == 0) { + __m256i weight_i_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_16_epi16[col_], 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_8_epi32), scale_8_ps[col_]); + } else { + __m256i weight_i_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_16_epi16[col_], 0xD8)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + } + // transpose and store + __m256 a0 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0x44); // a1, a2, b1, b2, a5, a6, b5, b6 + __m256 a1 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0xEE); // a3, a4, b3, b4, a7, a8, b7, b8 + __m256 a2 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0x44); // c1, c2, d1, d2, c5, c6, d5, d6 + __m256 a3 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0xEE); // c3, c4, d3, d4, c7, c8, d7, d8 + __m256 a4 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0x44); // e1, e2, f1, f2, e5, e6, f5, f6 + __m256 a5 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0xEE); // e3, e4, f3, f4, e7, e8, f7, f8 + __m256 a6 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0x44); // g1, g2, h1, h2, g5, g6, h5, h6 + __m256 a7 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0xEE); // g3, g4, h3, h4, g7, g8, h7, h8 + + __m256 b0 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0x88); // a1, b1, c1, d1, a5, b5, c5, d5 + __m256 b1 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0xDD); // a2, b2, c2, d2, a6, b6, c6, d6 + __m256 b2 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0x88); // a3, b3, c3, d3, a7, b7, c7, d7 + __m256 b3 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0xDD); // a4, b4, c4, d4, a8, b8, c8, d8 + __m256 b4 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0x88); // e1, f1, g1, h1, e5, f5, g5, h5 + __m256 b5 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0xDD); // e2, f2, g2, h2, e6, f6, g6, h6 + __m256 b6 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0x88); // e3, f3, g3, h3, e7, f7, g7, h7 + __m256 b7 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0xDD); // e4, f4, g4, h4, e8, f8, g8, h8 + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_2 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x02); // a1, b1, c1, d1, e1, f1, g1, h1 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x02); // a2, b2, c2, d2, e2, f2, g2, h2 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x02); // a3, b3, c3, d3, e3, f3, g3, h3 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x02); // a4, b4, c4, d4, e4, f4, g4, h4 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x13); // a5, b5, c5, d5, e5, f5, g5, h5 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x13); // a6, b6, c6, d6, e6, f6, g6, h6 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x13); // a7, b7, c7, d7, e7, f7, g7, h7 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x13); // a8, b8, c8, d8, e8, f8, g8, h8 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, 0); + } + } + } +} + +template +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t GemmFloatKernelWidth16 = 16; + constexpr size_t SubblkLen32 = 32; + + const size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t subblk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubblkLen32); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + /* + TODO: constexpr use template parameter + Since QuantBZeroPoint is a model parameter and cannot be determined at compile time, constexpr cannot be used + and comments are required, However, when the usage scenario can be determined, constexpr can be used to enhance + performance. + */ + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] int count_half_4 = 0; + + for (size_t col = 0; col < CountN; col += NCols8) { + // TODO: handle last tile with cols < NCols8 + const size_t cols = std::min(NCols8, CountN - col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset.Value(); + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; + SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); + const float* scale_ptr = QuantBScale + b_scale_offset.Value(); + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + bool is_lower = (k % 2) == 0; + + for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) { + __m256i weight_32_epi8[NCols8]; + __m256 scale_8_ps[NCols8]; + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (subblk % 2); + } + UnrolledLoop([&](size_t col_) { + // 1. load 32 4-bit data + if (col_ < cols) { + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 + // at the end of subblk loop, increment b_data_ptr by 2 * subblk_data_size_in_bytes if subblk % 2 == 1 + // so that all v0-64 of the pack are dequantized. + __m256i bv_32_4bit_tmp = __lasx_xvld(b_data_ptr + col_ * b_data_col_stride_in_bytes, 0); + if (!count_half_4) + weight_32_epi8[col_] = __lasx_xvandi_b(bv_32_4bit_tmp, 0x0F); + else + weight_32_epi8[col_] = __lasx_xvsrli_b(bv_32_4bit_tmp, 4); + } else { + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on QuantBData to ensure that it could be read in 16 units + std::memcpy(packed_bytes, b_data_ptr + col_ * b_data_col_stride_in_bytes, 16); + __m256i bv_32_4bit_tmp = __lasx_xvld((void*)&packed_bytes, 0); + __m256i bv_0_15_tmp = __lasx_xvpermi_d(__lasx_xvandi_b(bv_32_4bit_tmp, 0x0F), 0x36); + __m256i bv_16_31_tmp = __lasx_xvpermi_d(__lasx_xvsrli_b(bv_32_4bit_tmp, 4), 0x36); + weight_32_epi8[col_] = __lasx_xvpermi_d(__lasx_xvpermi_w(bv_16_31_tmp, bv_0_15_tmp, 0xEE), 0x72); + } + + // 2. load zeropoint and scale + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + __m256i zero_point = __lasx_xvreplgr2vr_b(static_cast(zp)); + weight_32_epi8[col_] = __lasx_xvsub_b(weight_32_epi8[col_], zero_point); + } else { + __m256i zero_point = __lasx_xvrepli_b(0x08); + weight_32_epi8[col_] = __lasx_xvsub_b(weight_32_epi8[col_], zero_point); + } + + scale_8_ps[col_] = (__m256)__lasx_xvldrepl_w((void*)(scale_ptr + col_ * BlockCountK), 0); + } else { + weight_32_epi8[col_] = __lasx_xvrepli_d(0); + scale_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + }); + + for (int i_of_4 = 0; i_of_4 < 4; i_of_4++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if (col_ < cols) { + if (i_of_4 == 0) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(__lasx_xvpermi_d(weight_32_epi8[col_], 0xE1)); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 1) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(weight_32_epi8[col_]); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 2) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(__lasx_xvpermi_d(weight_32_epi8[col_], 0xD8)); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 3) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(weight_32_epi8[col_]); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0xD8)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + } + // transpose and store + __m256 a0 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0x44); // a1, a2, b1, b2, a5, a6, b5, b6 + __m256 a1 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0xEE); // a3, a4, b3, b4, a7, a8, b7, b8 + __m256 a2 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0x44); // c1, c2, d1, d2, c5, c6, d5, d6 + __m256 a3 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0xEE); // c3, c4, d3, d4, c7, c8, d7, d8 + __m256 a4 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0x44); // e1, e2, f1, f2, e5, e6, f5, f6 + __m256 a5 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0xEE); // e3, e4, f3, f4, e7, e8, f7, f8 + __m256 a6 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0x44); // g1, g2, h1, h2, g5, g6, h5, h6 + __m256 a7 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0xEE); // g3, g4, h3, h4, g7, g8, h7, h8 + + __m256 b0 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0x88); // a1, b1, c1, d1, a5, b5, c5, d5 + __m256 b1 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0xDD); // a2, b2, c2, d2, a6, b6, c6, d6 + __m256 b2 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0x88); // a3, b3, c3, d3, a7, b7, c7, d7 + __m256 b3 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0xDD); // a4, b4, c4, d4, a8, b8, c8, d8 + __m256 b4 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0x88); // e1, f1, g1, h1, e5, f5, g5, h5 + __m256 b5 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0xDD); // e2, f2, g2, h2, e6, f6, g6, h6 + __m256 b6 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0x88); // e3, f3, g3, h3, e7, f7, g7, h7 + __m256 b7 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0xDD); // e4, f4, g4, h4, e8, f8, g8, h8 + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_4 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x02); // a1, b1, c1, d1, e1, f1, g1, h1 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x02); // a2, b2, c2, d2, e2, f2, g2, h2 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x02); // a3, b3, c3, d3, e3, f3, g3, h3 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x02); // a4, b4, c4, d4, e4, f4, g4, h4 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x13); // a5, b5, c5, d5, e5, f5, g5, h5 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x13); // a6, b6, c6, d6, e6, f6, g6, h6 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x13); // a7, b7, c7, d7, e7, f7, g7, h7 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x13); // a8, b8, c8, d8, e8, f8, g8, h8 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, 0); + } + dst_ptr += SubblkLen32 * GemmFloatKernelWidth16; + if constexpr (IsBlkLen64Layout) { + b_data_ptr += (subblk % 2) * 2 * subblk_data_size_in_bytes; + } else { + b_data_ptr += subblk_data_size_in_bytes; + } + } // subblk + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemm_CompFp32_Lasx( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockStrideQuantB +) +{ + if (BlkLen == 16) { + Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx( + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else if (BlkLen == 32) { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } +} + +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx = []() { + MLAS_QNBIT_GEMM_DISPATCH d; + + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize_Lasx<4>; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData_Lasx; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum_Lasx; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum_Lasx; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_Lasx; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_Lasx; + + return d; +}(); diff --git a/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx_common.h b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx_common.h new file mode 100644 index 0000000000000..508bcba8a2de7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/sqnbitgemm_kernel_lasx_common.h @@ -0,0 +1,514 @@ +/*++ + Abstract: + + Lasx/Lsx tool function, Auxiliary functions for inference required by + 4-bit/8-bit quantization models. +--*/ +#pragma once +#include "qnbitgemm.h" +#include "core/common/safeint.h" +#include + +template +struct MlasAlignedAllocator { + using value_type = T; + + MlasAlignedAllocator() = default; + + template + MlasAlignedAllocator(const MlasAlignedAllocator&) {} + + T* allocate(size_t n) { + // If RequiredAlignment > 0, use the required value + // Otherwise, use the value of MlasGetPreferredBufferAlignment() + size_t alignment = RequiredAlignment > 0 ? + RequiredAlignment : + MlasGetPreferredBufferAlignment(); + + size_t size = n * sizeof(T); + if (size % alignment != 0) // check the size + size = ((size + alignment - 1) / alignment) * alignment; + #if defined(_MSC_VER) + void* ptr = _aligned_malloc(size, alignment); + #else + void* ptr = aligned_alloc(alignment, size); + #endif + if (!ptr) throw std::bad_alloc(); + return static_cast(ptr); + } + + void deallocate(T* ptr, size_t) { + #if defined(_MSC_VER) + _aligned_free(ptr); + #else + free(ptr); + #endif + } + + template + struct rebind { + using other = MlasAlignedAllocator; + }; +}; + +static MLAS_FORCEINLINE __m256 +__lasx_xvzero() +{ + return (__m256)__lasx_xvldi(0); +} + +static size_t +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +{ + size_t T = n / 4, t = n % 4; + bool te = T == N / 4; + SafeInt scale_dst_offset = SafeInt(T) * 4 * SubOrBlkCountK; + if (te) { + scale_dst_offset += SafeInt(t) * SubOrBlkCountK + k_sub_or_blk; + } else { + scale_dst_offset += SafeInt(k_sub_or_blk) * 4 + t; + } + return scale_dst_offset.Value(); +} + +static size_t +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) +{ + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; + SafeInt scale_dst_offset = SafeInt(T) * 4 * BlockCountK; + if (te) { + scale_dst_offset += SafeInt(t) * BlockCountK + k_blk; + } else { + scale_dst_offset += SafeInt(k_subblk) * blks_per_sub * 4; + if (be) { + scale_dst_offset += SafeInt(b) * 4 + t; + } else { + scale_dst_offset += SafeInt(t) * blks_per_sub + b; + } + } + return scale_dst_offset.Value(); +} + +static void +ComputePackBlkSum_Lasx( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK +) +{ + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + float QuantBScale = QuantBScaleBegin[src_blk_offset.Value()]; + uint8_t zp = 8; + + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + SafeInt src_zp_offset = SafeInt(ZPCountK) * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset.Value(); + const std::byte low_mask{0X0f}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } + + float result = -QuantBScale * zp; + + const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + BlockSumBegin[dst_offset.Value()] = result; + + if (BlkLen == 16) { + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); + QuantBScaleBegin[scale_dst_offset] = QuantBScale; + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + QuantBScaleBegin[scale_dst_offset] = QuantBScale; + } + }); +} + +static void +PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen +) +{ + constexpr size_t BlkBitWidth = 4; + const size_t BlkBytePairCount = BlkLen / 4; + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / SubBlkCountK; + const size_t k_subblk = tid % SubBlkCountK; + + const SafeInt src_data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset.Value(); + + size_t PackBytePairCount = SubBlkBytePairCount; + size_t PackDataSize = SubBlkDataSize; + + auto pack_subblk = []( + const std::byte* QuantBData, std::byte* PackedQuantBData, + size_t pack_byte_pair_count, size_t pack_data_size + ) { + for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0f}) | ((src1 & std::byte{0x0f}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } }; + + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && + SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; + for (size_t k = 0; k < k_blks_remaining; k++) { + const SafeInt k_blk = SafeInt(k_subblk) * SubBlkLen / BlkLen + k; + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } else { + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const SafeInt k_blk = SafeInt(k_subblk) * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } + ); +} + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +static MLAS_FORCEINLINE __m128 +FoldAccumulators_Lasx(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3) +{ + /* + acc0 = [A0, A1, A2, A3, A4, A5, A6, A7] + acc1 = [B0, B1, B2, B3, B4, B5, B6, B7] + */ + + __m256 tmpAB_lo = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc1, acc0, 0x44), 0xD8); // a1,a2,a5,a6,b1,b2,b5,b6 + __m256 tmpAB_hi = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc1, acc0, 0xEE), 0xD8); // a3,a4,a7,a8,b3,b4,b7,b8 + __m256 tmpCD_lo = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc3, acc2, 0x44), 0xD8); // c1,c2,c5,c6,d1,d2,d5,d6 + __m256 tmpCD_hi = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc3, acc2, 0xEE), 0xD8); // c3,c4,c7,c8,d3,d4,d7,d8 + + __m256 tmpABCD_lo1 = (__m256)__lasx_xvpermi_w(tmpCD_lo, tmpAB_lo, 0x44); // a1,a2,c1,c2,b1,b2,d1,d2 + __m256 tmpABCD_lo2 = (__m256)__lasx_xvpermi_w(tmpCD_hi, tmpAB_hi, 0x44); // a3,a4,c3,c4,b3,b4,d3,d4 + __m256 tmpABCD_hi1 = (__m256)__lasx_xvpermi_w(tmpCD_lo, tmpAB_lo, 0xEE); // a5,a6,c5,c6,b5,b6,d5,d6 + __m256 tmpABCD_hi2 = (__m256)__lasx_xvpermi_w(tmpCD_hi, tmpAB_hi, 0xEE); // a7,a8,c7,c8,b7,b8,d7,d8 + + __m256 sumABCD = __lasx_xvfadd_s(__lasx_xvfadd_s(tmpABCD_lo1, tmpABCD_lo2), __lasx_xvfadd_s(tmpABCD_hi1, tmpABCD_hi2)); + + __m256 sum0 = (__m256)__lasx_xvpermi_w(sumABCD, sumABCD, 0xB1); + sumABCD = (__m256)__lasx_xvpermi_d(__lasx_xvfadd_s(sumABCD, sum0), 0xD8); + + sumABCD = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(sumABCD, sumABCD, 0x88), 0xD8); + + alignas(32) float tmp[8]; + __lasx_xvst(sumABCD, (void*)&tmp, 0); + __m128 result = (__m128)__lsx_vld((void*)&tmp, 0); + return result; +} + +__m256 +permutevar_ps_lasx(__m256 vec, __m256i idx_mask) +{ + __m256i veci = (__m256i)vec; + __m256i shuffled = __lasx_xvshuf_w(veci, veci, idx_mask); + return (__m256)shuffled; +} + +static void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen +) +{ + constexpr size_t BlkBitWidth = 8; + const size_t StrideN = BlockCountK * BlkLen; + const size_t BlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubBlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + const size_t SubBlkCountK = MlasDivRoundup(StrideN, SubBlkLen); + const size_t RemainderBlockCountK = BlockCountK % (SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // SubBlkLen rows x 4 columns pack together, then remainder BlkLen x 4 columns if SubBlkLen > BlkLen. + // remainder columns keep the original order. + // SubBlkLen >= 16 and is multiple of 16 + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / SubBlkCountK; + const size_t c_4 = c & (~3), c_res = c & 3; + const size_t r_subblk = tid % SubBlkCountK; + + const SafeInt data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkLen; + const std::byte* src = QuantBDataBegin + data_offset.Value(); + + if (c_4 + 4 <= N) { // full 4 cols + if (RemainderBlockCountK && r_subblk == SubBlkCountK - 1) { // remainder blocks + const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; + std::byte* dest = + PackedQuantBDataBegin + subblk_data_offset.Value(); + for (size_t i = 0; i < RemainderBlockCountK; i++) { + std::copy(src, src + BlkSize, dest); + src += BlkSize; + dest += BlkSize * 4; + } + } else { // full subblock + const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; + std::byte* dest = + PackedQuantBDataBegin + subblk_data_offset.Value(); + std::copy(src, src + SubBlkSize, dest); + } + } else { // remainder cols + const SafeInt remain_data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkSize; + std::byte* dest = + PackedQuantBDataBegin + remain_data_offset.Value(); + std::copy(src, src + std::min(SubBlkSize, StrideN - r_subblk * SubBlkSize), dest); + } + } + ); +} + +static void +Q8ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK +) +{ + SafeInt size = SafeInt(N) * BlockCountK; + std::vector> QuantBScaleBeginCopy(size.Value()); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n_4 = n & (~3), n_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset.Value()]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset.Value(); + zp = (uint8_t)(*QuantBZP); + } + + const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset.Value()) = -QuantBScale * zp; + + if (n_4 + 4 > N) { + SafeInt ptr_offset = SafeInt(n) * BlockCountK + k_blk; + *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + } else if (BlkLen >= SubBlkLen) { + SafeInt ptr_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; + *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + } else { + size_t blks_per_sub = SubBlkLen / BlkLen; + size_t remainder_blk = BlockCountK % blks_per_sub; + size_t sub_blk_count_k = MlasDivRoundup(BlockCountK, blks_per_sub); + size_t k_subblk = k_blk / blks_per_sub; + size_t k_blk_res = k_blk % blks_per_sub; + SafeInt dest_offset; + + if (remainder_blk && k_subblk == sub_blk_count_k - 1) { // remainder blocks + dest_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; + } else { // full subblock + dest_offset = SafeInt(n_4) * BlockCountK + k_subblk * blks_per_sub * 4 + n_res * blks_per_sub + k_blk_res; + } + + *(QuantBScaleBegin + dest_offset.Value()) = QuantBScale; + } + }); +} + +static void +Q8PackQuantBDataAndBlkSum_lasx( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + +static MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8_lasx(__m256 v0, __m256 v1) +{ + // fp32->int32 + __m256i v0_8_epi32 = __lasx_xvftint_w_s(__lasx_xvfrint_s(v0)); + __m256i v1_8_epi32 = __lasx_xvftint_w_s(__lasx_xvfrint_s(v1)); + + alignas(32) int val_0_15_i32[16] = {0}; + alignas(32) int8_t val_0_15_i8[16] = {0}; + + __lasx_xvst(v0_8_epi32, (void*)&val_0_15_i32, 0); + __lasx_xvst(v1_8_epi32, (void*)&val_0_15_i32, 32); + + UnrolledLoop<16>([&](size_t i) { + if (val_0_15_i32[i] > 127) + val_0_15_i8[i] = 127; + else if (val_0_15_i32[i] < -128) + val_0_15_i8[i] = -128; + else + val_0_15_i8[i] = static_cast(val_0_15_i32[i]); + }); + + __m128i result = __lsx_vld((void*)&val_0_15_i8, 0); + return result; +} + +static inline __m256i +lasx_maddubs_epi16_sat(__m256i a, __m256i b) +{ + // a: bytes interpreted as unsigned + // b: bytes interpreted as signed + __m256i zero_h = __lasx_xvldi(0); // 256-bit zeros + + __m256i even_prod16 = __lasx_xvmaddwev_h_bu_b(zero_h, a, b); + __m256i odd_prod16 = __lasx_xvmaddwod_h_bu_b(zero_h, a, b); + + __m256i sum16_sat = __lasx_xvsadd_h(even_prod16, odd_prod16); + + return sum16_sat; // 16-bit signed saturated results (16 lanes) +} + +static inline __m256i +lasx_madd_epi16(__m256i a, __m256i b) +{ + __m256i zero = __lasx_xvldi(0); + __m256i even_acc = __lasx_xvmaddwev_w_h(zero, a, b); + __m256i result = __lasx_xvmaddwod_w_h(even_acc, a, b); + + return result; // 32-bit signed sums, matches _mm256_madd_epi16 semantics (no saturation) +} + +static inline __m256i +lasx_hadd_epi32(__m256i a, __m256i b) +{ + __m256i a_swapped = __lasx_xvshuf4i_w(a, 0xB1); // 0xB1 = binary 10110001 + __m256i b_swapped = __lasx_xvshuf4i_w(b, 0xB1); + + __m256i a_sum = __lasx_xvadd_w(a, a_swapped); + __m256i b_sum = __lasx_xvadd_w(b, b_swapped); + + __m256i a_even = __lasx_xvpermi_w(a_sum, a_sum, 0x88); + __m256i b_even = __lasx_xvpermi_w(b_sum, b_sum, 0x88); + + __m256i result = __lasx_xvpermi_q(a_even, b_even, 0x20); + + return result; +} + +static inline __m256i +lasx_cvtepu8_epi16_emul_from_m128(const __m128i v128) +{ + alignas(32) int8_t num[32] = {0}; + __lsx_vst(v128, (void*)&num, 0); + __m256i result = __lasx_xvld((void*)&num, 0); + result = __lasx_xvexth_hu_bu(__lasx_xvpermi_d(result, 0x72)); + return result; +} + +static MLAS_FORCEINLINE float +hsum_float_8_lasx(__m256 v) +{ + v = __lasx_xvfadd_s(v, (__m256)__lasx_xvpermi_d(v, 0xB1)); + v = __lasx_xvfadd_s(v, (__m256)__lasx_xvpermi_d(v, 0x4E)); + alignas(32) float num[8] = {0.0f}; + __lasx_xvst(v, (void*)num, 0); + + return num[0] + num[1]; +} diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index 0116dec5170f0..04b89a5be061c 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -64,6 +64,7 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Relu", {}}, {"Gelu", {}}, {"Elu", {}}, + {"Erf", {}}, {"HardSigmoid", {}}, {"HardSwish", {}}, {"Sigmoid", {}}, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 5bcb8ca394346..8b536a01245f8 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -376,31 +376,34 @@ Status QnnBackendManager::LoadQnnSerializerBackend() { } Status QnnBackendManager::LoadQnnSystemLib() { + if (!system_lib_loaded_) { #ifdef _WIN32 - std::string system_lib_file = "QnnSystem.dll"; + std::string system_lib_file = "QnnSystem.dll"; #else - std::string system_lib_file = "libQnnSystem.so"; + std::string system_lib_file = "libQnnSystem.so"; #endif // #ifdef _WIN32 - LOGS_DEFAULT(INFO) << "Loading QnnSystem lib"; - std::filesystem::path lib_file_path(backend_path_.c_str()); - std::string sys_file_path(lib_file_path.remove_filename().string() + system_lib_file); - QnnSystemInterface_t* system_interface_provider{nullptr}; - auto rt = GetQnnInterfaceProvider(sys_file_path.c_str(), - "QnnSystemInterface_getProviders", - &system_lib_handle_, - {QNN_SYSTEM_API_VERSION_MAJOR, - QNN_SYSTEM_API_VERSION_MINOR, - QNN_SYSTEM_API_VERSION_PATCH}, - &system_interface_provider); - ORT_RETURN_IF_ERROR(rt); - Qnn_Version_t system_interface_version = GetQnnInterfaceApiVersion(system_interface_provider); - qnn_sys_interface_ = system_interface_provider->QNN_SYSTEM_INTERFACE_VER_NAME; - - LOGS_DEFAULT(INFO) << "Found valid system interface, version: " << system_interface_version.major - << "." << system_interface_version.minor - << " backend provider name: " << system_interface_provider->providerName; - + LOGS_DEFAULT(INFO) << "Loading QnnSystem lib"; + std::filesystem::path lib_file_path(backend_path_.c_str()); + std::string sys_file_path(lib_file_path.remove_filename().string() + system_lib_file); + QnnSystemInterface_t* system_interface_provider{nullptr}; + auto rt = GetQnnInterfaceProvider(sys_file_path.c_str(), + "QnnSystemInterface_getProviders", + &system_lib_handle_, + {QNN_SYSTEM_API_VERSION_MAJOR, + QNN_SYSTEM_API_VERSION_MINOR, + QNN_SYSTEM_API_VERSION_PATCH}, + &system_interface_provider); + ORT_RETURN_IF_ERROR(rt); + Qnn_Version_t system_interface_version = GetQnnInterfaceApiVersion(system_interface_provider); + qnn_sys_interface_ = system_interface_provider->QNN_SYSTEM_INTERFACE_VER_NAME; + + LOGS_DEFAULT(INFO) << "Found valid system interface, version: " << system_interface_version.major + << "." << system_interface_version.minor + << " backend provider name: " << system_interface_provider->providerName; + + system_lib_loaded_ = true; + } return Status::OK(); } @@ -639,6 +642,7 @@ Status QnnBackendManager::InitializeProfiling() { return Status::OK(); } + bool enable_optrace = false; QnnProfile_Level_t qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; if (ProfilingLevel::BASIC == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_BASIC; @@ -646,10 +650,36 @@ Status QnnBackendManager::InitializeProfiling() { } else if (ProfilingLevel::DETAILED == profiling_level_merge_) { qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED; LOGS_DEFAULT(VERBOSE) << "Profiling level set to detailed."; + } else if (ProfilingLevel::OPTRACE == profiling_level_merge_) { + qnn_profile_level = QNN_PROFILE_LEVEL_DETAILED; + enable_optrace = true; + LOGS_DEFAULT(VERBOSE) << "Profiling level set to optrace."; } + Qnn_ErrorHandle_t result = qnn_interface_.profileCreate(backend_handle_, qnn_profile_level, &profile_backend_handle_); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to create QNN profile! Error: ", QnnErrorHandleToString(result)); +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + profiling_enabled_ = true; + ORT_RETURN_IF_ERROR(LoadQnnSystemLib()); + + if (enable_optrace) { + QnnProfile_Config_t optrace_config = QNN_PROFILE_CONFIG_INIT; + optrace_config.option = QNN_PROFILE_CONFIG_OPTION_ENABLE_OPTRACE; + optrace_config.enableOptrace = enable_optrace; + + const QnnProfile_Config_t* profile_configs[] = {&optrace_config, nullptr}; + result = qnn_interface_.profileSetConfig(profile_backend_handle_, profile_configs); + + ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to enable op trace! Error: ", QnnErrorHandleToString(result)); + } +#else + if (enable_optrace) { + LOGS_DEFAULT(WARNING) << "Profiling level set to optrace, but QNN SDK Version is older than 2.29.0. " + << "Profiling level will be set to detailed instead."; + } +#endif + return Status::OK(); } @@ -1128,6 +1158,13 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); + qnn::profile::ProfilingInfo profiling_info; +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (ProfilingEnabled()) { + profiling_info.start_time = qnn::utils::GetTimeStampInUs(); + } +#endif + rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, context_configs, @@ -1135,9 +1172,20 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, &context, profile_backend_handle_); + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (ProfilingEnabled()) { + profiling_info.stop_time = qnn::utils::GetTimeStampInUs(); + profiling_info.method_type = ProfilingMethodType::CREATE_FROM_BINARY; + profiling_info.graph_name = node_name; + } +#endif + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary. Error code: ", rt); ORT_RETURN_IF_ERROR(AddQnnContextHandle(context)); + ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo(profiling_info)); + #if QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 26) } #endif @@ -1158,8 +1206,6 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t qnn_sys_interface_.systemContextFree(sys_ctx_handle); sys_ctx_handle = nullptr; - - ORT_RETURN_IF_ERROR(ExtractBackendProfilingInfo()); context_created_ = true; LOGS(*logger_, VERBOSE) << "Load from cached QNN Context completed."; @@ -1569,7 +1615,7 @@ void QnnBackendManager::ReleaseResources() { return; } -Status QnnBackendManager::ExtractBackendProfilingInfo() { +Status QnnBackendManager::ExtractBackendProfilingInfo(qnn::profile::ProfilingInfo& profiling_info) { if (ProfilingLevel::OFF == profiling_level_merge_ || ProfilingLevel::INVALID == profiling_level_merge_) { return Status::OK(); } @@ -1603,6 +1649,7 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { ORT_RETURN_IF(nullptr == profile_backend_handle_, "Backend profile handle not valid."); + LOGS(*logger_, VERBOSE) << "Extracting profiling events for graph " << profiling_info.graph_name; const QnnProfile_EventId_t* profile_events{nullptr}; uint32_t num_events{0}; Qnn_ErrorHandle_t result = qnn_interface_.profileGetEvents(profile_backend_handle_, &profile_events, &num_events); @@ -1629,34 +1676,35 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { LOGS(*logger_, VERBOSE) << "The QNN backend does not support extended event data."; } - std::ofstream outfile; + profiling_info.csv_output_filepath = profiling_file_path_; +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + profiling_info.num_events = num_events; +#endif + + profile::Serializer profile_writer(profiling_info, + qnn_sys_interface_, + tracelogging_provider_ep_enabled); if (!profiling_file_path_.empty()) { - // Write to CSV in append mode - std::ifstream infile(profiling_file_path_.c_str()); - bool exists = infile.good(); - infile.close(); - - outfile.open(profiling_file_path_, std::ios_base::app); - ORT_RETURN_IF(!outfile.is_open(), "Failed to open profiling file: ", profiling_file_path_); - // If file didn't exist before, write the header - if (!exists) { - outfile << "Msg Timestamp,Message,Time,Unit of Measurement,Timing Source,Event Level,Event Identifier\n"; - } + ORT_RETURN_IF_ERROR(profile_writer.InitCsvFile()); } for (size_t event_idx = 0; event_idx < num_events; event_idx++) { ORT_RETURN_IF_ERROR( - ExtractProfilingEvent(*(profile_events + event_idx), "ROOT", outfile, backendSupportsExtendedEventData, - tracelogging_provider_ep_enabled)); + ExtractProfilingEvent(*(profile_events + event_idx), "ROOT", profile_writer, + backendSupportsExtendedEventData)); ORT_RETURN_IF_ERROR( - ExtractProfilingSubEvents(*(profile_events + event_idx), outfile, backendSupportsExtendedEventData, - tracelogging_provider_ep_enabled)); + ExtractProfilingSubEvents(*(profile_events + event_idx), profile_writer, + backendSupportsExtendedEventData)); } +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + ORT_RETURN_IF_ERROR(profile_writer.SerializeEventsToQnnLog()); +#endif - if (outfile) { + if (!profiling_file_path_.empty()) { LOGS(*logger_, VERBOSE) << "Wrote QNN profiling events (" << num_events << ") to file (" << profiling_file_path_ << ")"; } + if (tracelogging_provider_ep_enabled) { LOGS(*logger_, VERBOSE) << "Wrote QNN profiling events (" << num_events << ") to ETW"; } @@ -1667,9 +1715,8 @@ Status QnnBackendManager::ExtractBackendProfilingInfo() { Status QnnBackendManager::ExtractProfilingSubEvents( QnnProfile_EventId_t profile_event_id, - std::ofstream& outfile, - bool useExtendedEventData, - bool tracelogging_provider_ep_enabled) { + profile::Serializer& profile_writer, + bool useExtendedEventData) { const QnnProfile_EventId_t* profile_sub_events{nullptr}; uint32_t num_sub_events{0}; Qnn_ErrorHandle_t result = qnn_interface_.profileGetSubEvents(profile_event_id, &profile_sub_events, &num_sub_events); @@ -1678,13 +1725,28 @@ Status QnnBackendManager::ExtractProfilingSubEvents( if (num_sub_events > 0) { LOGS(*logger_, VERBOSE) << "profile_sub_events: " << profile_sub_events << " num_sub_events: " << num_sub_events; +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + QnnSystemProfile_ProfileEventV1_t* parent_system_event = nullptr; + parent_system_event = profile_writer.GetParentSystemEvent(profile_event_id); + if (parent_system_event == nullptr) { + parent_system_event = profile_writer.GetSystemEventPointer(profile_event_id); + profile_writer.AddSubEventList(num_sub_events, parent_system_event); + } +#endif + for (size_t sub_event_idx = 0; sub_event_idx < num_sub_events; sub_event_idx++) { + QnnProfile_EventId_t subevent_id = *(profile_sub_events + sub_event_idx); + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + + ORT_RETURN_IF_ERROR(profile_writer.SetParentSystemEvent(subevent_id, parent_system_event)); + +#endif + ORT_RETURN_IF_ERROR( - ExtractProfilingEvent(*(profile_sub_events + sub_event_idx), "SUB-EVENT", outfile, useExtendedEventData, - tracelogging_provider_ep_enabled)); + ExtractProfilingEvent(subevent_id, "SUB-EVENT", profile_writer, useExtendedEventData)); ORT_RETURN_IF_ERROR( - ExtractProfilingSubEvents(*(profile_sub_events + sub_event_idx), outfile, useExtendedEventData, - tracelogging_provider_ep_enabled)); + ExtractProfilingSubEvents(subevent_id, profile_writer, useExtendedEventData)); } LOGS(*logger_, VERBOSE) << "Wrote QNN profiling sub events (" << num_sub_events << ")"; @@ -1695,167 +1757,44 @@ Status QnnBackendManager::ExtractProfilingSubEvents( Status QnnBackendManager::ExtractProfilingEvent( QnnProfile_EventId_t profile_event_id, - const std::string& eventLevel, - std::ofstream& outfile, - bool useExtendedEventData, - bool tracelogging_provider_ep_enabled) { + const std::string& event_level, + profile::Serializer& profile_writer, + bool useExtendedEventData) { if (useExtendedEventData) { - return ExtractProfilingEventExtended(profile_event_id, eventLevel, outfile, tracelogging_provider_ep_enabled); + return ExtractProfilingEventExtended(profile_event_id, event_level, profile_writer); } else { - return ExtractProfilingEventBasic(profile_event_id, eventLevel, outfile, tracelogging_provider_ep_enabled); + return ExtractProfilingEventBasic(profile_event_id, event_level, profile_writer); } } Status QnnBackendManager::ExtractProfilingEventBasic( QnnProfile_EventId_t profile_event_id, - const std::string& eventLevel, - std::ofstream& outfile, - bool tracelogging_provider_ep_enabled) { + const std::string& event_level, + profile::Serializer& profile_writer) { QnnProfile_EventData_t event_data; Qnn_ErrorHandle_t result = qnn_interface_.profileGetEventData(profile_event_id, &event_data); QnnProfile_Error_t errorCode = static_cast(result & 0xFFFF); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != result, "Failed to get profile event data: " + std::string(QnnProfileErrorToString(errorCode))); - std::string message = GetEventTypeString(event_data.type); - std::string unit = GetUnitString(event_data.unit); - - if (outfile) { - outfile << "UNKNOWN" - << "," - << message << "," - << event_data.value << "," - << unit << "," - << "BACKEND" - << "," - << eventLevel << "," - << (event_data.identifier ? event_data.identifier : "NULL") << "\n"; - } - - if (tracelogging_provider_ep_enabled) { -#ifdef _WIN32 - LogQnnProfileEventAsTraceLogging( - (uint64_t)0, - message, - std::to_string(event_data.value), - unit, - "BACKEND", - eventLevel, - (event_data.identifier ? event_data.identifier : "NULL")); -#endif - } + ORT_RETURN_IF_ERROR(profile_writer.ProcessEvent(profile_event_id, event_level, event_data)); return Status::OK(); } Status QnnBackendManager::ExtractProfilingEventExtended( QnnProfile_EventId_t profile_event_id, - const std::string& eventLevel, - std::ofstream& outfile, - bool tracelogging_provider_ep_enabled) { + const std::string& event_level, + profile::Serializer& profile_writer) { QnnProfile_ExtendedEventData_t event_data_extended; auto resultGetExtendedEventData = qnn_interface_.profileGetExtendedEventData(profile_event_id, &event_data_extended); QnnProfile_Error_t errorCode = static_cast(resultGetExtendedEventData & 0xFFFF); ORT_RETURN_IF(QNN_PROFILE_NO_ERROR != errorCode, "Failed to get profile event data: " + std::string(QnnProfileErrorToString(errorCode))); - // need to check the version first - std::string message = GetEventTypeString(event_data_extended.v1.type); - std::string unit = GetUnitString(event_data_extended.v1.unit); - - if (outfile) { - if (event_data_extended.version == QNN_PROFILE_DATA_VERSION_1) { - outfile << event_data_extended.v1.timestamp << "," - << message << "," - << ExtractQnnScalarValue(event_data_extended.v1.value) << "," - << unit << "," - << "BACKEND" - << "," - << eventLevel << "," - << (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL") << "\n"; - } - } - - if (tracelogging_provider_ep_enabled) { -#ifdef _WIN32 - LogQnnProfileEventAsTraceLogging( - event_data_extended.v1.timestamp, - message, - ExtractQnnScalarValue(event_data_extended.v1.value), - unit, - "BACKEND", - eventLevel, - (event_data_extended.v1.identifier ? event_data_extended.v1.identifier : "NULL")); -#endif - } + ORT_RETURN_IF_ERROR(profile_writer.ProcessExtendedEvent(profile_event_id, event_level, event_data_extended)); return Status::OK(); } -#ifdef _WIN32 -void QnnBackendManager::LogQnnProfileEventAsTraceLogging( - uint64_t timestamp, - const std::string& message, - const std::string& qnnScalarValue, - const std::string& unit, - const std::string& timingSource, - const std::string& eventLevel, - const char* eventIdentifier) { - QnnTelemetry& qnn_telemetry = QnnTelemetry::Instance(); - qnn_telemetry.LogQnnProfileEvent(timestamp, message, qnnScalarValue, unit, timingSource, eventLevel, eventIdentifier); -} -#endif - -const std::string& QnnBackendManager::GetUnitString(QnnProfile_EventUnit_t unitType) { - const auto& unitStringMap = GetUnitStringMap(); - auto it = unitStringMap.find(unitType); - if (it != unitStringMap.end()) { - return it->second; - } - static const std::string unknown = "UNKNOWN"; - return unknown; -} - -const std::unordered_map& QnnBackendManager::GetUnitStringMap() { - static const std::unordered_map unitStringMap = { - {QNN_PROFILE_EVENTUNIT_MICROSEC, "US"}, - {QNN_PROFILE_EVENTUNIT_BYTES, "BYTES"}, - {QNN_PROFILE_EVENTUNIT_CYCLES, "CYCLES"}, - {QNN_PROFILE_EVENTUNIT_COUNT, "COUNT"}, - {QNN_PROFILE_EVENTUNIT_OBJECT, "OBJECT"}, - {QNN_PROFILE_EVENTUNIT_BACKEND, "BACKEND"}}; - return unitStringMap; -} - -const std::string QnnBackendManager::GetEventTypeString(QnnProfile_EventType_t eventType) { - // Interpret the event type - switch (eventType) { - case QNN_PROFILE_EVENTTYPE_INIT: - return "INIT"; - case QNN_PROFILE_EVENTTYPE_FINALIZE: - return "FINALIZE"; - case QNN_PROFILE_EVENTTYPE_EXECUTE: - return "EXECUTE"; - case QNN_PROFILE_EVENTTYPE_NODE: - return "NODE"; - case QNN_PROFILE_EVENTTYPE_EXECUTE_QUEUE_WAIT: - return "EXECUTE QUEUE WAIT"; - case QNN_PROFILE_EVENTTYPE_EXECUTE_PREPROCESS: - return "EXECUTE PREPROCESS"; - case QNN_PROFILE_EVENTTYPE_EXECUTE_DEVICE: - return "EXECUTE DEVICE"; - case QNN_PROFILE_EVENTTYPE_EXECUTE_POSTPROCESS: - return "EXECUTE POSTPROCESS"; - case QNN_PROFILE_EVENTTYPE_DEINIT: - return "DE-INIT"; - case QNN_PROFILE_EVENTTYPE_BACKEND: - return "BACKEND"; - default: - if (eventType > QNN_PROFILE_EVENTTYPE_BACKEND) { - return "BACKEND"; - } - return "UNKNOWN"; - } -} - const char* QnnBackendManager::QnnProfileErrorToString(QnnProfile_Error_t error) { switch (error) { case QNN_PROFILE_NO_ERROR: @@ -1881,45 +1820,6 @@ std::string QnnBackendManager::QnnErrorHandleToString(Qnn_ErrorHandle_t error) { return utils::GetQnnErrorMessage(qnn_interface_, error); } -const std::string QnnBackendManager::ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) { - switch (scalar.dataType) { - case QNN_DATATYPE_INT_8: - return std::to_string(static_cast(scalar.int8Value)); - case QNN_DATATYPE_INT_16: - return std::to_string(scalar.int16Value); - case QNN_DATATYPE_INT_32: - return std::to_string(scalar.int32Value); - case QNN_DATATYPE_INT_64: - return std::to_string(scalar.int64Value); - case QNN_DATATYPE_UINT_8: - return std::to_string(static_cast(scalar.uint8Value)); - case QNN_DATATYPE_UINT_16: - return std::to_string(scalar.uint16Value); - case QNN_DATATYPE_UINT_32: - return std::to_string(scalar.uint32Value); - case QNN_DATATYPE_UINT_64: - return std::to_string(scalar.uint64Value); - case QNN_DATATYPE_FLOAT_16: - return std::to_string(scalar.floatValue); - case QNN_DATATYPE_FLOAT_32: - return std::to_string(scalar.floatValue); - case QNN_DATATYPE_SFIXED_POINT_8: - case QNN_DATATYPE_SFIXED_POINT_16: - case QNN_DATATYPE_SFIXED_POINT_32: - return std::to_string(scalar.int32Value); // Assume using int types for signed fixed points. - case QNN_DATATYPE_UFIXED_POINT_8: - case QNN_DATATYPE_UFIXED_POINT_16: - case QNN_DATATYPE_UFIXED_POINT_32: - return std::to_string(scalar.uint32Value); // Assume using unsigned int types for unsigned fixed points. - case QNN_DATATYPE_BOOL_8: - return scalar.bool8Value ? "true" : "false"; - case QNN_DATATYPE_STRING: - return scalar.stringValue ? scalar.stringValue : "NULL"; - default: - return "UNKNOWN"; - } -} - QnnBackendManager::~QnnBackendManager() { ReleaseResources(); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 2a71c7391b180..22d5993b68291 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -27,6 +27,7 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_context_mem_handle_manager.h" #include "core/providers/qnn/builder/qnn_def.h" +#include "core/providers/qnn/builder/qnn_profile_serializer.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" namespace onnxruntime { @@ -61,6 +62,13 @@ class QnnSerializerConfig { */ void SetGraphName(std::string graph_name); + /** + * Gets the name of the graph being serialized. + * + * \return graph_name The name of the graph being serialized. + */ + const std::string& GetGraphName() const; + /** * Get any QNN Graph configs required to configure this serializer and perform any * preparation, such as creating output directories. @@ -83,7 +91,6 @@ class QnnSerializerConfig { protected: QnnSerializerConfig(std::string backend_path); - const std::string& GetGraphName() const; private: std::string backend_path_; @@ -183,12 +190,13 @@ class QnnBackendManager : public std::enable_shared_from_this // NOTE: This function locks the internal `logger_recursive_mutex_`. Status ResetQnnLogLevel(std::optional ort_log_level = std::nullopt); - Status ExtractBackendProfilingInfo(); - Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, std::ofstream& outfile, - bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled); + Status ExtractBackendProfilingInfo(qnn::profile::ProfilingInfo& profiling_info); + + Status ExtractProfilingSubEvents(QnnProfile_EventId_t profile_event_id, profile::Serializer& profile_writer, + bool backendSupportsExtendedEventData); + Status ExtractProfilingEvent(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, - std::ofstream& outfile, bool backendSupportsExtendedEventData, - bool tracelogging_provider_ep_enabled); + profile::Serializer& profile_writer, bool backendSupportsExtendedEventData); Status SetProfilingLevelETW(ProfilingLevel profiling_level_etw_param); @@ -225,6 +233,10 @@ class QnnBackendManager : public std::enable_shared_from_this // Resets the context priority to the session default as defined by context_priority_ Status ResetContextPriority(); +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + bool ProfilingEnabled() { return profiling_enabled_; } +#endif + private: Status LoadBackend(); @@ -307,26 +319,14 @@ class QnnBackendManager : public std::enable_shared_from_this } Status ExtractProfilingEventBasic(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, - std::ofstream& outfile, bool tracelogging_provider_ep_enabled); + profile::Serializer& profile_writer); + Status ExtractProfilingEventExtended(QnnProfile_EventId_t profile_event_id, const std::string& eventLevel, - std::ofstream& outfile, bool tracelogging_provider_ep_enabled); - static const std::string& GetUnitString(QnnProfile_EventUnit_t unitType); - static const std::unordered_map& GetUnitStringMap(); - static const std::string GetEventTypeString(QnnProfile_EventType_t eventType); - static const std::string ExtractQnnScalarValue(const Qnn_Scalar_t& scalar); + profile::Serializer& profile_writer); + const char* QnnProfileErrorToString(QnnProfile_Error_t error); std::string QnnErrorHandleToString(Qnn_ErrorHandle_t error); QnnLog_Level_t MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level); -#ifdef _WIN32 - void LogQnnProfileEventAsTraceLogging( - uint64_t timestamp, - const std::string& message, - const std::string& qnnScalarValue, - const std::string& unit, - const std::string& timingSource, - const std::string& eventLevel, - const char* eventIdentifier); -#endif // Adds a new QNN context. // Transfers ownership of `context_handle` (i.e., responsibility of freeing it) to this instance @@ -437,6 +437,12 @@ class QnnBackendManager : public std::enable_shared_from_this ProfilingLevel profiling_level_; ProfilingLevel profiling_level_merge_; const std::string profiling_file_path_; + bool system_lib_loaded_ = false; + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + bool profiling_enabled_ = false; +#endif + bool backend_initialized_ = false; bool device_created_ = false; bool context_created_ = false; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_def.h b/onnxruntime/core/providers/qnn/builder/qnn_def.h index 6fba6d847cb74..42f4d7bb60f34 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_def.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_def.h @@ -14,6 +14,11 @@ namespace onnxruntime { namespace qnn { +#if QNN_API_VERSION_MAJOR > 2 || \ + (QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 29)) +#define QNN_SYSTEM_PROFILE_API_ENABLED +#endif + // QNN only support subset of POSIX of dlopen/dlsym/dladdr/dlerror/dlclose // except the following flags for dlopen, others should be done only // when we really need them @@ -32,9 +37,26 @@ enum class ProfilingLevel : uint8_t { OFF = 0, BASIC, DETAILED, + OPTRACE, INVALID }; +enum class ProfilingMethodType : uint8_t { + UNKNOWN = 0, + EXECUTE, + FINALIZE, + EXECUTE_ASYNC, + CREATE_FROM_BINARY, + DEINIT, + CONTEXT_CREATE, + COMPOSE_GRAPHS, + EXECUTE_IPS, + GRAPH_COMPONENT, + LIB_LOAD, + APPLY_BINARY_SECTION, + CONTEXT_FINALIZE +}; + // Defines performance modes available for HTP backend. enum class HtpPerformanceMode : uint8_t { kHtpDefault = 0, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index 175a76b590895..85901ab6fdfec 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -10,6 +10,7 @@ #include "core/providers/qnn/builder/op_builder_factory.h" #include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/builder/qnn_profile_serializer.h" #include "core/providers/qnn/builder/qnn_utils.h" #include "core/providers/qnn/ort_api.h" #include "core/providers/qnn/qnn_allocator.h" @@ -105,7 +106,6 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, // This name must be same with the EPContext node name const auto& graph_name = fused_node.Name(); ORT_RETURN_IF_ERROR(SetGraphInputOutputInfo(graph_viewer, fused_node, logger)); - QnnModelWrapper qnn_model_wrapper = QnnModelWrapper(graph_viewer, logger, qnn_backend_manager_->GetQnnInterface(), qnn_backend_manager_->GetQnnBackendHandle(), @@ -114,11 +114,33 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, qnn_backend_manager_->GetQnnBackendType(), model_settings); bool rt = true; + + qnn::profile::ProfilingInfo profiling_info; +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (qnn_backend_manager_->ProfilingEnabled()) { + profiling_info.graph_name = graph_name; + profiling_info.start_time = qnn::utils::GetTimeStampInUs(); + } +#endif + rt = qnn_model_wrapper.CreateQnnGraph(qnn_backend_manager_->GetQnnContext(), graph_name, graph_configs); + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (qnn_backend_manager_->ProfilingEnabled()) { + profiling_info.stop_time = qnn::utils::GetTimeStampInUs(); + profiling_info.method_type = ProfilingMethodType::COMPOSE_GRAPHS; + } +#endif + if (!rt) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to initialize qnn_model_wrapper."); } + // NOTE: This function returns immediately when profiling is disabled. + // Extracting profiling data can be expensive, but it is typically only enabled for debugging purposes + // and not in production. We can improve synchronization for event profiling if it becomes an issue. + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo(profiling_info)); + std::vector> qnn_node_groups; qnn_node_groups.reserve(node_unit_holder.size()); @@ -160,15 +182,35 @@ Status QnnModel::ComposeGraph(const GraphViewer& graph_viewer, Status QnnModel::FinalizeGraphs(const logging::Logger& logger) { LOGS(logger, VERBOSE) << "FinalizeGraphs started."; + + qnn::profile::ProfilingInfo profiling_info; +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (qnn_backend_manager_->ProfilingEnabled()) { + profiling_info.start_time = qnn::utils::GetTimeStampInUs(); + } +#endif + Qnn_ErrorHandle_t status = qnn_backend_manager_->GetQnnInterface().graphFinalize(graph_info_->Graph(), qnn_backend_manager_->GetQnnProfileHandle(), nullptr); + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (qnn_backend_manager_->ProfilingEnabled()) { + profiling_info.stop_time = qnn::utils::GetTimeStampInUs(); + profiling_info.method_type = ProfilingMethodType::FINALIZE; + profiling_info.graph_name = graph_info_->Name(); + } +#endif + if (QNN_GRAPH_NO_ERROR != status) { LOGS(logger, ERROR) << "Failed to finalize QNN graph. Error code: " << status; return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to finalize QNN graph."); } - ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); + // NOTE: This function returns immediately when profiling is disabled. + // Extracting profiling data can be expensive, but it is typically only enabled for debugging purposes + // and not in production. We can improve synchronization for event profiling if it becomes an issue. + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo(profiling_info)); LOGS(logger, VERBOSE) << "FinalizeGraphs completed."; return Status::OK(); @@ -297,6 +339,14 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, std::lock_guard lock(graph_exec_mutex_); LOGS(logger, VERBOSE) << "Start execute QNN graph:" << graph_info_->Name(); + + qnn::profile::ProfilingInfo profiling_info; +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (qnn_backend_manager_->ProfilingEnabled()) { + profiling_info.start_time = qnn::utils::GetTimeStampInUs(); + } +#endif + auto profile_backend_handle = qnn_backend_manager_->GetQnnProfileHandle(); execute_status = qnn_interface.graphExecute(graph_info_->Graph(), qnn_inputs.data(), @@ -305,11 +355,18 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context, static_cast(qnn_outputs.size()), profile_backend_handle, nullptr); +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + if (qnn_backend_manager_->ProfilingEnabled()) { + profiling_info.stop_time = qnn::utils::GetTimeStampInUs(); + profiling_info.method_type = ProfilingMethodType::EXECUTE; + profiling_info.graph_name = graph_info_->Name(); + } +#endif // NOTE: This function returns immediately when profiling is disabled. // Extracting profiling data can be expensive, but it is typically only enabled for debugging purposes // and not in production. We can improve synchronization for event profiling if it becomes an issue. - ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo()); + ORT_RETURN_IF_ERROR(qnn_backend_manager_->ExtractBackendProfilingInfo(profiling_info)); } if (QNN_COMMON_ERROR_SYSTEM_COMMUNICATION == execute_status) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc new file mode 100644 index 0000000000000..0a3a592a44906 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.cc @@ -0,0 +1,426 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_utils.h" +#include "core/providers/qnn/builder/op_builder_factory.h" +#include "core/providers/qnn/builder/qnn_model_wrapper.h" +#include "core/providers/qnn/builder/qnn_node_group/utils.h" + +namespace onnxruntime { +namespace qnn { + +// Helper function to extract value from raw data based on QNN data type +static Status GetValueOnQnnDataType(const Qnn_DataType_t qnn_data_type, + const uint8_t* raw_ptr, + double& value) { + switch (qnn_data_type) { + case QNN_DATATYPE_INT_8: + case QNN_DATATYPE_SFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_16: + case QNN_DATATYPE_SFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_INT_32: + case QNN_DATATYPE_SFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_8: + case QNN_DATATYPE_UFIXED_POINT_8: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_16: + case QNN_DATATYPE_UFIXED_POINT_16: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_UINT_32: + case QNN_DATATYPE_UFIXED_POINT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_32: { + value = static_cast(*reinterpret_cast(raw_ptr)); + break; + } + case QNN_DATATYPE_FLOAT_16: { + value = static_cast(reinterpret_cast(raw_ptr)->ToFloat()); + break; + } + default: + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Qnn Data Type: ", qnn_data_type, " not supported."); + } + return Status::OK(); +} + +// Helper function to extract a scalar float value from a constant initializer +// Handles both float and quantized (INT type) constant inputs +static std::optional GetConstantInitializerFloatScalar(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def) { + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + const auto& name = io_def.node_arg.Name(); + + if (!graph_viewer.IsConstantInitializer(name, true)) { + return std::nullopt; + } + + // Get tensor info to check if it's quantized + TensorInfo tensor_info = {}; + if (!qnn_model_wrapper.GetTensorInfo(io_def, tensor_info).IsOK()) { + return std::nullopt; + } + + // Must be an initializer + if (!tensor_info.is_initializer || !tensor_info.initializer_tensor) { + return std::nullopt; + } + + // Unpack the initializer data + std::vector unpacked_tensor; + if (!qnn_model_wrapper.UnpackInitializerData(*tensor_info.initializer_tensor, unpacked_tensor).IsOK()) { + return std::nullopt; + } + + if (unpacked_tensor.empty()) { + return std::nullopt; + } + + // Extract the value using GetValueOnQnnDataType + double extracted_value = 0.0; + if (!GetValueOnQnnDataType(tensor_info.qnn_data_type, unpacked_tensor.data(), extracted_value).IsOK()) { + return std::nullopt; + } + + // Check if quantized and dequantize if needed + const bool is_quantized = tensor_info.quant_param.IsQuantized(); + if (is_quantized) { + // For quantized tensors, dequantize the value + if (!tensor_info.quant_param.IsPerTensor()) { + return std::nullopt; // Only support per-tensor quantization + } + + const Qnn_QuantizeParams_t& quant_param = tensor_info.quant_param.Get(); + double dequantized_value = utils::Dequantize(quant_param.scaleOffsetEncoding.offset, + quant_param.scaleOffsetEncoding.scale, + extracted_value); + return static_cast(dequantized_value); + } + + // For non-quantized tensors, return the extracted value directly + return static_cast(extracted_value); +} + +// Helper function to check if a constant initializer has the expected float value +static bool IsInitializerWithExpectedValue(QnnModelWrapper& qnn_model_wrapper, + const NodeUnitIODef& io_def, + float expected_value, + float tolerance = 1e-5f) { + std::optional actual_value = GetConstantInitializerFloatScalar(qnn_model_wrapper, io_def); + if (!actual_value.has_value()) { + return false; + } + + // Compare with expected value within tolerance + return std::abs(actual_value.value() - expected_value) <= tolerance; +} + +// Forward declaration. +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate); + +// Helper function to validate on QNN +static Status ValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, true); +} + +// Helper function to create on QNN +static Status CreateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output) { + return CreateOrValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output, false); +} + +std::unique_ptr GeluFusion::TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger) { + ORT_UNUSED_PARAMETER(logger); + if (erf_node_unit.OpType() != "Erf") { + return nullptr; + } + + const GraphViewer& graph_viewer = qnn_model_wrapper.GetGraphViewer(); + + const auto& erf_inputs = erf_node_unit.Inputs(); + if (erf_inputs.empty()) { + return nullptr; + } + + const NodeUnit* div_node_unit = GetParentOfInput(graph_viewer, erf_node_unit, erf_inputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (div_node_unit == nullptr || div_node_unit->OpType() != "Div") { + return nullptr; + } + + // Div must have 2 inputs + const auto& div_inputs = div_node_unit->Inputs(); + if (div_inputs.size() < 2) { + return nullptr; + } + + // Check second input of Div is sqrt(2) ≈ 1.4142 + // Use a larger tolerance to handle approximations used in some models + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, div_inputs[1], static_cast(M_SQRT2), 1e-4f)) { + return nullptr; + } + + // Erf must have an Add child consuming its output + const auto& erf_outputs = erf_node_unit.Outputs(); + if (erf_outputs.empty()) { + return nullptr; + } + + const NodeUnit* add_node_unit = GetChildOfOutput(graph_viewer, erf_node_unit, erf_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (add_node_unit == nullptr || add_node_unit->OpType() != "Add") { + return nullptr; + } + + // Add must have 2 inputs + const auto& add_inputs = add_node_unit->Inputs(); + if (add_inputs.size() < 2) { + return nullptr; + } + + // Check the other input node (e.g. not the Erf) is 1.0f + bool is_erf_first_input = (add_inputs[0].node_arg.Name() == erf_outputs[0].node_arg.Name()); + const auto& add_const_input = add_inputs[is_erf_first_input ? 1 : 0]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, add_const_input, 1.0f)) { + return nullptr; + } + + // Add must have a Mul child consuming its output + const auto& add_outputs = add_node_unit->Outputs(); + if (add_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul_node_unit = GetChildOfOutput(graph_viewer, *add_node_unit, add_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul_node_unit == nullptr || mul_node_unit->OpType() != "Mul") { + return nullptr; + } + + // Now check which pattern we have + const auto& root_input_name = div_inputs[0].node_arg.Name(); + const auto& mul_inputs = mul_node_unit->Inputs(); + + if (mul_inputs.size() < 2) { + return nullptr; + } + + // Try to match Pattern 1: root -> Mul(0.5) -> ... -> Mul + // In this case, one input to the final Mul should be from a Mul node + const NodeUnit* mul2_node_unit = nullptr; + + // Check if either input to mul_node_unit comes from a Mul node + for (size_t i = 0; i < 2; ++i) { + const auto& mul_input_name = mul_inputs[i].node_arg.Name(); + + // Find the node that produces this input + for (const auto& node_index : graph_viewer.GetNodesInTopologicalOrder()) { + const Node* node = graph_viewer.GetNode(node_index); + if (node == nullptr) continue; + + // Check if this node's output matches our input + for (const auto* output_def : node->OutputDefs()) { + if (output_def && output_def->Name() == mul_input_name) { + // Found the producer node, check if it's a Mul + auto it = node_to_node_unit.find(node); + if (it != node_to_node_unit.end()) { + const NodeUnit* producer_unit = it->second; + if (producer_unit->OpType() == "Mul" && + node_unit_to_qnn_node_group.find(producer_unit) == node_unit_to_qnn_node_group.end()) { + // Check if this Mul has root as one input + const auto& mul2_inputs = producer_unit->Inputs(); + if (mul2_inputs.size() >= 2) { + bool has_root_input = (mul2_inputs[0].node_arg.Name() == root_input_name || + mul2_inputs[1].node_arg.Name() == root_input_name); + + if (has_root_input) { + // Check the other input is 0.5f + int root_index = (mul2_inputs[0].node_arg.Name() == root_input_name) ? 0 : 1; + const auto& mul_const_input = mul2_inputs[1 - root_index]; + + if (IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + mul2_node_unit = producer_unit; + break; + } + } + } + } + } + } + } + if (mul2_node_unit != nullptr) break; + } + if (mul2_node_unit != nullptr) break; + } + + std::vector node_units; + const NodeUnit* final_mul_node_unit = nullptr; + + if (mul2_node_unit != nullptr) { + // Pattern 1: root -> Mul(0.5) -> ... -> Mul + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul2_node_unit, mul_node_unit}; + final_mul_node_unit = mul_node_unit; + } else { + // Try Pattern 2: root -> ... -> Mul -> Mul(0.5) + // Check if one input to mul_node_unit is root + bool has_root_input = (mul_inputs[0].node_arg.Name() == root_input_name || + mul_inputs[1].node_arg.Name() == root_input_name); + + if (!has_root_input) { + return nullptr; + } + + // mul_node_unit must have a Mul child consuming its output + const auto& mul_outputs = mul_node_unit->Outputs(); + if (mul_outputs.empty()) { + return nullptr; + } + + const NodeUnit* mul2_node_unit_pattern2 = GetChildOfOutput(graph_viewer, *mul_node_unit, mul_outputs[0], + node_to_node_unit, node_unit_to_qnn_node_group); + if (mul2_node_unit_pattern2 == nullptr || mul2_node_unit_pattern2->OpType() != "Mul") { + return nullptr; + } + + // Verify this final Mul has 2 inputs + const auto& mul2_inputs = mul2_node_unit_pattern2->Inputs(); + if (mul2_inputs.size() < 2) { + return nullptr; + } + + // Check the constant input is 0.5f + int mul_const_input_index = 0; + if (mul2_inputs[0].node_arg.Name() == mul_outputs[0].node_arg.Name()) { + mul_const_input_index = 1; + } + const auto& mul_const_input = mul2_inputs[mul_const_input_index]; + if (!IsInitializerWithExpectedValue(qnn_model_wrapper, mul_const_input, 0.5f)) { + return nullptr; + } + + // Pattern 2 + node_units = {div_node_unit, &erf_node_unit, add_node_unit, mul_node_unit, mul2_node_unit_pattern2}; + final_mul_node_unit = mul2_node_unit_pattern2; + } + + // Validate on QNN + const NodeUnitIODef& root_input = div_inputs[0]; + const NodeUnitIODef& final_output = final_mul_node_unit->Outputs()[0]; + + if (Status status = ValidateOnQnn(qnn_model_wrapper, node_units, root_input, final_output); + !status.IsOK()) { + return nullptr; + } + + return std::make_unique(std::move(node_units), &erf_node_unit); +} + +GeluFusion::GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit) + : node_units_(std::move(node_units)), target_node_unit_(target_node_unit) { +} + +Status GeluFusion::IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return ValidateOnQnn(qmw, node_units_, root_input, final_output); +} + +Status GeluFusion::AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const { + ORT_UNUSED_PARAMETER(logger); + const NodeUnitIODef& root_input = node_units_[0]->Inputs()[0]; + const NodeUnitIODef& final_output = node_units_.back()->Outputs()[0]; + return CreateOnQnn(qmw, node_units_, root_input, final_output); +} + +gsl::span GeluFusion::GetNodeUnits() const { + return gsl::span(node_units_.data(), node_units_.size()); +} + +const NodeUnit* GeluFusion::GetTargetNodeUnit() const { + return target_node_unit_; +} + +static Status CreateOrValidateOnQnn(QnnModelWrapper& qnn_model_wrapper, + gsl::span node_units, + const NodeUnitIODef& root_input, + const NodeUnitIODef& final_output, + bool validate) { + assert(node_units.size() >= 4); + const auto& node_name = utils::GetUniqueName(*node_units[0]); + + QnnTensorWrapper input_tensor; + QnnTensorWrapper output_tensor; + + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(root_input, input_tensor)); + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(final_output, output_tensor)); + + if (validate) { + ORT_RETURN_IF_ERROR(qnn_model_wrapper.ValidateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {input_tensor.GetQnnTensor()}, + {output_tensor.GetQnnTensor()}, + {})); + } else { + // Only add tensor wrappers if they don't already exist + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(root_input.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor)), "Failed to add input"); + } + if (!qnn_model_wrapper.IsQnnTensorWrapperExist(final_output.node_arg.Name())) { + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), "Failed to add output"); + } + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(node_name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + QNN_OP_GELU, + {root_input.node_arg.Name()}, + {final_output.node_arg.Name()}, + {}, + validate), + "Failed to add fused Gelu node."); + } + + return Status::OK(); +} + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h new file mode 100644 index 0000000000000..58089089c3444 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/gelu_fusion.h @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/providers/qnn/builder/qnn_node_group/qnn_node_group.h" +#include "core/providers/qnn/ort_api.h" + +namespace onnxruntime { +namespace qnn { + +class QnnModelWrapper; + +/// +/// Represents a fusion of the Gelu pattern expanded into ONNX operators. +/// This fusion handles two patterns: +/// Pattern 1: root -> Div -> Erf -> Add -> Mul (with Mul from root) +/// Pattern 2: root -> Div -> Erf -> Add -> Mul -> Mul +/// Both patterns are translated into a QNN Gelu operator. +/// The contained NodeUnits can be of type SingleNode or QDQGroup (with Q-DQ nodes). +/// The second inputs to Div, Add, and Mul operations can be either constant or non-constant tensors. +/// +class GeluFusion : public IQnnNodeGroup { + public: + GeluFusion(std::vector&& node_units, const NodeUnit* target_node_unit); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(GeluFusion); + + Status IsSupported(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + Status AddToModelBuilder(QnnModelWrapper& qmw, const logging::Logger& logger) const override; + gsl::span GetNodeUnits() const override; + const NodeUnit* GetTargetNodeUnit() const override; + std::string_view Type() const override { return "GeluFusion"; } + + /// + /// Traverses graph to check if the given starting NodeUnit is part of a valid Gelu pattern. + /// If so, returns a IQnnNodeGroup that contains all the NodeUnits in the pattern. + /// + /// Used for validation and traverse/query the graph + /// Erf node unit that could be part of the sequence + /// Maps a Node to a NodeUnit. + /// Maps a NodeUnit to a IQnnNodeGroup. + /// + /// A valid IQnnNodeGroup on success or an empty std::unique_ptr otherwise + static std::unique_ptr TryFusion( + QnnModelWrapper& qnn_model_wrapper, + const NodeUnit& erf_node_unit, + const std::unordered_map& node_to_node_unit, + const std::unordered_map& node_unit_to_qnn_node_group, + const logging::Logger& logger); + + private: + std::vector node_units_; + const NodeUnit* target_node_unit_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc index 368caa518b7ba..135a3fed1e577 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc @@ -21,6 +21,7 @@ #include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h" +#include "core/providers/qnn/builder/qnn_node_group/gelu_fusion.h" #include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h" #include "core/providers/qnn/builder/qnn_utils.h" @@ -83,7 +84,9 @@ static std::unordered_map> fusions = { {"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}}, {"Mul", {ScaleSoftmaxFusion::TryFusion}}, {"Cast", {CastLoneQFusion::TryFusion}}, + {"Erf", {GeluFusion::TryFusion}}, {"Reshape", {Rank6ToRank5Fusion::TryFusion}}, + {"Erf", {GeluFusion::TryFusion}}, {"Transpose", {ChannelShuffleFusion::TryFusion}}}; void registerUDO(const std::string& node_type, const std::string& op_package) { @@ -119,9 +122,11 @@ static std::unique_ptr TryQnnFusions( const std::unordered_map& node_to_node_unit, const std::unordered_map& node_unit_to_qnn_node_group, const logging::Logger& logger) { - // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings and Reshape + // For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings, + // Erf and Reshape if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode && starting_node_unit.OpType() != "MatMul" && + starting_node_unit.OpType() != "Erf" && starting_node_unit.OpType() != "Reshape") { return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc index 10e1633e4b57d..c4462e1fc0a42 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.cc @@ -226,12 +226,76 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, return nullptr; } - // parent must not already be part of a QDQ NodeUnit (i.e., be standalone). - if (p_parent_node_unit->UnitType() != NodeUnit::Type::SingleNode) { + return p_parent_node_unit; + } + return nullptr; +} + +const NodeUnit* GetChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map) { + const Node* p_parent_node = nullptr; + + for (auto node : node_unit.GetAllNodesInGroup()) { + for (auto node_output : node->OutputDefs()) { + if (node_output->Name() == output.node_arg.Name()) { + p_parent_node = node; + break; + } + + if (p_parent_node != nullptr) { + break; + } + } + } + + if (p_parent_node == nullptr) { + return nullptr; + } + + const Node& parent_node = *p_parent_node; + + if (graph_viewer.NodeProducesGraphOutput(parent_node)) { + // Node is producing a graph output + return nullptr; + } + + for (auto edge = parent_node.OutputEdgesBegin(); edge != parent_node.OutputEdgesEnd(); ++edge) { + const Node& child_node = edge->GetNode(); + + // Check if this edge corresponds to the output we're looking for + bool is_matching_output = false; + for (auto child_input : child_node.InputDefs()) { + if (child_input->Name() == output.node_arg.Name()) { + is_matching_output = true; + break; + } + } + + if (!is_matching_output) { + continue; + } + + if (graph_viewer.GetNode(child_node.Index()) == nullptr) { + // Node is not in this GraphViewer return nullptr; } - return p_parent_node_unit; + const auto child_node_unit_it = node_unit_map.find(&child_node); + if (child_node_unit_it == node_unit_map.end()) { + return nullptr; + } + const NodeUnit* p_child_node_unit = child_node_unit_it->second; + + // Check if child node has already been handled. Should not be the case if the calling + // fusion function has been called in topological order, but check to be safe. + if (qnn_node_group_map.count(p_child_node_unit) != 0) { + return nullptr; + } + + return p_child_node_unit; } return nullptr; } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h index 14e2a3f25e7db..685b974f0a55a 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_node_group/utils.h @@ -51,5 +51,11 @@ const NodeUnit* GetParentOfInput(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const std::unordered_map& qnn_node_group_map); +const NodeUnit* GetChildOfOutput(const GraphViewer& graph_viewer, + const NodeUnit& node_unit, + const NodeUnitIODef& output, + const std::unordered_map& node_unit_map, + const std::unordered_map& qnn_node_group_map); + } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc new file mode 100644 index 0000000000000..fb76f2110cbc8 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.cc @@ -0,0 +1,465 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "qnn_profile_serializer.h" +#include "core/providers/qnn/qnn_telemetry.h" + +namespace onnxruntime { +namespace qnn { +namespace profile { + +const std::unordered_map& GetUnitStringMap() { + static const std::unordered_map unitStringMap = { + {QNN_PROFILE_EVENTUNIT_MICROSEC, "US"}, + {QNN_PROFILE_EVENTUNIT_BYTES, "BYTES"}, + {QNN_PROFILE_EVENTUNIT_CYCLES, "CYCLES"}, + {QNN_PROFILE_EVENTUNIT_COUNT, "COUNT"}, + {QNN_PROFILE_EVENTUNIT_OBJECT, "OBJECT"}, + {QNN_PROFILE_EVENTUNIT_BACKEND, "BACKEND"}}; + return unitStringMap; +} + +const std::string& GetUnitString(QnnProfile_EventUnit_t unitType) { + const auto& unitStringMap = GetUnitStringMap(); + auto it = unitStringMap.find(unitType); + if (it != unitStringMap.end()) { + return it->second; + } + static const std::string unknown = "UNKNOWN"; + return unknown; +} + +std::string GetEventTypeString(QnnProfile_EventType_t event_type) { + // Interpret the event type + switch (event_type) { + case QNN_PROFILE_EVENTTYPE_INIT: + return "INIT"; + case QNN_PROFILE_EVENTTYPE_FINALIZE: + return "FINALIZE"; + case QNN_PROFILE_EVENTTYPE_EXECUTE: + return "EXECUTE"; + case QNN_PROFILE_EVENTTYPE_NODE: + return "NODE"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_QUEUE_WAIT: + return "EXECUTE QUEUE WAIT"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_PREPROCESS: + return "EXECUTE PREPROCESS"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_DEVICE: + return "EXECUTE DEVICE"; + case QNN_PROFILE_EVENTTYPE_EXECUTE_POSTPROCESS: + return "EXECUTE POSTPROCESS"; + case QNN_PROFILE_EVENTTYPE_DEINIT: + return "DE-INIT"; + case QNN_PROFILE_EVENTTYPE_BACKEND: + return "BACKEND"; + default: + if (event_type > QNN_PROFILE_EVENTTYPE_BACKEND) { + return "BACKEND"; + } + return "UNKNOWN"; + } +} + +std::string ExtractQnnScalarValue(const Qnn_Scalar_t& scalar) { + switch (scalar.dataType) { + case QNN_DATATYPE_INT_8: + return std::to_string(static_cast(scalar.int8Value)); + case QNN_DATATYPE_INT_16: + return std::to_string(scalar.int16Value); + case QNN_DATATYPE_INT_32: + return std::to_string(scalar.int32Value); + case QNN_DATATYPE_INT_64: + return std::to_string(scalar.int64Value); + case QNN_DATATYPE_UINT_8: + return std::to_string(static_cast(scalar.uint8Value)); + case QNN_DATATYPE_UINT_16: + return std::to_string(scalar.uint16Value); + case QNN_DATATYPE_UINT_32: + return std::to_string(scalar.uint32Value); + case QNN_DATATYPE_UINT_64: + return std::to_string(scalar.uint64Value); + case QNN_DATATYPE_FLOAT_16: + return std::to_string(scalar.floatValue); + case QNN_DATATYPE_FLOAT_32: + return std::to_string(scalar.floatValue); + case QNN_DATATYPE_SFIXED_POINT_8: + case QNN_DATATYPE_SFIXED_POINT_16: + case QNN_DATATYPE_SFIXED_POINT_32: + return std::to_string(scalar.int32Value); // Assume using int types for signed fixed points. + case QNN_DATATYPE_UFIXED_POINT_8: + case QNN_DATATYPE_UFIXED_POINT_16: + case QNN_DATATYPE_UFIXED_POINT_32: + return std::to_string(scalar.uint32Value); // Assume using unsigned int types for unsigned fixed points. + case QNN_DATATYPE_BOOL_8: + return scalar.bool8Value ? "true" : "false"; + case QNN_DATATYPE_STRING: + return scalar.stringValue ? scalar.stringValue : "NULL"; + default: + return "UNKNOWN"; + } +} + +#ifdef _WIN32 +void Serializer::LogQnnProfileEventAsTraceLogging( + uint64_t timestamp, + const std::string& message, + const std::string& qnnScalarValue, + const std::string& unit, + const std::string& timingSource, + const std::string& event_level, + const char* eventIdentifier) { + QnnTelemetry& qnn_telemetry = QnnTelemetry::Instance(); + qnn_telemetry.LogQnnProfileEvent(timestamp, message, qnnScalarValue, unit, timingSource, event_level, eventIdentifier); +} +#endif + +Status Serializer::ProcessEvent(const QnnProfile_EventId_t event_id, const std::string& event_level, + const QnnProfile_EventData_t& event_data) { + const std::string& message = GetEventTypeString(event_data.type); + const std::string& unit = GetUnitString(event_data.unit); + + ORT_UNUSED_PARAMETER(event_id); + + if (outfile_) { + outfile_ << "UNKNOWN" + << "," + << message << "," + << event_data.value << "," + << unit << "," + << "BACKEND" + << "," + << event_level << "," + << (event_data.identifier ? event_data.identifier : "NULL") << "\n"; + } +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + QnnSystemProfile_ProfileEventV1_t* created_event = nullptr; + if (event_level == "SUB-EVENT") { + auto parent_system_event = GetParentSystemEvent(event_id); + ORT_RETURN_IF(parent_system_event == nullptr, "Serialization of subevent failed: parent event pointer is null"); + created_event = AddSubEvent(event_id, event_data, parent_system_event); + } else { + created_event = AddEvent(event_id, event_data); + } + + ORT_RETURN_IF(created_event == nullptr, "Serialization of event failed: Unable to create system profile event"); +#endif + + if (tracelogging_provider_ep_enabled_) { +#ifdef _WIN32 + LogQnnProfileEventAsTraceLogging( + (uint64_t)0, + message, + std::to_string(event_data.value), + unit, + "BACKEND", + event_level, + (event_data.identifier ? event_data.identifier : "NULL")); +#endif + } + + return Status::OK(); +} + +Status Serializer::ProcessExtendedEvent(const QnnProfile_EventId_t event_id, const std::string& event_level, + const QnnProfile_ExtendedEventData_t& event_data) { + // need to check the version first + const std::string& message = GetEventTypeString(event_data.v1.type); + const std::string& unit = GetUnitString(event_data.v1.unit); + + ORT_UNUSED_PARAMETER(event_id); + + if (outfile_) { + if (event_data.version == QNN_PROFILE_DATA_VERSION_1) { + outfile_ << event_data.v1.timestamp << "," + << message << "," + << ExtractQnnScalarValue(event_data.v1.value) << "," + << unit << "," + << "BACKEND" + << "," + << event_level << "," + << (event_data.v1.identifier ? event_data.v1.identifier : "NULL") + << "\n"; + } + } +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + QnnSystemProfile_ProfileEventV1_t* created_event = nullptr; + if (event_level == "SUB-EVENT") { + auto parent_system_event = GetParentSystemEvent(event_id); + ORT_RETURN_IF(parent_system_event == nullptr, "Serialization of subevent failed: parent event pointer is null"); + created_event = AddExtendedSubEvent(event_id, event_data, parent_system_event); + } else { + created_event = AddExtendedEvent(event_id, event_data); + } + + ORT_RETURN_IF(created_event == nullptr, "Serialization of event failed: Unable to create system profile event"); +#endif + + if (tracelogging_provider_ep_enabled_) { +#ifdef _WIN32 + LogQnnProfileEventAsTraceLogging( + event_data.v1.timestamp, + message, + ExtractQnnScalarValue(event_data.v1.value), + unit, + "BACKEND", + event_level, + (event_data.v1.identifier ? event_data.v1.identifier : "NULL")); +#endif + } + + return Status::OK(); +} + +Status Serializer::InitCsvFile() { + auto output_filepath = profiling_info_.csv_output_filepath; + // Write to CSV in append mode + std::ifstream infile(output_filepath.c_str()); + bool exists = infile.good(); + if (infile.is_open()) { + infile.close(); + } + + outfile_.open(output_filepath.c_str(), std::ios_base::app); + ORT_RETURN_IF(!outfile_.is_open(), "Failed to open profiling file: ", output_filepath); + // If file didn't exist before, write the header + if (!exists) { + outfile_ << "Msg Timestamp,Message,Time,Unit of Measurement,Timing Source,Event Level,Event Identifier\n"; + } + + return Status::OK(); +} + +Serializer::Serializer(const ProfilingInfo& profiling_info, + QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface, + bool tracelogging_provider_ep_enabled) + : profiling_info_(profiling_info), + qnn_system_interface_(qnn_system_interface), + tracelogging_provider_ep_enabled_(tracelogging_provider_ep_enabled) { +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + std::filesystem::path output_fs_filepath(profiling_info.csv_output_filepath); + qnn_log_filename_ = output_fs_filepath.filename().string(); + // Remove extension (assumed to be ".csv") then add "_qnn.log" + size_t extension_start_idx = qnn_log_filename_.rfind("."); + qnn_log_filename_ = qnn_log_filename_.substr(0, extension_start_idx); + qnn_log_filename_.append("_qnn.log"); + + std::filesystem::path abs_output_path; + if (output_fs_filepath.has_root_path()) { + abs_output_path = output_fs_filepath.parent_path(); + } else { + abs_output_path = std::filesystem::current_path() / output_fs_filepath.parent_path(); + } + output_directory_ = abs_output_path.string(); + + event_data_list_.reserve(profiling_info.num_events); +#endif +} + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED +QnnSystemProfile_MethodType_t ParseMethodType(ProfilingMethodType method_type) { + switch (method_type) { + case ProfilingMethodType::EXECUTE: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_BACKEND_EXECUTE; + case ProfilingMethodType::FINALIZE: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_BACKEND_FINALIZE; + case ProfilingMethodType::EXECUTE_ASYNC: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_BACKEND_EXECUTE_ASYNC; + case ProfilingMethodType::CREATE_FROM_BINARY: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_BACKEND_CREATE_FROM_BINARY; + case ProfilingMethodType::DEINIT: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_BACKEND_DEINIT; + case ProfilingMethodType::CONTEXT_CREATE: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_APP_CONTEXT_CREATE; + case ProfilingMethodType::COMPOSE_GRAPHS: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_APP_COMPOSE_GRAPHS; + case ProfilingMethodType::EXECUTE_IPS: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_APP_EXECUTE_IPS; + case ProfilingMethodType::GRAPH_COMPONENT: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_BACKEND_GRAPH_COMPONENT; + case ProfilingMethodType::LIB_LOAD: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_APP_BACKEND_LIB_LOAD; + case ProfilingMethodType::APPLY_BINARY_SECTION: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_BACKEND_APPLY_BINARY_SECTION; + case ProfilingMethodType::CONTEXT_FINALIZE: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_CONTEXT_FINALIZE; + default: + return QNN_SYSTEM_PROFILE_METHOD_TYPE_NONE; + } +} + +std::string GetSystemProfileErrorString(Qnn_ErrorHandle_t error) { + switch (error) { + case QNN_SYSTEM_PROFILE_ERROR_UNSUPPORTED_FEATURE: + return "Unsupported Feature"; + case QNN_SYSTEM_PROFILE_ERROR_INVALID_HANDLE: + return "Invalid Handle"; + case QNN_SYSTEM_PROFILE_ERROR_INVALID_ARGUMENT: + return "Invalid Argument"; + case QNN_SYSTEM_PROFILE_ERROR_MEM_ALLOC: + return "Memory Allocation Error"; + default: + return "Unknown"; + } +} + +QnnSystemProfile_ProfileEventV1_t* Serializer::AddEvent(const QnnProfile_EventId_t event_id, + const QnnProfile_EventData_t event) { + return CreateSystemEvent(event_data_list_, event_id, event); +} +QnnSystemProfile_ProfileEventV1_t* Serializer::AddExtendedEvent(const QnnProfile_EventId_t event_id, + const QnnProfile_ExtendedEventData_t event) { + return CreateSystemExtendedEvent(event_data_list_, event_id, event); +} + +QnnSystemProfile_ProfileEventV1_t* Serializer::AddSubEvent(const QnnProfile_EventId_t event_id, + const QnnProfile_EventData_t& sub_event, + QnnSystemProfile_ProfileEventV1_t* const parent_system_event) { + if (sub_event_data_map_.find(parent_system_event) == sub_event_data_map_.end()) { + return nullptr; + } + + auto& sub_event_list = sub_event_data_map_.at(parent_system_event); + return CreateSystemEvent(sub_event_list, event_id, sub_event); +} +QnnSystemProfile_ProfileEventV1_t* Serializer::AddExtendedSubEvent(const QnnProfile_EventId_t event_id, + const QnnProfile_ExtendedEventData_t& sub_event, + QnnSystemProfile_ProfileEventV1_t* const parent_system_event) { + if (sub_event_data_map_.find(parent_system_event) == sub_event_data_map_.end()) { + return nullptr; + } + + auto& sub_event_list = sub_event_data_map_.at(parent_system_event); + return CreateSystemExtendedEvent(sub_event_list, event_id, sub_event); +} + +Status Serializer::SerializeEventsToQnnLog() { + bool result = nullptr == qnn_system_interface_.systemProfileCreateSerializationTarget || + nullptr == qnn_system_interface_.systemProfileSerializeEventData || + nullptr == qnn_system_interface_.systemProfileFreeSerializationTarget; + ORT_RETURN_IF(result, "Failed to get system profile API pointers."); + + auto method_type = profiling_info_.method_type; + ORT_RETURN_IF(method_type == ProfilingMethodType::UNKNOWN, "Invalid serialization method type"); + + QnnSystemProfile_SerializationTargetConfig_t config; + config.type = QNN_SYSTEM_PROFILE_SERIALIZATION_TARGET_CONFIG_SERIALIZATION_HEADER; + + std::string backend_version(std::to_string(QNN_API_VERSION_MAJOR) + "." + std::to_string(QNN_API_VERSION_MINOR) + "." + std::to_string(QNN_API_VERSION_PATCH)); + + std::string app_version(std::to_string(ORT_API_VERSION)); + config.serializationHeader.appName = "OnnxRuntime"; + config.serializationHeader.appVersion = app_version.c_str(); + config.serializationHeader.backendVersion = backend_version.c_str(); + + QnnSystemProfile_SerializationTargetFile_t serialization_file{qnn_log_filename_.c_str(), output_directory_.c_str()}; + QnnSystemProfile_SerializationTarget_t serialization_target = { + QNN_SYSTEM_PROFILE_SERIALIZATION_TARGET_FILE, + {serialization_file}}; + + QnnSystemProfile_SerializationTargetHandle_t serialization_target_handle; + + auto status = qnn_system_interface_.systemProfileCreateSerializationTarget(serialization_target, &config, 1, + &serialization_target_handle); + ORT_RETURN_IF(QNN_SYSTEM_PROFILE_NO_ERROR != status, "Failed to create serialization target handle: ", + GetSystemProfileErrorString(status)); + + ManagedSerializationTargetHandle managed_target_handle(serialization_target_handle, qnn_system_interface_); + + // Set subevent data pointers for all event data + // Must be done here as underlying array ptrs can change as vectors are resized + for (auto it = sub_event_data_map_.begin(); it != sub_event_data_map_.end(); it++) { + it->first->profileSubEventData = it->second.data(); + it->first->numSubEvents = static_cast(it->second.size()); + } + + // Create QnnSystemProfile_ProfileData_t obj here + QnnSystemProfile_ProfileData_t system_profile_data = QNN_SYSTEM_PROFILE_DATA_INIT; + system_profile_data.version = QNN_SYSTEM_PROFILE_DATA_VERSION_1; + system_profile_data.v1.header.startTime = profiling_info_.start_time; + system_profile_data.v1.header.stopTime = profiling_info_.stop_time; + system_profile_data.v1.header.graphName = profiling_info_.graph_name.c_str(); + system_profile_data.v1.header.methodType = ParseMethodType(method_type); + system_profile_data.v1.profilingEvents = event_data_list_.data(); + system_profile_data.v1.numProfilingEvents = static_cast(event_data_list_.size()); + + std::vector system_profile_data_list = {&system_profile_data}; + status = qnn_system_interface_.systemProfileSerializeEventData(serialization_target_handle, + system_profile_data_list.data(), + 1); + + ORT_RETURN_IF(QNN_SYSTEM_PROFILE_NO_ERROR != status, "Failed to serialize QNN profiling data: ", + GetSystemProfileErrorString(status)); + + status = managed_target_handle.FreeHandle(); + ORT_RETURN_IF(QNN_SYSTEM_PROFILE_NO_ERROR != status, "Failed to free serialization target: ", + GetSystemProfileErrorString(status)); + + return Status::OK(); +} + +void Serializer::AddSubEventList(const uint32_t num_sub_events, QnnSystemProfile_ProfileEventV1_t* event_ptr) { + if (num_sub_events > 0U) { + auto it = sub_event_data_map_.emplace(event_ptr, std::vector()).first; + it->second.reserve(num_sub_events); + } +} + +QnnSystemProfile_ProfileEventV1_t* Serializer::GetSystemEventPointer(QnnProfile_EventId_t event_id) { + auto it = event_profile_id_lookup_map_.find(event_id); + if (it == event_profile_id_lookup_map_.end()) { + return nullptr; + } + + return it->second; +} + +Status Serializer::SetParentSystemEvent( + const QnnProfile_EventId_t event_id, + QnnSystemProfile_ProfileEventV1_t* const system_parent_event) { + ORT_RETURN_IF(!(system_parent_event_lookup_map_.emplace(event_id, system_parent_event).second), + "Failed to add subevent-parent event mapping"); + return Status::OK(); +} +QnnSystemProfile_ProfileEventV1_t* Serializer::GetParentSystemEvent(const QnnProfile_EventId_t event_id) { + if (system_parent_event_lookup_map_.find(event_id) == system_parent_event_lookup_map_.end()) { + return nullptr; + } + + return system_parent_event_lookup_map_.at(event_id); +} + +QnnSystemProfile_ProfileEventV1_t* Serializer::CreateSystemEvent( + std::vector& event_list, + QnnProfile_EventId_t event_id, + QnnProfile_EventData_t event_data) { + auto system_event = &(event_list.emplace_back()); + + system_event->type = QNN_SYSTEM_PROFILE_EVENT_DATA; + system_event->eventData = event_data; + system_event->profileSubEventData = NULL; + system_event->numSubEvents = 0; + + event_profile_id_lookup_map_.emplace(event_id, system_event); + + return system_event; +} + +QnnSystemProfile_ProfileEventV1_t* Serializer::CreateSystemExtendedEvent(std::vector& event_list, + QnnProfile_EventId_t event_id, + QnnProfile_ExtendedEventData_t event_data) { + auto system_event = &(event_list.emplace_back()); + + system_event->type = QNN_SYSTEM_PROFILE_EXTENDED_EVENT_DATA; + system_event->extendedEventData = event_data; + system_event->profileSubEventData = NULL; + system_event->numSubEvents = 0; + + event_profile_id_lookup_map_.emplace(event_id, system_event); + + return system_event; +} +#endif // QNN_SYSTEM_PROFILE_API_ENABLED + +} // namespace profile +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.h b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.h new file mode 100644 index 0000000000000..0b56d05289ccc --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_profile_serializer.h @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include + +#include + +#include "core/providers/qnn/ort_api.h" +#include "core/providers/qnn/builder/qnn_def.h" + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + +#include + +#include +#include +#include + +#endif + +namespace onnxruntime { + +namespace qnn { +namespace profile { + +struct ProfilingInfo { + std::string graph_name = ""; + std::string csv_output_filepath = ""; + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + uint64_t start_time = 0; + uint64_t stop_time = 0; + uint32_t num_events = 0; + + ProfilingMethodType method_type = ProfilingMethodType::UNKNOWN; +#endif +}; + +class Serializer { + public: + Serializer(const ProfilingInfo& profiling_info, + QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface, + bool tracelogging_provider_ep_enabled); + + // Extracts all event/subevent data then: + // 1. Writes/appends data to a csv file defined in profiling_info_ + // 2. If QNN System Profile API is enabled, converts the data + // into a QNN System Profile Event and stores the new obj locally + Status ProcessEvent(const QnnProfile_EventId_t event_Id, const std::string& event_level, + const QnnProfile_EventData_t& event_data); + + // Extracts all event/subevent data then: + // 1. Writes/appends data to a csv file defined in profiling_info_ + // 2. If QNN System Profile API is enabled, converts the data + // into a QNN System Profile Event and stores the new obj locally + Status ProcessExtendedEvent(const QnnProfile_EventId_t event_id, const std::string& event_level, + const QnnProfile_ExtendedEventData_t& event_data); + + // If QNN API is too old, turn Serializer into an ofstream wrapper class + // Keeps code clean, any performance impacts can be ignored when profiling is enabled + ~Serializer() { +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + event_data_list_.clear(); + system_parent_event_lookup_map_.clear(); + event_profile_id_lookup_map_.clear(); + sub_event_data_map_.clear(); +#endif + } + + // Initializes outfile_ to output to a defined .csv file + // or appends to the defined .csv file if it already exists + // This is in its own function & not ctor for error checking/handling + Status InitCsvFile(); + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + // Serializes all locally stored QNN System Profile Event data into a + // qnn profiling .log file in the same directory as the .csv file defined + // in profilng_info_. The output file name will have the same name + // as the .csv file (sans extension) with _qnn.log appended to the end. + Status SerializeEventsToQnnLog(); + + QnnSystemProfile_ProfileEventV1_t* GetParentSystemEvent(const QnnProfile_EventId_t event_id); + + QnnSystemProfile_ProfileEventV1_t* GetSystemEventPointer(const QnnProfile_EventId_t event_id); + + void AddSubEventList(const uint32_t num_sub_events, QnnSystemProfile_ProfileEventV1_t* event_ptr); + + Status SetParentSystemEvent(const QnnProfile_EventId_t event_id, + QnnSystemProfile_ProfileEventV1_t* const system_parent_event); +#endif + + private: +#ifdef _WIN32 + void LogQnnProfileEventAsTraceLogging( + uint64_t timestamp, + const std::string& message, + const std::string& qnnScalarValue, + const std::string& unit, + const std::string& timingSource, + const std::string& eventLevel, + const char* eventIdentifier); +#endif + +#ifdef QNN_SYSTEM_PROFILE_API_ENABLED + class ManagedSerializationTargetHandle { + public: + ManagedSerializationTargetHandle(const QnnSystemProfile_SerializationTargetHandle_t& raw_handle, + QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface) : qnn_system_interface_(qnn_system_interface), + handle_(raw_handle) {} + + ~ManagedSerializationTargetHandle() { + auto status = FreeHandle(); + ORT_UNUSED_PARAMETER(status); + } + + Qnn_ErrorHandle_t FreeHandle() { + return qnn_system_interface_.systemProfileFreeSerializationTarget(handle_); + } + + private: + QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface_; + QnnSystemProfile_SerializationTargetHandle_t handle_; + }; // ManagedSerializationTargetHandle + + QnnSystemProfile_ProfileEventV1_t* AddEvent(const QnnProfile_EventId_t event_Id, + const QnnProfile_EventData_t event); + + QnnSystemProfile_ProfileEventV1_t* AddExtendedEvent(const QnnProfile_EventId_t event_id, + const QnnProfile_ExtendedEventData_t event); + + QnnSystemProfile_ProfileEventV1_t* AddSubEvent(const QnnProfile_EventId_t event_id, + const QnnProfile_EventData_t& sub_event, + QnnSystemProfile_ProfileEventV1_t* const system_parent_event); + + QnnSystemProfile_ProfileEventV1_t* AddExtendedSubEvent(const QnnProfile_EventId_t event_id, + const QnnProfile_ExtendedEventData_t& sub_event, + QnnSystemProfile_ProfileEventV1_t* const system_parent_event); + + QnnSystemProfile_ProfileEventV1_t* CreateSystemEvent(std::vector& event_list, + const QnnProfile_EventId_t event_id, + QnnProfile_EventData_t event_data); + + QnnSystemProfile_ProfileEventV1_t* CreateSystemExtendedEvent(std::vector& event_list, + const QnnProfile_EventId_t event_id, + QnnProfile_ExtendedEventData_t event_data); + + std::string qnn_log_filename_; + std::string output_directory_; + std::vector event_data_list_; + std::unordered_map system_parent_event_lookup_map_; + std::unordered_map event_profile_id_lookup_map_; + std::unordered_map > sub_event_data_map_; + +#endif // QNN_SYSTEM_PROFILE_API_ENABLED + const ProfilingInfo profiling_info_; + QNN_SYSTEM_INTERFACE_VER_TYPE qnn_system_interface_; + bool tracelogging_provider_ep_enabled_ = false; + std::ofstream outfile_; +}; + +} // namespace profile +} // namespace qnn +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc index afa5e3bdbb6d1..6a3234e507f1f 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.cc @@ -1426,6 +1426,11 @@ Status GetPermToLastAxis(uint32_t axis, uint32_t rank, std::vector& pe return Status::OK(); } +uint64_t GetTimeStampInUs() { + auto timestamp = std::chrono::steady_clock::now().time_since_epoch(); + return std::chrono::duration_cast(timestamp).count(); +} + } // namespace utils } // namespace qnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_utils.h b/onnxruntime/core/providers/qnn/builder/qnn_utils.h index b234f7df375e9..b86e4d27789b4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_utils.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_utils.h @@ -453,6 +453,12 @@ Status InsertConvertOp(QnnModelWrapper& qnn_model_wrapper, * @return execution status of this function */ Status GetPermToLastAxis(uint32_t axis, uint32_t rank, std::vector& perm); +/** + * Get the current timestamp in microseconds + * + * @return the current timestamp in microseconds + */ +uint64_t GetTimeStampInUs(); } // namespace utils } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 2bdbfb9c1c62e..2ff349bb8d765 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -94,6 +94,8 @@ static void ParseProfilingLevel(std::string profiling_level_string, profiling_level = qnn::ProfilingLevel::BASIC; } else if (profiling_level_string == "detailed") { profiling_level = qnn::ProfilingLevel::DETAILED; + } else if (profiling_level_string == "optrace") { + profiling_level = qnn::ProfilingLevel::OPTRACE; } else { LOGS_DEFAULT(WARNING) << "Profiling level not valid."; } @@ -400,6 +402,7 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio if (profiling_level_pos != provider_options_map.end()) { ParseProfilingLevel(profiling_level_pos->second, profiling_level); } + static const std::string PROFILING_FILE = "profiling_file_path"; auto profiling_file_pos = provider_options_map.find(PROFILING_FILE); if (profiling_file_pos != provider_options_map.end()) { diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc index 2993bfcebb1da..7dd3b50c656f4 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.cc +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.cc @@ -158,9 +158,9 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, int output_components, uint32_t tile_inner, bool split_k, - uint32_t splitted_dim_inner) { + uint32_t split_dim_inner) { ORT_UNUSED_PARAMETER(split_k); - ORT_UNUSED_PARAMETER(splitted_dim_inner); + ORT_UNUSED_PARAMETER(split_dim_inner); const std::string type_string = MakeScalarOrVectorType(4 /*components */, data_type); @@ -356,9 +356,9 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, bool need_handle_matmul, uint32_t tile_inner, bool split_k, - uint32_t splitted_dim_inner) { + uint32_t split_dim_inner) { ORT_UNUSED_PARAMETER(split_k); - ORT_UNUSED_PARAMETER(splitted_dim_inner); + ORT_UNUSED_PARAMETER(split_dim_inner); const auto elements_per_thread_x = elements_per_thread[0]; const auto elements_per_thread_y = elements_per_thread[1]; diff --git a/onnxruntime/core/providers/webgpu/math/gemm_utils.h b/onnxruntime/core/providers/webgpu/math/gemm_utils.h index 2244f2810c3bf..ed4cf997d2f00 100644 --- a/onnxruntime/core/providers/webgpu/math/gemm_utils.h +++ b/onnxruntime/core/providers/webgpu/math/gemm_utils.h @@ -43,7 +43,7 @@ Status MakeMatMulPackedVec4Source(ShaderHelper& shader, int output_components = 4, uint32_t tile_inner = 32, bool split_k = false, - uint32_t splitted_dim_inner = 32); + uint32_t split_dim_inner = 32); Status MakeMatMulPackedSource(ShaderHelper& shader, const InlinedVector& elements_per_thread, @@ -57,7 +57,7 @@ Status MakeMatMulPackedSource(ShaderHelper& shader, bool need_handle_matmul = true, uint32_t tile_inner = 32, bool split_k = false, - uint32_t splitted_dim_inner = 32); + uint32_t split_dim_inner = 32); } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/nn/conv.cc b/onnxruntime/core/providers/webgpu/nn/conv.cc index 7f82f85fc8f91..a2777979ae983 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/providers/webgpu/nn/conv.h" -#include "core/providers/webgpu/nn/conv2d_mm_webgpu.h" +#include "core/providers/webgpu/nn/conv2d_mm.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/tensor/transpose.h" @@ -9,14 +9,6 @@ #include "core/providers/webgpu/webgpu_utils.h" #include "core/providers/webgpu/math/matmul.h" -namespace { - -inline uint32_t ceil_div(int64_t numerator, int32_t denominator) { - return static_cast((numerator + denominator - 1) / denominator); -} - -} // namespace - namespace onnxruntime { namespace webgpu { @@ -27,37 +19,10 @@ Status TransposeKernel(ComputeContext& context, const Tensor* kernel, const Tens for (size_t i = 0; i < rank; ++i) { transposed_kernel_shape_vector[i] = kernel_shape[perm[i]]; } - uint32_t output_size = onnxruntime::narrow(kernel_shape.Size()); - - uint32_t dispatch_x = ceil_div(output_size, 64); - uint32_t dispatch_y = 1; - uint32_t dispatch_z = 1; - - // This temporary workaround addresses a significant performance bottleneck - // (~12x slower) for the shape (3, 3, 2560, 1280) due to an issue with Intel's - // GPU drivers. We manually normalize the dispatch group size to restore - // performance. - // - // TODO: Revert this change once the driver issue is fixed. - if (context.AdapterInfo().vendor == std::string_view{"intel"}) { - ORT_ENFORCE(rank == static_cast(4), "Input tensor must have rank 4."); - dispatch_x = ceil_div(transposed_kernel_shape_vector[0] * transposed_kernel_shape_vector[1], 2); - dispatch_y = ceil_div(transposed_kernel_shape_vector[2], 4); - dispatch_z = ceil_div(transposed_kernel_shape_vector[3], 8); - } - TensorShape transposed_kernel_shape(transposed_kernel_shape_vector); *transposed_kernel = context.CreateGPUTensor(kernel->DataType(), transposed_kernel_shape); - bool use_shared = false; - TransposeProgram program{perm, use_shared}; - program - .CacheHint(absl::StrJoin(perm, "-")) - .AddInput({kernel, ProgramTensorMetadataDependency::TypeAndRank, kernel_shape, 1}) - .AddOutput({transposed_kernel, ProgramTensorMetadataDependency::TypeAndRank}) - .AddUniformVariable({output_size}) - .SetWorkgroupSize(64) - .SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z); - return context.RunProgram(program); + const Tensor reshaped_kernel(kernel->DataType(), kernel_shape, const_cast(kernel->DataRaw()), kernel->Location()); + return Transpose::DoTranspose(context, perm, reshaped_kernel, *transposed_kernel); } template diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc similarity index 99% rename from onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc rename to onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc index bf5208883508f..2d5424c52a3f2 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.cc @@ -4,7 +4,7 @@ #include #include #include -#include "core/providers/webgpu/nn/conv2d_mm_webgpu.h" +#include "core/providers/webgpu/nn/conv2d_mm.h" #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/nn/activation_util.h" diff --git a/onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h b/onnxruntime/core/providers/webgpu/nn/conv2d_mm.h similarity index 100% rename from onnxruntime/core/providers/webgpu/nn/conv2d_mm_webgpu.h rename to onnxruntime/core/providers/webgpu/nn/conv2d_mm.h diff --git a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc b/onnxruntime/core/providers/webgpu/nn/conv_backprop.cc similarity index 99% rename from onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc rename to onnxruntime/core/providers/webgpu/nn/conv_backprop.cc index 8fedf748997e6..effc5cacf6b64 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv_backprop.cc @@ -4,7 +4,7 @@ #include #include #include "core/common/inlined_containers.h" -#include "core/providers/webgpu/nn/conv_backprop_webgpu.h" +#include "core/providers/webgpu/nn/conv_backprop.h" #include "core/providers/webgpu/webgpu_utils.h" namespace onnxruntime { namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h b/onnxruntime/core/providers/webgpu/nn/conv_backprop.h similarity index 100% rename from onnxruntime/core/providers/webgpu/nn/conv_backprop_webgpu.h rename to onnxruntime/core/providers/webgpu/nn/conv_backprop.h diff --git a/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc index 9cd290ef56013..84a0afd873d23 100644 --- a/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/webgpu/nn/conv_transpose.cc @@ -6,7 +6,7 @@ #include "core/providers/cpu/nn/conv_attributes.h" #include "core/providers/webgpu/webgpu_supported_types.h" #include "core/providers/webgpu/tensor/transpose.h" -#include "core/providers/webgpu/nn/conv_backprop_webgpu.h" +#include "core/providers/webgpu/nn/conv_backprop.h" namespace onnxruntime { namespace webgpu { diff --git a/onnxruntime/core/providers/webgpu/program_manager.cc b/onnxruntime/core/providers/webgpu/program_manager.cc index 33c3514f8f6d3..aa0db6ad06301 100644 --- a/onnxruntime/core/providers/webgpu/program_manager.cc +++ b/onnxruntime/core/providers/webgpu/program_manager.cc @@ -81,6 +81,9 @@ Status ProgramManager::Build(const ProgramBase& program, ORT_RETURN_IF_ERROR(program.GenerateShaderCode(shader_helper)); + // Finalize inputs after GenerateShaderCode() to ensure indirect buffer is added as the last input + shader_helper.FinalizeInputs(); + ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForInputs()); ORT_RETURN_IF_ERROR(shader_helper.ValidateShapeForOutputs()); ORT_RETURN_IF_ERROR(shader_helper.ValidateIndices()); diff --git a/onnxruntime/core/providers/webgpu/shader_helper.cc b/onnxruntime/core/providers/webgpu/shader_helper.cc index 0e4a3e08e1c13..0182bdc607173 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.cc +++ b/onnxruntime/core/providers/webgpu/shader_helper.cc @@ -54,6 +54,7 @@ Status ShaderHelper::Init() { // init body string stream bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1; + bool use_indirect_dispatch = program_.IndirectDispatchTensor() != nullptr; body_.reserve(4096); additional_implementation_.reserve(1024); @@ -68,15 +69,25 @@ Status ShaderHelper::Init() { " @builtin(subgroup_invocation_id) sg_id : u32,\n" " @builtin(subgroup_size) sg_size : u32"; } - if (!is_1d_dispatch) { - body_ss_ << ",\n" - " @builtin(num_workgroups) num_workgroups : vec3"; - } - body_ss_ << ") {\n"; - if (is_1d_dispatch) { + // When using indirect dispatch, avoid @builtin(num_workgroups) to skip Dawn's validation + // and duplication overhead in TransformIndirectDispatchBuffer. + // Instead, the dispatch dimensions will be read from the indirect buffer at runtime. + if (use_indirect_dispatch) { + body_ss_ << ") {\n"; + // For indirect dispatch, read the actual dispatch dimensions from the indirect buffer. + // The indirect buffer format is: [x, y, z] where x, y, z are the workgroup counts. + // We read these values to calculate workgroup_idx accurately based on actual dispatch. + body_ss_ << " let num_workgroups_x = indirect_buffer[0];\n" + " let num_workgroups_y = indirect_buffer[1];\n" + " let workgroup_idx = workgroup_id.z * num_workgroups_x * num_workgroups_y + workgroup_id.y * num_workgroups_x + workgroup_id.x;\n" + " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; + } else if (is_1d_dispatch) { + body_ss_ << ") {\n"; body_ss_ << " let global_idx = global_id.x;\n" " let workgroup_idx = workgroup_id.x;\n"; } else { + body_ss_ << ",\n" + " @builtin(num_workgroups) num_workgroups : vec3) {\n"; body_ss_ << " let workgroup_idx = workgroup_id.z * num_workgroups[0] * num_workgroups[1] + workgroup_id.y * num_workgroups[0] + workgroup_id.x;\n" " let global_idx = workgroup_idx * (workgroup_size_x * workgroup_size_y * workgroup_size_z) + local_idx;\n"; } @@ -84,6 +95,13 @@ Status ShaderHelper::Init() { return Status::OK(); } +void ShaderHelper::FinalizeInputs() { + // Automatically add indirect buffer as the last shader input when using indirect dispatch. + if (program_.IndirectDispatchTensor() != nullptr) { + AddInput("indirect_buffer", ShaderUsage::None); + } +} + const ShaderVariableHelper& ShaderHelper::AddInput(const std::string& name, ShaderUsage usage) { const size_t input_index = input_vars_.size(); ORT_ENFORCE(input_index < program_.Inputs().size(), diff --git a/onnxruntime/core/providers/webgpu/shader_helper.h b/onnxruntime/core/providers/webgpu/shader_helper.h index 6878f5236fddf..5a0398839267b 100644 --- a/onnxruntime/core/providers/webgpu/shader_helper.h +++ b/onnxruntime/core/providers/webgpu/shader_helper.h @@ -75,6 +75,11 @@ class ShaderHelper final { Status Init(); + // Finalize inputs by automatically adding the indirect buffer if needed. + // This should be called after GenerateShaderCode() to ensure the indirect buffer + // is registered as the last input. + void FinalizeInputs(); + // Add an input variable to the shader. // // depending on the usage of the variable, additional code may be generated. diff --git a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc index 136a1ba9776a0..63516d7e48d52 100644 --- a/onnxruntime/core/providers/webgpu/tensor/squeeze.cc +++ b/onnxruntime/core/providers/webgpu/tensor/squeeze.cc @@ -11,7 +11,43 @@ namespace webgpu { ONNX_OPERATOR_KERNEL_EX( Squeeze, kOnnxDomain, - 13, + 24, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 23, 23, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 21, 22, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", WebGpuSupportedNumberTypes()) + .TypeConstraint("axes", DataTypeImpl::GetTensorType()) + .Alias(0, 0) + .InputMemoryType(OrtMemTypeCPU, 1), + Squeeze); + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Squeeze, + kOnnxDomain, + 13, 20, kWebGpuExecutionProvider, (*KernelDefBuilder::Create()) .TypeConstraint("T", WebGpuSupportedNumberTypes()) diff --git a/onnxruntime/core/providers/webgpu/tensor/transpose.cc b/onnxruntime/core/providers/webgpu/tensor/transpose.cc index 5f1496ff7a40e..1458eceb3e244 100644 --- a/onnxruntime/core/providers/webgpu/tensor/transpose.cc +++ b/onnxruntime/core/providers/webgpu/tensor/transpose.cc @@ -8,6 +8,14 @@ #include "core/providers/webgpu/shader_helper.h" #include "core/providers/webgpu/webgpu_supported_types.h" +namespace { + +inline uint32_t ceil_div(int64_t numerator, int32_t denominator) { + return static_cast((numerator + denominator - 1) / denominator); +} + +} // namespace + namespace onnxruntime { namespace webgpu { @@ -134,22 +142,39 @@ Status Transpose::DoTranspose(onnxruntime::webgpu::ComputeContext& context, uint32_t output_size = onnxruntime::narrow(input_shape.Size()); TransposeProgram program{permutations, use_shared}; - if (use_shared) { - program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); - } program .CacheHint(absl::StrJoin(permutations, "-")) .AddInputs({{&input, ProgramTensorMetadataDependency::TypeAndRank, new_input_shape, 1}}) .AddOutputs({{&output, ProgramTensorMetadataDependency::None, new_output_shape, 1}}) - .SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), - static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) .AddUniformVariables({ {static_cast(output_size)}, }); - use_shared ? program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), - static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))) - : program.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE); + if (use_shared) { + program.SetWorkgroupSize(TILE_SIZE, TILE_SIZE, 1); + program.SetDispatchGroupSize(static_cast((new_output_shape[1] + TILE_SIZE - 1) / TILE_SIZE), + static_cast(((new_output_shape[0] + TILE_SIZE - 1) / TILE_SIZE))); + } else { + program.SetWorkgroupSize(WORKGROUP_SIZE); + + uint32_t dispatch_x = ceil_div(output_size, WORKGROUP_SIZE); + uint32_t dispatch_y = 1; + uint32_t dispatch_z = 1; + + // This temporary workaround addresses a significant performance bottleneck + // (~12x slower) for the shape (3, 3, 2560, 1280) due to an issue with Intel's + // GPU drivers. We manually normalize the dispatch group size to restore + // performance. + // + // TODO: Revert this change once the driver issue is fixed. + if (context.AdapterInfo().vendor == std::string_view{"intel"}) { + ORT_ENFORCE(rank == static_cast(4), "Input tensor must have rank 4."); + dispatch_x = ceil_div(input_shape[0] * input_shape[1], 2); + dispatch_y = ceil_div(input_shape[2], 4); + dispatch_z = ceil_div(input_shape[3], 8); + } + program.SetDispatchGroupSize(dispatch_x, dispatch_y, dispatch_z); + } return context.RunProgram(program); } diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index f48b78c9adb91..0e4004db35b10 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -43,7 +43,7 @@ void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_confi // Create wgpu::Adapter wgpu::RequestAdapterOptions req_adapter_options = {}; req_adapter_options.backendType = static_cast(backend_type); - req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance; + req_adapter_options.powerPreference = static_cast(power_preference_); #if !defined(__wasm__) auto enabled_adapter_toggles = GetEnabledAdapterToggles(); @@ -179,6 +179,11 @@ Status WebGpuContext::Wait(wgpu::Future f) { } Status WebGpuContext::Run(ComputeContext& context, ProgramBase& program) { + // Finalize program inputs by adding the indirect buffer as the last input if needed. + if (program.IndirectDispatchTensor() != nullptr) { + program.AddInput({program.IndirectDispatchTensor(), ProgramTensorMetadataDependency::None}); + } + const auto& inputs = program.Inputs(); const auto& outputs = program.Outputs(); @@ -950,7 +955,7 @@ WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& co auto it = contexts_.find(context_id); if (it == contexts_.end()) { GSL_SUPPRESS(r.11) - auto context = std::unique_ptr(new WebGpuContext(instance, device, config.validation_mode, config.preserve_device, config.small_storage_buffer_binding_size_for_testing)); + auto context = std::unique_ptr(new WebGpuContext(instance, device, config.validation_mode, config.preserve_device, config.small_storage_buffer_binding_size_for_testing, config.power_preference)); it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first; } else if (context_id != 0) { ORT_ENFORCE(it->second.context->instance_.Get() == instance && diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index e21a0e577311f..f1bebc0d52738 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -41,6 +41,7 @@ struct WebGpuContextConfig { ValidationMode validation_mode; bool preserve_device; bool small_storage_buffer_binding_size_for_testing; + int power_preference; }; struct WebGpuBufferCacheConfig { @@ -177,8 +178,8 @@ class WebGpuContext final { AtPasses }; - WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode, bool preserve_device, bool small_storage_buffer_binding_size_for_testing = false) - : instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None}, preserve_device_{preserve_device}, small_storage_buffer_binding_size_for_testing_{small_storage_buffer_binding_size_for_testing} {} + WebGpuContext(WGPUInstance instance, WGPUDevice device, webgpu::ValidationMode validation_mode, bool preserve_device, bool small_storage_buffer_binding_size_for_testing = false, int power_preference = static_cast(wgpu::PowerPreference::HighPerformance)) + : instance_{instance}, device_{device}, validation_mode_{validation_mode}, query_type_{TimestampQueryType::None}, preserve_device_{preserve_device}, small_storage_buffer_binding_size_for_testing_{small_storage_buffer_binding_size_for_testing}, power_preference_{power_preference} {} ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(WebGpuContext); void LaunchComputePipeline(const wgpu::ComputePassEncoder& compute_pass_encoder, @@ -267,6 +268,7 @@ class WebGpuContext final { bool is_profiling_ = false; bool preserve_device_; bool small_storage_buffer_binding_size_for_testing_; + int power_preference_; GraphCaptureState graph_capture_state_{GraphCaptureState::Default}; // External vector to store captured commands, owned by EP diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 0f7607ac1dbfe..135782ad577c4 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -246,7 +246,10 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Squeeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Squeeze); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, Squeeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); @@ -529,7 +532,10 @@ std::unique_ptr RegisterKernels(bool enable_graph_capture = fals BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 60934bef574fa..efd3cf08364f9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -168,6 +168,20 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( } LOGS_DEFAULT(VERBOSE) << "WebGPU EP small storage buffer binding size for testing: " << small_storage_buffer_binding_size_for_testing; + // power preference + int power_preference = static_cast(WGPUPowerPreference_HighPerformance); // default + std::string power_preference_str; + if (config_options.TryGetConfigEntry(kPowerPreference, power_preference_str)) { + if (power_preference_str == kPowerPreference_HighPerformance) { + power_preference = static_cast(WGPUPowerPreference_HighPerformance); + } else if (power_preference_str == kPowerPreference_LowPower) { + power_preference = static_cast(WGPUPowerPreference_LowPower); + } else { + ORT_THROW("Invalid power preference: ", power_preference_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP power preference: " << power_preference; + webgpu::WebGpuContextConfig context_config{ context_id, reinterpret_cast(webgpu_instance), @@ -176,6 +190,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( validation_mode, preserve_device, small_storage_buffer_binding_size_for_testing, + power_preference, }; LOGS_DEFAULT(VERBOSE) << "WebGPU EP Device ID: " << context_id; @@ -184,6 +199,7 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( LOGS_DEFAULT(VERBOSE) << "WebGPU EP DawnProcTable: " << dawn_proc_table; LOGS_DEFAULT(VERBOSE) << "WebGPU EP ValidationMode: " << validation_mode; LOGS_DEFAULT(VERBOSE) << "WebGPU EP PreserveDevice: " << preserve_device; + LOGS_DEFAULT(VERBOSE) << "WebGPU EP PowerPreference: " << power_preference; // // STEP.3 - prepare parameters for WebGPU context initialization. diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h index 761ff0d85fc98..a2ecd7f21f618 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_options.h +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_options.h @@ -15,6 +15,7 @@ constexpr const char* kEnableGraphCapture = "ep.webgpuexecutionprovider.enableGr constexpr const char* kDawnProcTable = "ep.webgpuexecutionprovider.dawnProcTable"; constexpr const char* kDawnBackendType = "ep.webgpuexecutionprovider.dawnBackendType"; +constexpr const char* kPowerPreference = "ep.webgpuexecutionprovider.powerPreference"; constexpr const char* kDeviceId = "ep.webgpuexecutionprovider.deviceId"; constexpr const char* kWebGpuInstance = "ep.webgpuexecutionprovider.webgpuInstance"; @@ -39,6 +40,9 @@ constexpr const char* kSmallStorageBufferBindingSizeForTesting = "ep.webgpuexecu constexpr const char* kDawnBackendType_D3D12 = "D3D12"; constexpr const char* kDawnBackendType_Vulkan = "Vulkan"; +constexpr const char* kPowerPreference_HighPerformance = "high-performance"; +constexpr const char* kPowerPreference_LowPower = "low-power"; + constexpr const char* kPreferredLayout_NCHW = "NCHW"; constexpr const char* kPreferredLayout_NHWC = "NHWC"; diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index 598eeea8d2e49..9056ac07cc286 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -75,6 +75,24 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision fp16 --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk ``` +Export + Quantize for INT8 CUDA +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cuda --use_gpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk +``` + +Export + Quantize for INT8 CPU +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --precision int8 --quantize_symmetric --provider cpu --use_external_data_format --optimize_onnx --no_beam_search_op --output_cross_qk +``` + ## Exporting Whisper with Beam Search There are several ways to export Whisper with beam search. @@ -143,13 +161,22 @@ $ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --o $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --optimize_onnx --precision fp16 --use_gpu --provider cuda ``` -Export + Quantize for INT8 +Export + Quantize for INT8 CUDA +``` +# From source: +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --use_gpu --provider cuda +``` + +Export + Quantize for INT8 CPU ``` # From source: -$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_embedding_layer +$ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/whisper-large-v3-turbo --output whisper-turbo --use_external_data_format --precision int8 --quantize_symmetric --provider cpu ``` Note: INT8 CPU is not compatible with `--output_cross_qk`. diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index a111db1edc257..88fdad01baf92 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -449,7 +449,7 @@ def parse_args(): type=str, required=True, default="fp32", - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision for model. For ONNX models, the model's precision should be set before running this script.", ) @@ -579,7 +579,7 @@ def main(): config = WhisperConfig.from_pretrained(args.model_name) processor = WhisperProcessor.from_pretrained(args.model_name) target_device = f"cuda:{args.device_id}" if args.device != "cpu" else args.device - use_fp16 = args.precision == "fp16" + use_fp16 = args.precision == "fp16" or (args.precision in {"int8", "int4"} and args.device != "cpu") setattr(args, "processor", processor) # noqa: B010 setattr(args, "target_device", target_device) # noqa: B010 diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index d2eb0d5259254..95d4b60fead99 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -97,7 +97,7 @@ def get_args(): "--precision", type=str, required=True, - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision to run model", ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py index e092285d57358..38fbd73e9c119 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/whisper/convert_to_onnx.py @@ -8,13 +8,18 @@ import logging import os +import onnx import torch from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger from whisper_chain import chain_model from whisper_encoder import WhisperEncoder from whisper_helper import PRETRAINED_WHISPER_MODELS, WhisperHelper -from onnxruntime import quantization +from onnxruntime.quantization.matmul_nbits_quantizer import ( + KQuantWeightOnlyQuantConfig, + MatMulNBitsQuantizer, + QuantFormat, +) logger = logging.getLogger("") @@ -94,8 +99,8 @@ def parse_arguments(argv=None): required=False, type=Precision, default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8], - help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8 for quantization", + choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4], + help="Precision of model to run. fp32 for full precision, fp16 for half precision, int8/int4 for quantization", ) conversion_args.add_argument( @@ -289,28 +294,20 @@ def parse_arguments(argv=None): ################################### quant_args.add_argument( - "--quantize_embedding_layer", - required=False, - action="store_true", - help="Quantize MatMul, GEMM, and Gather.", - ) - quant_args.set_defaults(quantize_embedding_layer=False) - - quant_args.add_argument( - "--quantize_per_channel", + "--accuracy_level", + default=0, required=False, - action="store_true", - help="Quantize weights per each channel.", + type=int, + help="Accuracy level of the 4-bit quantized MatMul computation.", ) - quant_args.set_defaults(quantize_per_channel=False) quant_args.add_argument( - "--quantize_reduce_range", + "--quantize_symmetric", required=False, action="store_true", - help="Quantize weights with 7 bits.", + help="Quantize weights symmetrically", ) - quant_args.set_defaults(quantize_reduce_range=False) + quant_args.set_defaults(quantize_symmetric=False) args = parser.parse_args(argv) @@ -323,6 +320,22 @@ def parse_arguments(argv=None): return args +# quant_method is reserved for mixed precision in future +def make_quant_algo_config(precision, quant_method: str, matmul_nodes=None): + customized_weight_config = {} + quant_algo_config = None + + # need to use k_quant for int8 + if precision == Precision.INT8: + for node_name in matmul_nodes: + customized_weight_config[node_name] = {"bits": 8} + quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) + else: + quant_algo_config = KQuantWeightOnlyQuantConfig(customized_weight_config=customized_weight_config) + + return quant_algo_config + + def export_onnx_models( model_name_or_path, model_impl, @@ -340,19 +353,21 @@ def export_onnx_models( output_qk: bool = False, overwrite: bool = False, use_int32_inputs: bool = True, - quantize_embedding_layer: bool = False, - quantize_per_channel: bool = False, - quantize_reduce_range: bool = False, + accuracy_level: int = 0, + quantize_symmetric: bool = False, provider: str = "cpu", ): device = torch.device("cuda" if use_gpu else "cpu") + if not use_gpu: + accuracy_level = 4 # change to 4 for CPU EP + use_fp16_inputs = precision == Precision.FLOAT16 or (precision in (Precision.INT8, Precision.INT4) and use_gpu) models = WhisperHelper.load_model( model_name_or_path, model_impl, cache_dir, device, - torch.float16 if precision == Precision.FLOAT16 else torch.float32, + torch.float16 if use_fp16_inputs else torch.float32, merge_encoder_and_decoder_init, no_beam_search_op, output_qk, @@ -384,7 +399,7 @@ def export_onnx_models( PROVIDERS[provider], verbose, use_external_data_format, - use_fp16_inputs=(precision == Precision.FLOAT16), + use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, use_encoder_hidden_states=(name == "decoder_init"), use_kv_cache_inputs=(name == "decoder"), @@ -430,27 +445,43 @@ def export_onnx_models( model.verify_onnx( onnx_path, PROVIDERS[provider], - use_fp16_inputs=(precision == Precision.FLOAT16), + use_fp16_inputs=use_fp16_inputs, ) else: model.verify_onnx( onnx_path, PROVIDERS[provider], - use_fp16_inputs=(precision == Precision.FLOAT16), + use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, ) - if precision == Precision.INT8: - quantization.quantize_dynamic( - onnx_path, + if precision in (Precision.INT8, Precision.INT4): + onnx_model = onnx.load(onnx_path, load_external_data=True) + matmul_nodes = [node.name for node in onnx_model.graph.node if node.op_type == "MatMul"] + quant_algo_config = make_quant_algo_config(precision, "k_quant", matmul_nodes) + + quant = MatMulNBitsQuantizer( + model=onnx_model, + block_size=32, + is_symmetric=quantize_symmetric, + accuracy_level=accuracy_level, + quant_format=QuantFormat.QOperator, + op_types_to_quantize=("MatMul",), + algo_config=quant_algo_config, + ) + quant.process() + if os.path.exists(output_path): + os.remove(output_path) + if os.path.exists(output_path + ".data"): + os.remove(output_path + ".data") + onnx.save_model( + quant.model.model, output_path, - op_types_to_quantize=( - ["MatMul", "Gemm", "Gather"] if quantize_embedding_layer else ["MatMul", "Gemm"] - ), - use_external_data_format=use_external_data_format, - per_channel=quantize_per_channel, - reduce_range=quantize_reduce_range, - extra_options={"MatMulConstBOnly": True}, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=os.path.basename(output_path) + ".data", + size_threshold=0, + convert_attribute=False, ) else: logger.info(f"Skip optimizing: existing ONNX model {onnx_path}") @@ -493,9 +524,8 @@ def main(argv=None): args.output_cross_qk, args.overwrite, not args.use_int64_inputs, - args.quantize_embedding_layer, - args.quantize_per_channel, - args.quantize_reduce_range, + args.accuracy_level, + args.quantize_symmetric, args.provider, ) diff --git a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt index 37fc72cd26e07..37b23d9daabf4 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/whisper/requirements.txt @@ -1,4 +1,4 @@ -torch>=2.7.0 +torch==2.7.0 transformers==4.52.3 openai-whisper==20240927 ffmpeg-python diff --git a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py index 365a69ee4ec67..c28fa06e13c76 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py +++ b/onnxruntime/python/tools/transformers/models/whisper/whisper_chain.py @@ -54,16 +54,19 @@ def chain_model(args): config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + use_fp16_inputs = args.precision == Precision.FLOAT16 or ( + args.precision in (Precision.INT8, Precision.INT4) and args.use_gpu + ) # Create inputs/outputs for WhisperBeamSearch op - temperature_name = "temperature_fp16" if args.precision == Precision.FLOAT16 else "temperature" + temperature_name = "temperature_fp16" if use_fp16_inputs else "temperature" beam_inputs = [ - "input_features_fp16" if args.precision == Precision.FLOAT16 else "input_features", + "input_features_fp16" if use_fp16_inputs else "input_features", "max_length", "min_length", "num_beams", "num_return_sequences", - "length_penalty_fp16" if args.precision == Precision.FLOAT16 else "length_penalty", - "repetition_penalty_fp16" if args.precision == Precision.FLOAT16 else "repetition_penalty", + "length_penalty_fp16" if use_fp16_inputs else "length_penalty", + "repetition_penalty_fp16" if use_fp16_inputs else "repetition_penalty", "vocab_mask" if args.use_vocab_mask else "", "prefix_vocab_mask" if args.use_prefix_vocab_mask else "", "", # attention mask @@ -74,8 +77,8 @@ def chain_model(args): temperature_name if args.use_temperature else "", ] - sequence_scores_name = "sequence_scores_fp16" if args.precision == Precision.FLOAT16 else "sequence_scores" - scores_name = "scores_fp16" if args.precision == Precision.FLOAT16 else "scores" + sequence_scores_name = "sequence_scores_fp16" if use_fp16_inputs else "sequence_scores" + scores_name = "scores_fp16" if use_fp16_inputs else "scores" beam_outputs = [ "sequences", sequence_scores_name if args.output_sequence_scores else "", @@ -85,7 +88,7 @@ def chain_model(args): ] graph_nodes = [] - if args.precision == Precision.FLOAT16: + if use_fp16_inputs: input_features_cast_node = helper.make_node( "Cast", inputs=["input_features"], diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index fa1725d9003d7..080730a489ecd 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -334,9 +334,13 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device ORT_THROW("Please provide the valid file path."); } } else if (key == "profiling_level") { - std::set supported_profiling_level = {"off", "basic", "detailed"}; + std::set supported_profiling_level = {"off", "basic", "detailed", "optrace"}; if (supported_profiling_level.find(value) == supported_profiling_level.end()) { - ORT_THROW("Supported profiling_level: off, basic, detailed"); + std::ostringstream str_stream; + std::copy(supported_profiling_level.begin(), supported_profiling_level.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Supported profiling_level: " + str); } } else if (key == "backend_type" || key == "rpc_control_latency" || key == "vtcm_mb" || key == "soc_model" || key == "device_id") { diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index c7b68b7f25a91..87ca6e32c82f9 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -1088,14 +1088,19 @@ static GetTestModelFn BuildCastAddTestCase() { }; } +void VerifyFileExistsAndIsNonEmpty(const std::string& filepath) { + std::ifstream csv_file(filepath, std::ifstream::binary); + ASSERT_TRUE(csv_file.good()); + + csv_file.seekg(0, csv_file.end); + size_t buffer_size = static_cast(csv_file.tellg()); + EXPECT_NE(0, buffer_size); +} + TEST_F(QnnHTPBackendTests, ProfilingTest) { onnxruntime::ProviderOptions provider_options; -#if defined(_WIN32) - provider_options["backend_path"] = "QnnHtp.dll"; -#else - provider_options["backend_path"] = "libQnnHtp.so"; -#endif + provider_options["backend_type"] = "htp"; provider_options["offload_graph_io_quantization"] = "0"; provider_options["enable_htp_fp16_precision"] = "1"; provider_options["profiling_level"] = "detailed"; @@ -1108,6 +1113,42 @@ TEST_F(QnnHTPBackendTests, ProfilingTest) { 13, ExpectedEPNodeAssignment::All, 0.008f); + + VerifyFileExistsAndIsNonEmpty(provider_options["profiling_file_path"]); + std::remove(provider_options["profiling_file_path"].c_str()); + +#if QNN_API_VERSION_MAJOR > 2 || \ + (QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 29)) + VerifyFileExistsAndIsNonEmpty("detailed_profile_qnn.log"); + std::remove("detailed_profile_qnn.log"); +#endif +} + +TEST_F(QnnHTPBackendTests, OptraceTest) { + onnxruntime::ProviderOptions provider_options; + + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + provider_options["enable_htp_fp16_precision"] = "1"; + provider_options["profiling_level"] = "optrace"; + provider_options["profiling_file_path"] = "optrace_profile.csv"; + + auto input_defs = {TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f), + TestInputDef({1, 2, 2, 2}, false, -10.0f, 10.0f)}; + RunQnnModelTest(BuildOpTestCase("Add", input_defs, {}, {}, kOnnxDomain), + provider_options, + 13, + ExpectedEPNodeAssignment::All, + 0.008f); + + VerifyFileExistsAndIsNonEmpty(provider_options["profiling_file_path"]); + std::remove(provider_options["profiling_file_path"].c_str()); + +#if QNN_API_VERSION_MAJOR > 2 || \ + (QNN_API_VERSION_MAJOR == 2 && (QNN_API_VERSION_MINOR >= 29)) + VerifyFileExistsAndIsNonEmpty("optrace_profile_qnn.log"); + std::remove("optrace_profile_qnn.log"); +#endif } TEST_F(QnnHTPBackendTests, CastAddQDQU8) { diff --git a/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc new file mode 100644 index 0000000000000..99fb604d55521 --- /dev/null +++ b/onnxruntime/test/providers/qnn/qnn_node_group/gelu_fusion_test.cc @@ -0,0 +1,368 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" +#include "test/providers/qnn/qnn_test_utils.h" +#include "test/unittest_util/qdq_test_utils.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +namespace { + +// Helper function to build GELU Pattern 1: Mul(0.5) before the main sequence +// Pattern 1: +// +-------Mul(0.5)---------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul ==> +// (B=1.4142...) (1) +GetTestModelFn BuildGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Create Mul(0.5) branch: input * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* mul_half_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, half_initializer}, {mul_half_output}); + + // Create main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Final Mul: (mul_half_output) * (add_output) + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {mul_half_output, add_output}, {output}); + }; +} + +// Helper function to build GELU Pattern 2: Mul(0.5) after the main sequence +// Pattern 2: +// +------------------------------------+ +// | | +// | v +// [root] --> Div -----> Erf --> Add --> Mul -->Mul ==> +// (B=1.4142...) (1) (0.5) +GetTestModelFn BuildGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input + NodeArg* input = MakeTestInput(builder, input_def); + + // Main branch: input / sqrt(2) + NodeArg* sqrt2_initializer = builder.MakeScalarInitializer(sqrt_2); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input, sqrt2_initializer}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + NodeArg* one_initializer = builder.MakeScalarInitializer(one); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_initializer}, {add_output}); + + // Mul with input: input * add_output + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input, add_output}, {mul_output}); + + // Final Mul with 0.5: mul_output * 0.5 + NodeArg* half_initializer = builder.MakeScalarInitializer(half); + NodeArg* output = builder.MakeOutput(); + builder.AddNode("Mul", {mul_output, half_initializer}, {output}); + }; +} + +// Helper function to build QDQ GELU Pattern 1 +template +GetTestQDQModelFn BuildQDQGeluPattern1TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input with QDQ + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); + + // Create Mul(0.5) branch: input * 0.5 + // Quantize half constant with DequantizeLinear node (quant_value=255, scale=half/255) + NodeArg* half_initializer_quant = builder.MakeInitializer({}, {static_cast(255)}); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_initializer_quant, half / 255.0f, 0, half_dq); + NodeArg* mul_half_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_qdq, half_dq}, {mul_half_output}); + + // Create main branch: input / sqrt(2) + // Quantize sqrt(2) constant with DequantizeLinear node (quant_value=255, scale=sqrt_2/255) + NodeArg* sqrt2_initializer_quant = builder.MakeInitializer({}, {static_cast(255)}); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_initializer_quant, sqrt_2 / 255.0f, 0, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_qdq, sqrt2_dq}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + // Quantize one constant with DequantizeLinear node (quant_value=255, scale=one/255) + NodeArg* one_initializer_quant = builder.MakeInitializer({}, {static_cast(255)}); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_initializer_quant, one / 255.0f, 0, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_dq}, {add_output}); + + // Final Mul: (mul_half_output) * (add_output) + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_half_output, add_output}, {mul_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +// Helper function to build QDQ GELU Pattern 2 +template +GetTestQDQModelFn BuildQDQGeluPattern2TestCase(const TestInputDef& input_def) { + return [input_def](ModelTestBuilder& builder, std::vector>& output_qparams) -> void { + constexpr float sqrt_2 = 1.4142135381698608f; + constexpr float half = 0.5f; + constexpr float one = 1.0f; + + // Create input with QDQ + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); + + // Main branch: input / sqrt(2) + // Quantize sqrt(2) constant with DequantizeLinear node (quant_value=255, scale=sqrt_2/255) + NodeArg* sqrt2_initializer_quant = builder.MakeInitializer({}, {static_cast(255)}); + NodeArg* sqrt2_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(sqrt2_initializer_quant, sqrt_2 / 255.0f, 0, sqrt2_dq); + NodeArg* div_output = builder.MakeIntermediate(); + builder.AddNode("Div", {input_qdq, sqrt2_dq}, {div_output}); + + // Erf + NodeArg* erf_output = builder.MakeIntermediate(); + builder.AddNode("Erf", {div_output}, {erf_output}); + + // Add 1.0 + // Quantize one constant with DequantizeLinear node (quant_value=255, scale=one/255) + NodeArg* one_initializer_quant = builder.MakeInitializer({}, {static_cast(255)}); + NodeArg* one_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(one_initializer_quant, one / 255.0f, 0, one_dq); + NodeArg* add_output = builder.MakeIntermediate(); + builder.AddNode("Add", {erf_output, one_dq}, {add_output}); + + // Mul with input: input * add_output + NodeArg* mul_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {input_qdq, add_output}, {mul_output}); + + // Final Mul with 0.5 + // Quantize half constant with DequantizeLinear node (quant_value=255, scale=half/255) + NodeArg* half_initializer_quant = builder.MakeInitializer({}, {static_cast(255)}); + NodeArg* half_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(half_initializer_quant, half / 255.0f, 0, half_dq); + NodeArg* mul_final_output = builder.MakeIntermediate(); + builder.AddNode("Mul", {mul_output, half_dq}, {mul_final_output}); + + // Add output QDQ + AddQDQNodePairWithOutputAsGraphOutput(builder, mul_final_output, output_qparams[0].scale, + output_qparams[0].zero_point); + }; +} + +ProviderOptions GetProviderOptions() { + ProviderOptions provider_options; + provider_options["backend_type"] = "htp"; + provider_options["offload_graph_io_quantization"] = "0"; + return provider_options; +} + +} // namespace + +// Test GELU Pattern 1 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +// Test GELU Pattern 2 with float32 model (for baseline comparison) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_Float32) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +// Test GELU Pattern 1 with QDQ (uint8) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern1TestCase(input_def), + BuildQDQGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/QDQTolerance(0.005f)); +} + +// Test GELU Pattern 2 with QDQ (uint8) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_QDQ_U8) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern2TestCase(input_def), + BuildQDQGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/QDQTolerance(0.005f)); +} + +// Test GELU Pattern 1 with QDQ (uint16) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_QDQ_U16) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern1TestCase(input_def), + BuildQDQGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/QDQTolerance(0.002f)); +} + +// Test GELU Pattern 2 with QDQ (uint16) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_QDQ_U16) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 2, 3, 4}, false, -1.0f, 1.0f); + + TestQDQModelAccuracy(BuildGeluPattern2TestCase(input_def), + BuildQDQGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*tolerance=*/QDQTolerance(0.002f)); +} + +// Test GELU Pattern 1 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -2.0f, 2.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +// Test GELU Pattern 2 with larger input shape +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_LargeInput) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 128, 768}, false, -2.0f, 2.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +// Test GELU Pattern 1 with different input ranges +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_DifferentRange) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -3.0f, 3.0f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +// Test GELU Pattern 2 with different input ranges +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_DifferentRange) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({1, 16, 32}, false, -3.0f, 3.0f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +// Test GELU Pattern 1 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern1_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern1TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +// Test GELU Pattern 2 with 2D input (typical for linear layers) +TEST_F(QnnHTPBackendTests, GeluFusionPattern2_2D) { + ProviderOptions provider_options = GetProviderOptions(); + auto input_def = TestInputDef({32, 512}, false, -1.5f, 1.5f); + + RunQnnModelTest(BuildGeluPattern2TestCase(input_def), + provider_options, + /*opset_version=*/13, + /*expected_ep_assignment=*/ExpectedEPNodeAssignment::All, + /*fp32_abs_err=*/1e-4f); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) + +} // namespace test +} // namespace onnxruntime + +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index a53ddbf500ffa..f76a6e036c661 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -8,6 +8,7 @@ import os import shutil import unittest +from importlib.util import find_spec import onnx import pytest @@ -20,12 +21,16 @@ from benchmark_helper import Precision from convert_generation import main as run from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models - from models.whisper.convert_to_onnx import main as run_whisper + + if not find_spec("onnxruntime.training"): + from models.whisper.convert_to_onnx import main as run_whisper else: from onnxruntime.transformers.benchmark_helper import Precision from onnxruntime.transformers.convert_generation import main as run from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models - from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper + + if not find_spec("onnxruntime.training"): + from onnxruntime.transformers.models.whisper.convert_to_onnx import main as run_whisper def has_cuda_environment(): @@ -464,7 +469,7 @@ def setUp(self): self.int8_cpu_arguments = [ "--precision", "int8", - "--quantize_embedding_layer", + "--quantize_symmetric", ] def tearDown(self): @@ -509,21 +514,33 @@ def run_configs(self, optional_arguments): if "--model_impl" not in arguments: self.run_export(arguments) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_required_args(self): optional_args = [] self.run_configs(optional_args) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_forced_decoder_ids(self): decoder_input_ids = ["--use_forced_decoder_ids"] self.run_configs(decoder_input_ids) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_logits_processor(self): logits_processor = ["--use_logits_processor"] self.run_configs(logits_processor) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_cross_qk_overall(self): cross_qk_input_args = [ @@ -540,6 +557,9 @@ def test_cross_qk_overall(self): ] self.run_configs(cross_qk_input_args + cross_qk_output_args) + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_2bits" + ) @pytest.mark.slow def test_openai_impl_whisper(self): optional_args = ["--model_impl", "openai"] diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 591be538ac873..9447dcc2d0baf 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -505,6 +505,7 @@ def generate_build_tree( "-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING=" + ("ON" if args.enable_wasm_exception_throwing_override else "OFF"), "-Donnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER=" + ("ON" if args.wasm_run_tests_in_browser else "OFF"), + "-Donnxruntime_ENABLE_WEBASSEMBLY_JSPI=" + ("ON" if args.enable_wasm_jspi else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=" + ("ON" if args.enable_wasm_threads else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), @@ -621,6 +622,7 @@ def generate_build_tree( build_dir, configs, emscripten_root_path, + args.enable_wasm_jspi, not args.disable_rtti, not args.disable_wasm_exception_catching, args.minimal_build is not None, @@ -866,6 +868,12 @@ def generate_build_tree( # if args.use_jsep and args.use_webgpu: # raise BuildError("JSEP (--use_jsep) and WebGPU (--use_webgpu) cannot be enabled at the same time.") + if args.enable_wasm_jspi: + if args.use_jsep: + raise BuildError("JSEP (--use_jsep) and WASM JSPI (--enable_wasm_jspi) cannot be enabled at the same time.") + if args.disable_wasm_exception_catching: + raise BuildError("Cannot set WebAssembly exception catching in JSPI build.") + if not args.use_webgpu: if args.use_external_dawn: raise BuildError("External Dawn (--use_external_dawn) must be enabled with WebGPU (--use_webgpu).") diff --git a/tools/ci_build/build_args.py b/tools/ci_build/build_args.py index 05d5052067b2e..6763973406294 100644 --- a/tools/ci_build/build_args.py +++ b/tools/ci_build/build_args.py @@ -360,6 +360,9 @@ def add_webassembly_args(parser: argparse.ArgumentParser) -> None: parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly.") parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build WebAssembly static library.") parser.add_argument("--emsdk_version", default="4.0.11", help="Specify version of emsdk.") + parser.add_argument( + "--enable_wasm_jspi", action="store_true", help="Enable WebAssembly JavaScript Promise Integration." + ) parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD.") parser.add_argument("--enable_wasm_relaxed_simd", action="store_true", help="Enable WebAssembly Relaxed SIMD.") parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threading.") diff --git a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml index e54216fe4ef4e..64f8146b25fe4 100644 --- a/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/build-perf-test-binaries-pipeline.yml @@ -32,4 +32,4 @@ stages: extra_build_arg: '' cmake_build_type: Release cuda_version: 12.8 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20250714.2 \ No newline at end of file + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 399b44a7f3cb2..b4e6040731d5f 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -57,6 +57,13 @@ parameters: type: string default: 2.39.0.250926 +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.8' + values: + - 12.8 + resources: repositories: - repository: onnxruntime-inference-examples # The name used to reference this repository in the checkout step @@ -70,11 +77,6 @@ resources: variables: - template: templates/common-variables.yml -- name: ReleaseVersionSuffix - value: '' -- name: win_trt_version - value: 12.8 - - name: win_trt_home value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda12 }} - name: win_cuda_home @@ -142,7 +144,7 @@ extends: - template: stages/nuget-combine-cuda-stage.yml parameters: - CudaVersion: 12.8 + CudaVersion: ${{ parameters.CudaVersion }} RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} UseIncreasedTimeoutForTests: ${{ parameters.UseIncreasedTimeoutForTests }} win_trt_home: ${{ variables.win_trt_home }} diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index d7fc0efbf45ea..8390295388c6d 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -51,19 +51,20 @@ parameters: default: '12.8' values: - 12.8 + - 13.0 variables: - template: templates/common-variables.yml - name: ReleaseVersionSuffix value: '' - name: win_trt_home - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: $(Agent.TempDirectory)\${{ variables.win_trt_folder_cuda12 }} - name: win_cuda_home - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: $(Agent.TempDirectory)\v13.0 ${{ if eq(parameters.CudaVersion, '12.8') }}: value: $(Agent.TempDirectory)\v12.8 diff --git a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml index eff2b4d885721..fe0f2427d31d9 100644 --- a/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/custom-nuget-packaging-pipeline.yml @@ -38,8 +38,8 @@ variables: - name: ReleaseVersionSuffix value: '' - name: win_cuda_home - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: $(Agent.TempDirectory)\v11.8 + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: $(Agent.TempDirectory)\v13.0 ${{ if eq(parameters.CudaVersion, '12.8') }}: value: $(Agent.TempDirectory)\v12.8 diff --git a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml index 463c02203e21a..9d831df54096a 100644 --- a/tools/ci_build/github/azure-pipelines/jar_package_testing.yml +++ b/tools/ci_build/github/azure-pipelines/jar_package_testing.yml @@ -45,7 +45,7 @@ stages: DownloadTRT: true - template: templates/setup-maven.yml - + - task: Maven@4 displayName: 'Download Java Dependencies' inputs: @@ -105,7 +105,7 @@ stages: - name: runCodesignValidationInjection value: false - name: docker_base_image - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20251008.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 timeoutInMinutes: 60 steps: - checkout: self diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index 5e6671e3797ce..b36df5748f3b3 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -34,20 +34,22 @@ parameters: default: '12.8' values: - 12.8 + - 13.0 variables: - template: templates/common-variables.yml - name: docker_base_image - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20251008.2 + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251017.1 ${{ if eq(parameters.CudaVersion, '12.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20251008.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 - name: linux_trt_version - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.linux_trt_version_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.linux_trt_version_cuda12 }} + jobs: - job: Linux_Build timeoutInMinutes: 180 diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index b60ef7576184e..b26c96892952b 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -9,7 +9,7 @@ parameters: SpecificArtifact: false CustomOpArtifactName: 'onnxruntime-linux-x64' BuildId: '0' - CudaVersion: '11.8' + CudaVersion: '12.8' stages: - stage: NuGet_Test_Linux_${{ parameters.StageSuffix }}${{ parameters.MoreSuffix }} dependsOn: @@ -41,7 +41,7 @@ stages: - script: | mv $(Pipeline.Workspace)/build/drop-signed-nuget-${{ parameters.ArtifactSuffix }} $(Build.BinariesDirectory)/nuget-artifact mv $(Pipeline.Workspace)/build/${{ parameters.CustomOpArtifactName }} $(Build.BinariesDirectory)/testdata - + - template: get-nuget-package-version-as-variable.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index fdfafd4d9a179..71f8dd567793c 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -5,14 +5,20 @@ parameters: default: '12.8' values: - 12.8 + - 13.0 variables: - template: templates/common-variables.yml - name: win_trt_folder - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.win_trt_folder_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.win_trt_folder_cuda12 }} + - name: setup_env_script + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: 'setup_env_cuda13.bat' + ${{ if eq(parameters.CudaVersion, '12.8') }}: + value: 'setup_env_cuda12.bat' stages: - template: templates/web-ci.yml @@ -219,8 +225,9 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat + EnvSetupScript: '${{ variables.setup_env_script }}' buildArch: x64 + CudaVersion: ${{ parameters.CudaVersion }} additionalBuildFlags: --build_wheel --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --enable_cuda_profiling --enable_transformers_tool_test --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 @@ -242,8 +249,9 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_trt.bat + EnvSetupScript: '${{ variables.setup_env_script }}' buildArch: x64 + CudaVersion: ${{ parameters.CudaVersion }} additionalBuildFlags: --config RelWithDebInfo --parallel --use_binskim_compliant_compile_flags --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_wheel --enable_onnx_tests --use_tensorrt --tensorrt_home="$(Agent.TempDirectory)\${{ variables.win_trt_folder }}" --cuda_home="$(Agent.TempDirectory)\v${{ parameters.CudaVersion }}" --use_vcpkg --use_vcpkg_ms_internal_asset_cache --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 isX86: false diff --git a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml index 02b6a6df76611..8beae99218867 100644 --- a/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-cuda-package-test-pipeline.yml @@ -18,7 +18,7 @@ stages: machine_pool: 'Onnxruntime-Linux-GPU' python_wheel_suffix: '_gpu' timeout: 480 - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20251008.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 cuda_version: '12.8' - stage: Republish_Wheels diff --git a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml index b53aee639372d..dde00c7a36852 100644 --- a/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml +++ b/tools/ci_build/github/azure-pipelines/stages/jobs/py-linux-cuda-package-test-job.yml @@ -43,13 +43,13 @@ jobs: variables: - template: ../../templates/common-variables.yml - name: docker_base_image - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda11_x64_almalinux8_gcc11:20251008.2 + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda13_x64_almalinux8_gcc14:20251017.1 ${{ if eq(parameters.CudaVersion, '12.8') }}: - value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20251008.2 + value: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 - name: linux_trt_version - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.linux_trt_version_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.linux_trt_version_cuda12 }} pool: ${{ parameters.machine_pool }} diff --git a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml index 8cbb81ba89c12..ff35d3e35ef6c 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nodejs-linux-packaging-stage.yml @@ -18,15 +18,15 @@ stages: variables: - template: ../templates/common-variables.yml - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: '13' ${{ if eq(parameters.CudaVersion, '12.8') }}: value: '12' - name: CUDA_VERSION value: ${{ parameters.CudaVersion }} - name: linux_trt_version - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.linux_trt_version_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.linux_trt_version_cuda12 }} steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml index b1e5f541b90e0..1ab7155d8abc9 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-linux-cuda-packaging-stage.yml @@ -20,8 +20,8 @@ stages: os: linux variables: - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: '13' ${{ if eq(parameters.CudaVersion, '12.8') }}: value: '12' - name: CUDA_VERSION @@ -72,15 +72,15 @@ stages: variables: - template: ../templates/common-variables.yml - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: '13' ${{ if eq(parameters.CudaVersion, '12.8') }}: value: '12' - name: CUDA_VERSION value: ${{ parameters.CudaVersion }} - name: linux_trt_version - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.linux_trt_version_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.linux_trt_version_cuda12 }} steps: @@ -138,13 +138,13 @@ stages: variables: - template: ../templates/common-variables.yml - name: CUDA_VERSION_MAJOR - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: '11' + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: '13' ${{ if eq(parameters.CudaVersion, '12.8') }}: value: '12' - name: linux_trt_version - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.linux_trt_version_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.linux_trt_version_cuda12 }} steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index e7e541205ba0a..a948a3e6aff5a 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -13,7 +13,7 @@ parameters: - name: CudaVersion type: string - default: '11.8' + default: '13.0' - name: win_cuda_home type: string diff --git a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml index 3c5cf591039e0..755c6e0e88bd6 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-gpu-packaging-stage.yml @@ -48,4 +48,4 @@ stages: extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} cuda_version: ${{ parameters.cuda_version }} - docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12:20251008.2 + docker_base_image: onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14:20251017.1 diff --git a/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml index ab1fb919af413..f5b2c05931808 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-linux-gpu-stage.yml @@ -53,8 +53,8 @@ stages: value: '' - template: ../templates/common-variables.yml - name: trt_version - ${{ if eq(parameters.cuda_version, '11.8') }}: - value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.cuda_version, '13.0') }}: + value: ${{ variables.linux_trt_version_cuda13 }} ${{ if eq(parameters.cuda_version, '12.8') }}: value: ${{ variables.linux_trt_version_cuda12 }} steps: @@ -81,16 +81,14 @@ stages: - script: | set -e -x - mv $(Build.BinariesDirectory)/${{ parameters.cmake_build_type }} ./${{ parameters.cmake_build_type }} + mv $(Build.BinariesDirectory)/${{ parameters.cmake_build_type }} ./${{ parameters.cmake_build_type }} mv $(Build.BinariesDirectory)/dist ./dist pushd dist find . -name \*.whl -exec unzip -qq -o {} \; rm -r onnxruntime popd - pushd ${{ parameters.cmake_build_type }} + pushd ${{ parameters.cmake_build_type }} find . -name \*.whl -exec unzip -qq -o {} \; popd workingDirectory: '$(Build.ArtifactStagingDirectory)' displayName: 'Move files' - - diff --git a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml index c3957fc8341de..0c163f74768ca 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-win-gpu-stage.yml @@ -74,8 +74,8 @@ stages: - name: CUDA_MODULE_LOADING value: 'LAZY' - name: win_trt_folder - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.win_trt_folder_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.win_trt_folder_cuda12 }} - name: trt_build_flag @@ -119,7 +119,7 @@ stages: --cmake_generator "$(VSGenerator)" --enable_pybind --enable_onnx_tests - --parallel 8 --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags --update --build + --parallel 8 --use_vcpkg --use_vcpkg_ms_internal_asset_cache --use_binskim_compliant_compile_flags --update --build $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} ${{ variables.trt_build_flag }} workingDirectory: '$(Build.BinariesDirectory)' @@ -213,7 +213,7 @@ stages: TMPDIR: "$(Agent.TempDirectory)" - powershell: | - $ErrorActionPreference = "Stop" + $ErrorActionPreference = "Stop" python -m pip uninstall -y onnxruntime onnxruntime-${{ parameters.EP_NAME }} -qq dir $(Build.ArtifactStagingDirectory) python -m pip --disable-pip-version-check install --no-index --find-links $(Build.ArtifactStagingDirectory) onnxruntime-${{ parameters.EP_NAME }} diff --git a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml index 39a958e848784..a0ac2cfad1a93 100644 --- a/tools/ci_build/github/azure-pipelines/templates/common-variables.yml +++ b/tools/ci_build/github/azure-pipelines/templates/common-variables.yml @@ -1,7 +1,8 @@ variables: - common_trt_version: '10.9.0.34' + cuda12_trt_version: '10.9.0.34' + cuda13_trt_version: '10.13.3.9' # As for Debian installation, replace '-1.' by '-1+' when assigning trt version below - linux_trt_version_cuda11: ${{ variables.common_trt_version }}-1.cuda11.8 - linux_trt_version_cuda12: ${{ variables.common_trt_version }}-1.cuda12.8 - win_trt_folder_cuda11: TensorRT-${{ variables.common_trt_version }}.Windows10.x86_64.cuda-11.8 - win_trt_folder_cuda12: TensorRT-${{ variables.common_trt_version }}.Windows10.x86_64.cuda-12.8 \ No newline at end of file + linux_trt_version_cuda13: ${{ variables.cuda13_trt_version }}-1.cuda13.0 + linux_trt_version_cuda12: ${{ variables.cuda12_trt_version }}-1.cuda12.8 + win_trt_folder_cuda13: TensorRT-${{ variables.cuda13_trt_version }}.Windows10.win10.cuda-13.0 + win_trt_folder_cuda12: TensorRT-${{ variables.cuda12_trt_version }}.Windows10.x86_64.cuda-12.8 diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml index be213337091e8..631f40bcdd22a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_gpu_library.yml @@ -9,14 +9,14 @@ parameters: type: string default: '12.8' values: - - 11.8 + - 13.0 - 12.8 - name: TrtVersion type: string default: '10.9.0.34' values: - - 8.6.1.6 - 10.9.0.34 + - 10.13.3.9 steps: - ${{ if eq(parameters.DownloadCUDA, true) }}: @@ -42,37 +42,48 @@ steps: displayName: 'Print PATH after download CUDA SDK' - ${{ if eq(parameters.DownloadTRT, true) }}: - - ${{ if eq(parameters.CudaVersion, '11.8') }}: + - ${{ if eq(parameters.CudaVersion, '13.0') }}: - powershell: | - Write-Host "##vso[task.setvariable variable=trtCudaVersion;]11.8" + Write-Host "##vso[task.setvariable variable=trtCudaVersion;]13.0" displayName: Set trtCudaVersion - - ${{ if and(eq(parameters.CudaVersion, '12.8'), eq(parameters.TrtVersion, '8.6.1.6')) }}: + - ${{ if and(eq(parameters.CudaVersion, '12.8'), eq(parameters.TrtVersion, '10.13.3.9')) }}: - powershell: | - Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.0" + Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.9" displayName: Set trtCudaVersion - ${{ if and(eq(parameters.CudaVersion, '12.8'), eq(parameters.TrtVersion, '10.9.0.34')) }}: - powershell: | Write-Host "##vso[task.setvariable variable=trtCudaVersion;]12.8" displayName: Set trtCudaVersion - - script: | - echo $(trtCudaVersion) && echo TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) - displayName: Get trtCudaVersion and Directory Name + - ${{ if eq(parameters.TrtVersion, '10.9.0.34') }}: + - powershell: | + Write-Host "##vso[task.setvariable variable=trtPlatformString;]Windows10.x86_64" + displayName: 'Set TRT platform string for 10.9.0.34' + - ${{ if eq(parameters.TrtVersion, '10.13.3.9') }}: + - powershell: | + Write-Host "##vso[task.setvariable variable=trtPlatformString;]Windows.win10" + displayName: 'Set TRT platform string for 10.13.3.9' + + - powershell: | + $trtDirName = "TensorRT-${{ parameters.TrtVersion }}.$(trtPlatformString).cuda-$(trtCudaVersion)" + Write-Host "TensorRT Directory Name: $trtDirName" + Write-Host "##vso[task.setvariable variable=trtDirName;]$trtDirName" + displayName: 'Construct TensorRT Directory Name' - task: AzureCLI@2 - displayName: 'Download TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion)' + displayName: 'Download $(trtDirName)' inputs: azureSubscription: AIInfraBuildOnnxRuntimeOSS scriptType: 'batch' scriptLocation: 'inlineScript' inlineScript: | set AZCOPY_AUTO_LOGIN_TYPE=AZCLI - azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion) $(Agent.TempDirectory) + azcopy.exe cp --recursive https://lotusscus.blob.core.windows.net/models/local/$(trtDirName) $(Agent.TempDirectory) - powershell: | - Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\TensorRT-${{ parameters.TrtVersion }}.Windows10.x86_64.cuda-$(trtCudaVersion)\lib" - displayName: 'Append TensorRT-${{ parameters.TrtVersion }} Directory to PATH' + Write-Host "##vso[task.prependpath]$(Agent.TempDirectory)\$(trtDirName)\lib" + displayName: 'Append $(trtDirName) Directory to PATH' - task: CmdLine@2 inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index d7c940cda30f4..00370eedb8d6e 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -7,15 +7,15 @@ parameters: - name: DownloadTRT type: boolean default: false - - name: PrimaryCUDAVersion + - name: CudaVersion type: string default: '12.8' -# - name: SecondaryCUDAVersion -# type: string -# default: '11.8' -# - name: win_trt_folder_cuda11 -# type: string -# default: 'TensorRT-10.9.0.34.Windows10.x86_64.cuda-11.8' + values: + - 13.0 + - 12.8 + - name: win_trt_folder_cuda13 + type: string + default: 'TensorRT-10.13.3.9.Windows.win10.cuda-13.0' - name: win_trt_folder_cuda12 type: string default: 'TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8' @@ -23,18 +23,18 @@ parameters: steps: - ${{ if eq(parameters.DownloadCUDA, 'true') }}: - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.PrimaryCUDAVersion }}" $(Agent.TempDirectory) - displayName: 'Download Primary CUDA SDK v${{ parameters.PrimaryCUDAVersion }}' -# - powershell: | -# azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.SecondaryCUDAVersion }}" $(Agent.TempDirectory) -# displayName: 'Download Secondary CUDA SDK v${{ parameters.SecondaryCUDAVersion }}' + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.CudaVersion }}" $(Agent.TempDirectory) + displayName: 'Download Primary CUDA SDK v${{ parameters.CudaVersion }}' + - ${{ if eq(parameters.DownloadTRT, 'true') }}: - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda12 }}" $(Agent.TempDirectory) - displayName: 'Download ${{ parameters.win_trt_folder_cuda12 }}' -# - powershell: | -# azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda11 }}" $(Agent.TempDirectory) -# displayName: 'Download ${{ parameters.win_trt_folder_cuda11 }}' + - ${{ if eq(parameters.CudaVersion, '12.8') }}: + - powershell: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda12 }}" $(Agent.TempDirectory) + displayName: 'Download ${{ parameters.win_trt_folder_cuda12 }}' + - ${{ if eq(parameters.CudaVersion, '13.0') }}: + - powershell: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/local/${{ parameters.win_trt_folder_cuda13 }}" $(Agent.TempDirectory) + displayName: 'Download ${{ parameters.win_trt_folder_cuda13 }}' - task: BatchScript@1 displayName: 'setup env' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index 39da103a2285b..ecafd578e1b6d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -27,6 +27,13 @@ parameters: - name: Today type: string +- name: CudaVersion + type: string + default: '12.8' + values: + - 13.0 + - 12.8 + steps: - task: UsePythonVersion@0 inputs: @@ -69,6 +76,7 @@ steps: - template: set-winenv.yml parameters: + CudaVersion: ${{ parameters.CudaVersion }} EnvSetupScript: ${{parameters.EnvSetupScript}} DownloadCUDA: ${{parameters.DownloadCUDA}} DownloadTRT: ${{parameters.DownloadTRT}} @@ -94,4 +102,3 @@ steps: (C:\ProgramData\chocolatey\bin\cl.exe -?) -match 'Compiler Version' displayName: Install ccache and update PATH to use linked versions of gcc, cc, etc - diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 083381817818b..eb6492f779b94 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -55,6 +55,13 @@ parameters: type: number default: 0 +- name: CudaVersion + displayName: CUDA version + type: string + default: '12.8' + values: + - 12.8 + - 13.0 jobs: - job: build_${{ parameters.job_name_suffix }} @@ -85,6 +92,7 @@ jobs: - template: win-ci-prebuild-steps.yml parameters: + CudaVersion: ${{ parameters.CudaVersion }} EnvSetupScript: ${{parameters.EnvSetupScript}} ${{ if contains(parameters.additionalBuildFlags, 'use_cuda') }}: DownloadCUDA: true @@ -142,7 +150,7 @@ jobs: restoreSolution: '$(Build.SourcesDirectory)\packages.config' nugetConfigPath: '$(Build.SourcesDirectory)\tools\ci_build\github\azure-pipelines\nuget\nuget_config\nuget.config' restoreDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}' - + - ${{ if eq(parameters.RunOnnxRuntimeTests, true) }}: - powershell: | python.exe $(Build.SourcesDirectory)\tools\ci_build\build.py --config ${{ parameters.BuildConfig }} --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_csharp --parallel --use_binskim_compliant_compile_flags --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests ${{ parameters.additionalBuildFlags }} @@ -160,7 +168,7 @@ jobs: Remove-Item "$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}" -Include "*.obj" -Recurse displayName: 'Build' - - script: + - script: python tools\ValidateNativeDelegateAttributes.py displayName: 'Validate C# native delegates' workingDirectory: '$(Build.SourcesDirectory)\csharp' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 263f73a9e29b0..3244008eb281a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -44,8 +44,8 @@ jobs: - name: skipComponentGovernanceDetection value: true - name: trt_version - ${{ if eq(parameters.cuda_version, '11.8') }}: - value: ${{ variables.linux_trt_version_cuda11 }} + ${{ if eq(parameters.cuda_version, '13.0') }}: + value: ${{ variables.linux_trt_version_cuda13 }} ${{ if eq(parameters.cuda_version, '12.8') }}: value: ${{ variables.linux_trt_version_cuda12 }} workspace: diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 8018da41fbc2d..337218e300fcd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -56,7 +56,6 @@ stages: - stage: ReactNative_CI_iOS displayName: ReactNative_CI_iOS dependsOn: '${{parameters.InitialStageDependsOn}}' - variables: jobs: - job: ReactNative_CI_iOS_build pool: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index ca698123a04e7..2d4d05755eb24 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -75,9 +75,9 @@ parameters: - name: CudaVersion type: string - default: '11.8' + default: '12.8' values: - - 11.8 + - 13.0 - 12.8 - name: SpecificArtifact @@ -123,7 +123,7 @@ stages: - output: pipelineArtifact targetPath: $(Build.ArtifactStagingDirectory) artifactName: 'onnxruntime${{ parameters.artifact_name_suffix }}-win-${{ parameters.packageName }}' - + - ${{ if eq(parameters.buildJava, 'true') }}: - output: pipelineArtifact targetPath: $(Build.BinariesDirectory)\onnxruntime-java-win-${{ parameters.msbuildPlatform }} diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml index 8b320b0ceb4ac..b75611e023c25 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-doc-gen-ci-pipeline.yml @@ -34,9 +34,16 @@ parameters: type: string default: '12.8' values: - - 11.8 + - 13.0 - 12.8 +variables: + - name: setup_env_script + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: 'setup_env_cuda13.bat' + ${{ if eq(parameters.CudaVersion, '12.8') }}: + value: 'setup_env_cuda12.bat' + stages: - stage: kernelDocumentation dependsOn: [] @@ -44,8 +51,9 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda.bat + EnvSetupScript: '${{ variables.setup_env_script }}' buildArch: x64 + CudaVersion: ${{ parameters.CudaVersion }} # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags additionalBuildFlags: >- --gen_doc validate --skip_tests --build_wheel --use_dml --use_cuda diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-cuda-minimal-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-cuda-minimal-ci-pipeline.yml index 08953749f6527..459951893433e 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-cuda-minimal-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-tensorrt-cuda-minimal-ci-pipeline.yml @@ -34,21 +34,26 @@ parameters: default: '12.8' values: - 12.8 + - 13.0 variables: - template: templates/common-variables.yml - name: win_trt_folder - ${{ if eq(parameters.CudaVersion, '11.8') }}: - value: ${{ variables.win_trt_folder_cuda11 }} + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: ${{ variables.win_trt_folder_cuda13 }} ${{ if eq(parameters.CudaVersion, '12.8') }}: value: ${{ variables.win_trt_folder_cuda12 }} + - name: setup_env_script + ${{ if eq(parameters.CudaVersion, '13.0') }}: + value: 'setup_env_cuda13.bat' + ${{ if eq(parameters.CudaVersion, '12.8') }}: + value: 'setup_env_cuda12.bat' jobs: - job: 'build' pool: 'onnxruntime-Win2022-GPU-A10' variables: MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' - EnvSetupScript: setup_env_trt.bat skipComponentGovernanceDetection: true TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] timeoutInMinutes: 150 @@ -57,7 +62,8 @@ jobs: steps: - template: templates/jobs/win-ci-prebuild-steps.yml parameters: - EnvSetupScript: $(EnvSetupScript) + CudaVersion: ${{ parameters.CudaVersion }} + EnvSetupScript: '${{ variables.setup_env_script }}' DownloadCUDA: true DownloadTRT: true BuildArch: 'x64' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index a277286866e41..56c10d87366fb 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,4 +1,4 @@ -ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251008.2 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251017.1 FROM $BASEIMAGE ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 489e4ce9f3913..6c70549bd3c2b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -1,8 +1,8 @@ -# The default ARGs are for cuda 11.8 with cudnn8, TensorRT is optional +# The default ARGs are for cuda 13.0 with cudnn8, TensorRT is optional # Please overwrite BASEIMAGE, TRT_VERSION and other arguments with # --docker-build-args ' --build-arg BASEIMAGE=other_base_image --build-arg TRT_VERSION=other_trt_version etc...' # for other cuda version and TRT version -ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 +ARG BASEIMAGE=nvidia/cuda:12.8.1-cudnn-devel-ubi8 FROM $BASEIMAGE ARG TRT_VERSION=10.9.0.34-1.cuda12.8 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 5410bd64036ce..ac72d043eb182 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -1,4 +1,4 @@ -ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251008.2 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251017.1 FROM $BASEIMAGE ARG ROCM_VERSION=6.2.3 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu index 07ad8e933baf0..dc6a355f32754 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_webgpu @@ -1,4 +1,4 @@ -ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251008.2 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251017.1 FROM $BASEIMAGE ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-17 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 index 1933fd371d3bc..83b1e97096fee 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0 @@ -5,7 +5,7 @@ # Dockerfile to Test ONNX Runtime on UBI8 with TensorRT 10 and CUDA 12 by default # Build base image with required system packages -ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 +ARG BASEIMAGE=nvidia/cuda:12.8.1-cudnn-devel-ubi8 ARG TRT_VERSION=10.9.0.34-1.cuda12.8 FROM $BASEIMAGE AS base ARG TRT_VERSION diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch deleted file mode 100644 index 62562705c92b2..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda_tensorrt10_0_torch +++ /dev/null @@ -1,57 +0,0 @@ -# -------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------- -# Dockerfile to Test ONNX Runtime on UBI8 with TensorRT 10.0 and CUDA 11.8 by default - -# Build base image with required system packages -ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG TRT_VERSION=10.9.0.34-1.cuda11.8 -FROM $BASEIMAGE AS base -ARG TRT_VERSION -ENV PATH=/opt/python/cp310-cp310/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} - -RUN dnf install -y bash wget &&\ - dnf clean dbcache - -RUN pip3 install --upgrade pip -RUN pip3 install setuptools>=68.2.2 - -#Install TensorRT only if TRT_VERSION is not empty -RUN if [ -n "$TRT_VERSION" ]; then \ - echo "TRT_VERSION is $TRT_VERSION" && \ - dnf -y install \ - libnvinfer10-${TRT_VERSION} \ - libnvinfer-headers-devel-${TRT_VERSION} \ - libnvinfer-devel-${TRT_VERSION} \ - libnvinfer-lean10-${TRT_VERSION} \ - libnvonnxparsers10-${TRT_VERSION} \ - libnvonnxparsers-devel-${TRT_VERSION} \ - libnvinfer-dispatch10-${TRT_VERSION} \ - libnvinfer-plugin10-${TRT_VERSION} \ - libnvinfer-vc-plugin10-${TRT_VERSION} \ - libnvinfer-bin-${TRT_VERSION} \ - libnvinfer-plugin10-${TRT_VERSION} \ - libnvinfer-plugin-devel-${TRT_VERSION} \ - libnvinfer-vc-plugin-devel-${TRT_VERSION} \ - libnvinfer-lean-devel-${TRT_VERSION} \ - libnvinfer-dispatch-devel-${TRT_VERSION} \ - libnvinfer-headers-plugin-devel-${TRT_VERSION} && \ - dnf clean dbcache ; \ -else \ - echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ -fi - -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && /tmp/scripts/install_java.sh && rm -rf /tmp/scripts - -RUN python3 -m pip uninstall -y torch -RUN python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 - -# Build final image from base. -FROM base as final -ARG BUILD_USER=onnxruntimedev -ARG BUILD_UID=1000 -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 index 1d3575411a692..80432f2dcea66 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda12_tensorrt10 @@ -5,7 +5,7 @@ # Dockerfile to run ONNXRuntime with TensorRT integration # Build base image with required system packages -FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 AS base +FROM nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index 03f14732b70f8..511955cd767c5 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -5,7 +5,7 @@ # Dockerfile to run ONNXRuntime with TensorRT installed from provided binaries # Build base image with required system packages -FROM nvidia/cuda:12.8.0-cudnn-devel-ubuntu22.04 AS base +FROM nvidia/cuda:12.8.1-cudnn-devel-ubuntu22.04 AS base # The local directory into which to build and install CMAKE ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code @@ -27,7 +27,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/python3 python &&\ ln -s /usr/bin/pip3 pip; -RUN pip install --upgrade pip +RUN pip install --upgrade pip RUN pip install setuptools>=68.2.2 # Install TensorRT @@ -100,4 +100,4 @@ RUN /bin/sh build.sh ${PARSER_CONFIG} --parallel --build_shared_lib --cuda_home USER root # Intall ORT wheel -RUN pip install ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime/build/Linux/Release/dist/*.whl \ No newline at end of file +RUN pip install ${ONNXRUNTIME_LOCAL_CODE_DIR}/onnxruntime/build/Linux/Release/dist/*.whl diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile index 79d99d08dcc4e..d39e39a5a429d 100644 --- a/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/aarch64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20251008.2 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_aarch64_almalinux8_gcc14:20251017.1 FROM $BASEIMAGE ADD scripts /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile index 72d98206f9205..d2d5ae684e10e 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cpu/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20251008.2 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14_dotnet:20251017.1 FROM $BASEIMAGE ENV LANG=en_US.UTF-8 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile index 85f4a074e30bf..44fcdb33c2d51 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/default/cuda12/Dockerfile @@ -2,7 +2,7 @@ # Licensed under the MIT License. # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline -ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc12_dotnet:20251008.2 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cuda12_x64_almalinux8_gcc14_dotnet:20251017.1 FROM $BASEIMAGE ARG TRT_VERSION @@ -36,7 +36,7 @@ fi ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 -ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-12/root/usr/bin/g++ +ENV CUDAHOSTCXX=/opt/rh/gcc-toolset-14/root/usr/bin/g++ ADD scripts /tmp/scripts RUN sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/almalinux.repo ENV PATH=/usr/lib/jvm/msopenjdk-17/bin:$PATH diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile index 81ba47f397f91..8288a98ed2adc 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cpu/Dockerfile @@ -1,4 +1,4 @@ -ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251008.2 +ARG BASEIMAGE=onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251017.1 FROM $BASEIMAGE ADD scripts /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile index d87870db0bca8..c65febda1b33a 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/cuda/Dockerfile @@ -1,8 +1,8 @@ -# The default ARGs are for cuda 11.8 with cudnn8, TensorRT is optional +# The default ARGs are for cuda 12.8 with cudnn9, TensorRT is optional # Please overwrite BASEIMAGE, TRT_VERSION and other arguments with # --docker-build-args ' --build-arg BASEIMAGE=other_base_image --build-arg TRT_VERSION=other_trt_version etc...' # for other cuda version and TRT version -ARG BASEIMAGE=nvidia/cuda:12.5.1-cudnn-devel-ubi8 +ARG BASEIMAGE=nvidia/cuda:12.8.1-cudnn-devel-ubi8 FROM $BASEIMAGE ARG TRT_VERSION=10.9.0.34-1.cuda12.8 diff --git a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile index 5ad1023bfb5b2..44b14d31919b2 100644 --- a/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x86_64/python/openvino/Dockerfile @@ -1,5 +1,5 @@ # Use the specified UBI8 base image with GCC 14 -ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251008.2" +ARG BASEIMAGE="onnxruntimebuildcache.azurecr.io/internal/azureml/onnxruntime/build/cpu_x64_almalinux8_gcc14:20251017.1" FROM ${BASEIMAGE} ARG BUILD_UID=1000 diff --git a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh index b3acc4da57a4c..b0a7f4baaaff0 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_python_deps.sh @@ -3,21 +3,16 @@ set -e -x INSTALL_DEPS_TRAINING=false INSTALL_DEPS_DISTRIBUTED_SETUP=false -TARGET_ROCM=false -CU_VER="11.8" -TORCH_VERSION='2.0.0' USE_CONDA=false -while getopts p:h:d:v:tmurc parameter_Option +while getopts p:d:v:tmuc parameter_Option do case "${parameter_Option}" in p) PYTHON_VER=${OPTARG};; -h) TORCH_VERSION=${OPTARG};; d) DEVICE_TYPE=${OPTARG};; v) CU_VER=${OPTARG};; t) INSTALL_DEPS_TRAINING=true;; m) INSTALL_DEPS_DISTRIBUTED_SETUP=true;; -r) TARGET_ROCM=true;; c) USE_CONDA=true;; esac done diff --git a/tools/ci_build/github/windows/setup_env_cuda.bat b/tools/ci_build/github/windows/setup_env_cuda.bat deleted file mode 100644 index f095f58f9920e..0000000000000 --- a/tools/ci_build/github/windows/setup_env_cuda.bat +++ /dev/null @@ -1,17 +0,0 @@ -REM Copyright (c) Microsoft Corporation. All rights reserved. -REM Licensed under the MIT License. - -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.8\ ( - set PATH=%AGENT_TEMPDIRECTORY%\v12.8\bin;%AGENT_TEMPDIRECTORY%\v12.8\extras\CUPTI\lib64;%PATH% -) else ( - set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\extras\CUPTI\lib64;%PATH% -) - -@REM The default version is still cuda v12.8, because set cuda v11.8 after it -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( - set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 -) else ( - set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64 -) - -set GRADLE_OPTS=-Dorg.gradle.daemon=false diff --git a/tools/ci_build/github/windows/setup_env_cuda12.bat b/tools/ci_build/github/windows/setup_env_cuda12.bat new file mode 100644 index 0000000000000..7a9f3181fb36f --- /dev/null +++ b/tools/ci_build/github/windows/setup_env_cuda12.bat @@ -0,0 +1,25 @@ +REM Copyright (c) Microsoft Corporation. All rights reserved. +REM Licensed under the MIT License. + +@REM --- Setup CUDA 12.8 --- +@REM Check if a local/agent-specific version exists +if exist "%AGENT_TEMPDIRECTORY%\v12.8\" ( + echo "Using CUDA 12.8 from AGENT_TEMPDIRECTORY." + set "PATH=%AGENT_TEMPDIRECTORY%\v12.8\bin;%AGENT_TEMPDIRECTORY%\v12.8\extras\CUPTI\lib64;%PATH%" +) else ( + echo "Using system default CUDA 12.8." + set "PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\extras\CUPTI\lib64;%PATH%" +) + +@REM --- Setup TensorRT for CUDA 12.8 --- +set "TRT_12_8_PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib" +if exist "%TRT_12_8_PATH%\" ( + echo "Adding TensorRT 10.9.0 for CUDA 12.8 to PATH." + set "PATH=%TRT_12_8_PATH%;%PATH%" +) else ( + echo "Warning: TensorRT 10.9.0 directory not found at %TRT_12_8_PATH%" +) + + +set GRADLE_OPTS=-Dorg.gradle.daemon=false +set CUDA_MODULE_LOADING=LAZY diff --git a/tools/ci_build/github/windows/setup_env_cuda13.bat b/tools/ci_build/github/windows/setup_env_cuda13.bat new file mode 100644 index 0000000000000..63c33cc233d60 --- /dev/null +++ b/tools/ci_build/github/windows/setup_env_cuda13.bat @@ -0,0 +1,23 @@ +REM Copyright (c) Microsoft Corporation. All rights reserved. +REM Licensed under the MIT License. + +@REM --- Setup for CUDA 13.0 --- +if exist "%AGENT_TEMPDIRECTORY%\v13.0\" ( + echo "Using CUDA 13.0 from AGENT_TEMPDIRECTORY." + set "PATH=%AGENT_TEMPDIRECTORY%\v13.0\bin;%AGENT_TEMPDIRECTORY%\v13.0\extras\CUPTI\lib64;%PATH%" +) else ( + echo "Using system default CUDA 13.0." + set "PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.0\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v13.0\extras\CUPTI\lib64;%PATH%" +) + +@REM --- Setup TensorRT for CUDA 13.0 --- +set "TRT_13_0_PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.13.3.9.Windows.win10.cuda-13.0\lib" +if exist "%TRT_13_0_PATH%\" ( + echo "Adding TensorRT 10.13.3.9 for CUDA 13.0 to PATH." + set "PATH=%TRT_13_0_PATH%;%PATH%" +) else ( + echo "Warning: TensorRT 10.13.3.9 directory not found at %TRT_13_0_PATH%" +) + +set GRADLE_OPTS=-Dorg.gradle.daemon=false +set CUDA_MODULE_LOADING=LAZY diff --git a/tools/ci_build/github/windows/setup_env_gpu.bat b/tools/ci_build/github/windows/setup_env_gpu.bat deleted file mode 100644 index 115a19b6f3a01..0000000000000 --- a/tools/ci_build/github/windows/setup_env_gpu.bat +++ /dev/null @@ -1,21 +0,0 @@ -REM Copyright (c) Microsoft Corporation. All rights reserved. -REM Licensed under the MIT License. - -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.8\ ( - set PATH=%AGENT_TEMPDIRECTORY%\v12.8\bin;%AGENT_TEMPDIRECTORY%\v12.8\extras\CUPTI\lib64;%PATH% -) else ( - set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\extras\CUPTI\lib64;%PATH% -) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib;%PATH% - -@REM The default version is still cuda v12.8, because set cuda v11.8 after it -set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\TensorRT-10.9.0.34.Windows10.x86_64.cuda-11.8\lib -if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ ( - set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v11.8\bin;%AGENT_TEMPDIRECTORY%\v11.8\extras\CUPTI\lib64 -) else ( - set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\\extras\CUPTI\lib64 -) - - -set GRADLE_OPTS=-Dorg.gradle.daemon=false -set CUDA_MODULE_LOADING=LAZY diff --git a/tools/ci_build/github/windows/setup_env_trt.bat b/tools/ci_build/github/windows/setup_env_trt.bat deleted file mode 100644 index 6110249a9cde6..0000000000000 --- a/tools/ci_build/github/windows/setup_env_trt.bat +++ /dev/null @@ -1,11 +0,0 @@ -REM Copyright (c) Microsoft Corporation. All rights reserved. -REM Licensed under the MIT License. - -if exist PATH=%AGENT_TEMPDIRECTORY%\v12.8\ ( - set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.8\bin;%AGENT_TEMPDIRECTORY%\v12.8\extras\CUPTI\lib64 -) else ( - set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.8\extras\CUPTI\lib64 -) -set PATH=%AGENT_TEMPDIRECTORY%\TensorRT-10.9.0.34.Windows10.x86_64.cuda-12.8\lib;%PATH% -set GRADLE_OPTS=-Dorg.gradle.daemon=false -set CUDA_MODULE_LOADING=LAZY diff --git a/tools/python/util/vcpkg_helpers.py b/tools/python/util/vcpkg_helpers.py index 5c336ace4c7fc..c2839dc909cb7 100644 --- a/tools/python/util/vcpkg_helpers.py +++ b/tools/python/util/vcpkg_helpers.py @@ -437,6 +437,7 @@ def generate_vcpkg_triplets_for_emscripten( configs: set[str], emscripten_root: str, # Parameters defining the specific build configuration + enable_jspi: bool, enable_rtti: bool, enable_wasm_exception_catching: bool, # Controls -sDISABLE_EXCEPTION_CATCHING=... enable_minimal_onnx_build: bool, # Controls ONNX port setting AND C++ exceptions (-fno-exceptions) @@ -451,18 +452,22 @@ def generate_vcpkg_triplets_for_emscripten( - If enable_minimal_onnx_build=True, C++ exceptions are disabled (-fno-exceptions). - If enable_minimal_onnx_build=False, C++ exceptions are assumed enabled (-fexceptions). - This supports three main effective EH scenarios depending on the combination of - 'enable_minimal_onnx_build' and 'enable_wasm_exception_catching': + This supports 4 main effective EH scenarios depending on the combination of + 'enable_minimal_onnx_build', 'enable_jspi' and 'enable_wasm_exception_catching': 1. No EH (-fno-exceptions, -sDISABLE_EXCEPTION_CATCHING=1): Set enable_minimal_onnx_build=True, enable_wasm_exception_catching=False 2. Full EH (-fexceptions, -sDISABLE_EXCEPTION_CATCHING=0): Set enable_minimal_onnx_build=False, enable_wasm_exception_catching=True 3. Throw Only EH (-fexceptions, -sDISABLE_EXCEPTION_CATCHING=1): Set enable_minimal_onnx_build=False, enable_wasm_exception_catching=False + 4. Use the new Wasm EH (-fwasm-exceptions -sWASM_LEGACY_EXCEPTIONS=0): + Set enable_minimal_onnx_build=False, enable_jspi=True Args: build_dir (str): The directory to save the generated triplet files. emscripten_root (str): The root path of Emscripten. + enable_jspi (bool): Flag indicating if JSPI is enabled. If JSPI is enabled, the new + Wasm EH will be used and enable_wasm_exception_catching is ignored. enable_rtti (bool): Flag indicating if RTTI is enabled for dependencies. enable_wasm_exception_catching (bool): Flag indicating if the Emscripten runtime exception catching mechanism should be enabled @@ -479,6 +484,12 @@ def generate_vcpkg_triplets_for_emscripten( # Derive C++ exception enablement from the minimal build flag cpp_exceptions_enabled = not enable_minimal_onnx_build + # When JSPI is enabled, use the new Wasm EH + if enable_jspi: + if enable_minimal_onnx_build: + # TODO: support minimal build with JSPI if needed + raise ValueError("Currently minimal build cannot be used with JSPI.") + for target_abi in ["wasm32", "wasm64"]: os_name = "emscripten" file_name = f"{target_abi}-{os_name}.cmake" @@ -522,7 +533,9 @@ def generate_vcpkg_triplets_for_emscripten( # Wasm Exception Catching Runtime (-s flag, apply to Base and Linker flags) exception_catching_flag = "" - if enable_wasm_exception_catching: + if enable_jspi: + exception_catching_flag = "-fwasm-exceptions -sWASM_LEGACY_EXCEPTIONS=0" + elif enable_wasm_exception_catching: exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=0" else: exception_catching_flag = "-sDISABLE_EXCEPTION_CATCHING=1"