diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml deleted file mode 100644 index 1416f5a4d33a9..0000000000000 --- a/.github/workflows/sca.yml +++ /dev/null @@ -1,133 +0,0 @@ -name: Windows_SCA -on: - push: - branches: - - main - - rel-* - pull_request: - -concurrency: - group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -env: - AZCOPY_AUTO_LOGIN_TYPE: MSI - AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4 - -jobs: - Onnxruntime-SCA-training-CUDA: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x64' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Download cuda - run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk - - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --windows_sdk_version 10.0.22621.0 --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v11.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA - - # No python - Onnxruntime-SCA-win32-WINML-x64: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x64' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA_WIN32_WINML_X64 - - # No java, No python - Onnxruntime-SCA-win32-WINML-x86: - runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"] - steps: - - uses: actions/checkout@v3 - with: - submodules: false - - uses: actions/setup-python@v4 - with: - python-version: '3.11.x' - architecture: 'x86' - - - uses: actions/setup-node@v3 - with: - node-version: 18 - - - name: Delete build folder - run: | - if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b } - &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x86 -install_prefix D:\b\Debug\installed -build_config Debug - - # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter. - - name: Build code - env: - CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake' - run: python tools\ci_build\build.py --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib - - - name: Generate sarif - working-directory: D:\b - run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output - - - name: Upload SARIF to GitHub - uses: github/codeql-action/upload-sarif@v2 - continue-on-error: true - with: - sarif_file: ${{ github.workspace }}\output\MergeResult.sarif - category: VS_SCA_WIN32_WINML_X86 diff --git a/.gitmodules b/.gitmodules index 036a248070855..7bb49e98bfec1 100644 --- a/.gitmodules +++ b/.gitmodules @@ -8,6 +8,3 @@ path = cmake/external/emsdk url = https://github.com/emscripten-core/emsdk.git branch = 3.1.44 -[submodule "cmake/external/onnxruntime-extensions"] - path = cmake/external/onnxruntime-extensions - url = https://github.com/microsoft/onnxruntime-extensions.git diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt index 6f6faa3a2e56f..985eb645664c8 100644 --- a/ThirdPartyNotices.txt +++ b/ThirdPartyNotices.txt @@ -6230,3 +6230,37 @@ https://github.com/intel/neural-compressor terms, and open source software license terms. These separate license terms govern your use of the third party programs as set forth in the "THIRD-PARTY-PROGRAMS" file. + +_____ + +FlashAttention, https://github.com/Dao-AILab/flash-attention + +BSD 3-Clause License + +Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/VERSION_NUMBER b/VERSION_NUMBER index 15b989e398fc7..4a02d2c3170bd 100644 --- a/VERSION_NUMBER +++ b/VERSION_NUMBER @@ -1 +1 @@ -1.16.0 +1.16.2 diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index b01ed00350bb0..82a454791d159 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -84,7 +84,8 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF) option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF) -option(onnxruntime_USE_FLASH_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) +cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF) +option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON) option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF) option(onnxruntime_USE_AVX "Use AVX instructions" OFF) @@ -666,13 +667,16 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_DISABLE_CONTRIB_OPS) set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6) message( STATUS "Turn off flash attention since CUDA compiler version < 11.6") set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() else() set(onnxruntime_USE_FLASH_ATTENTION OFF) + set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF) endif() if (onnxruntime_USE_CUDA) @@ -685,6 +689,11 @@ if (onnxruntime_USE_CUDA) list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1) list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1) endif() + if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) + message( STATUS "Enable memory efficient attention for CUDA EP") + list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1) + list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1) + endif() endif() if (onnxruntime_USE_VITISAI) diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake index 18ac668bb1592..8c5d81d638ced 100644 --- a/cmake/external/cutlass.cmake +++ b/cmake/external/cutlass.cmake @@ -1,4 +1,4 @@ -if (onnxruntime_USE_FLASH_ATTENTION) +if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) include(FetchContent) FetchContent_Declare( cutlass diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index c087ad8f6d81e..8e412c7847b70 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -46,8 +46,8 @@ if (onnxruntime_BUILD_UNIT_TESTS) FetchContent_Declare( googletest URL ${DEP_URL_googletest} + FIND_PACKAGE_ARGS 1.13.0...<2.0.0 NAMES GTest URL_HASH SHA1=${DEP_SHA1_googletest} - OVERRIDE_FIND_PACKAGE ) endif() @@ -528,4 +528,3 @@ endif() FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR) FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR) - diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 5adfc7ba03923..03360ff30c4c4 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -529,7 +529,7 @@ if (onnxruntime_USE_CUDA) target_link_libraries(${target} PRIVATE cuda) endif() - if (onnxruntime_USE_FLASH_ATTENTION) + if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION) include(cutlass) target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples) endif() diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index c8592a4019461..ecee52f642b1f 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -201,6 +201,10 @@ set(training_ops_excluded_files "reduction/reduction_ops.cc" # no double type support "cuda_training_kernels.cc" "cuda_training_kernels.h" + "nn/conv_shared.cc" + "nn/conv_shared.h" + "nn/conv_transpose_grad.cc" + "nn/conv_transpose_grad.h" ) function(auto_set_source_files_hip_language) diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs index b374371446a90..86b44a6784817 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs @@ -743,7 +743,7 @@ internal static OrtValue CreateFromTensorObject(TensorBase value, out TensorElem /// /// Creates an OrtValue that contains a string tensor of specified shape, and /// containing empty strings. String tensors are always on CPU. - /// Use FillStringTensorElement to assign individual elements values. + /// Use StringTensorSetElementAt to assign individual elements values. /// /// /// disposable OrtValue diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index c52ca4d1a4631..ac790242409e3 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -15,6 +15,7 @@ public struct OrtTrainingApi public IntPtr LoadCheckpoint; public IntPtr SaveCheckpoint; public IntPtr CreateTrainingSession; + public IntPtr CreateTrainingSessionFromBuffer; public IntPtr TrainingSessionGetTrainingModelOutputCount; public IntPtr TrainingSessionGetEvalModelOutputCount; public IntPtr TrainingSessionGetTrainingModelOutputName; diff --git a/docs/python/README.rst b/docs/python/README.rst index 7d978b0941235..bcf7c635afd82 100644 --- a/docs/python/README.rst +++ b/docs/python/README.rst @@ -8,6 +8,16 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime ((b & 0x80000000) >> 24); // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val |= 0x7f; - } else if ((b & 0x7fffffff) == 0x7f800000) { + if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { val |= 126; } else { val |= 0x7f; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; } else { uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent uint32_t m = static_cast(b & 0x007FFFFF); // mantissa if (e != 0) { - if (e < 117) { // 0b1110101 - } else if (e < 118) { // 0b1110110 - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; + if (e < 117) { + } else if (e < 121) { + // denormalized number + auto d = 120 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; } - } else if (e < 121) { // 127 - 7 + 1 // 0b1111001 - auto d = 120 - e; // 0b1111000 - val |= 1 << (2 - d); - val |= m >> (21 + d); - if ((m >> (20 + d)) & 1) { + auto mask = 1 << (20 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 136) { // 127 + 8 + 1 // 0b10001000 - auto ex = e - 120; // 127 - 7 + } else if (e < 136) { + // normalized number + auto ex = e - 120; if (ex == 0) { val |= 0x4; val |= m >> 21; @@ -83,7 +85,7 @@ struct Float8E4M3FN { val &= 0xFE; } } - if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7C000))) { + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { if ((val & 0x7F) < 0x7E) { // rounding val += 1; @@ -147,14 +149,22 @@ struct Float8E4M3FN { inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } #if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 - explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { val = *reinterpret_cast(&value); } + explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { + val = *reinterpret_cast(&value); + } explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3() const { return *reinterpret_cast(&val); } #endif }; -inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val != right.val; } -inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val < right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { + return left.val != right.val; +} +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { + return left.val < right.val; +} // User defined suffixes to make it easier to declare // initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char @@ -164,9 +174,7 @@ inline Float8E4M3FN operator"" _f8e4m3fn(unsigned long long int v) { return Float8E4M3FN(narrow(v), Float8E4M3FN::FromBits()); } -inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) { - return Float8E4M3FN(static_cast(v), true); -} +inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) { return Float8E4M3FN(static_cast(v), true); } #endif @@ -205,36 +213,38 @@ struct Float8E4M3FNUZ { std::memcpy(&b, &v, sizeof(b)); val = static_cast((b & 0x80000000) >> 24); // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val = 0x80; - } else if ((b & 0x7fffffff) == 0x7f800000) { + if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { val |= 0x7F; } else { // infinity val = 0x80; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; } else { uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent uint32_t m = static_cast(b & 0x007FFFFF); // mantissa if (e != 0) { if (e < 116) { - } else if (e < 117) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 120) { // 127 - 8 + 1 + } else if (e < 120) { + // denormalized number auto d = 119 - e; - val |= 1 << (2 - d); - val |= m >> (21 + d); - if ((m >> (20 + d)) & 1) { + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (20 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 135) { // 127 + 8 - auto ex = e - 119; // 127 - 7 + } else if (e < 135) { + // normalized number + auto ex = e - 119; if (ex == 0) { val |= 0x4; val |= m >> 21; @@ -242,7 +252,7 @@ struct Float8E4M3FNUZ { val |= ex << 3; val |= m >> 20; } - if (m & 0x80000) { + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { if ((val & 0x7F) < 0x7F) { // rounding val += 1; @@ -303,9 +313,15 @@ struct Float8E4M3FNUZ { inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } }; -inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val != right.val; } -inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val < right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { + return left.val != right.val; +} +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { + return left.val < right.val; +} // User defined suffixes to make it easier to declare // initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char @@ -315,9 +331,7 @@ inline Float8E4M3FNUZ operator"" _f8e4m3p8fnuz(unsigned long long int v) { return Float8E4M3FNUZ(narrow(v), Float8E4M3FNUZ::FromBits()); } -inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) { - return Float8E4M3FNUZ(static_cast(v), true); -} +inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) { return Float8E4M3FNUZ(static_cast(v), true); } #endif @@ -357,32 +371,33 @@ struct Float8E5M2 { uint32_t b; std::memcpy(&b, &v, sizeof(b)); - val = (b & 0x80000000) >> 24; // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val |= 0x7f; - } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { val |= 0x7B; } else { val |= 0x7C; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; } else { uint32_t e = (b & 0x7F800000) >> 23; // exponent uint32_t m = b & 0x007FFFFF; // mantissa if (e != 0) { if (e < 110) { - } else if (e < 111) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 113) { // 127 - 15 + 1 + } else if (e < 113) { + // denormalized number auto d = 112 - e; - val |= 1 << (1 - d); - val |= m >> (22 + d); - if ((m >> (21 + d)) & 1) { + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } @@ -461,8 +476,12 @@ struct Float8E5M2 { #endif }; -inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { + return left.val != right.val; +} inline ORT_HOST_DEVICE bool operator<(const Float8E5M2& left, const Float8E5M2& right) { return left.val < right.val; } // User defined suffixes to make it easier to declare @@ -473,9 +492,7 @@ inline Float8E5M2 operator"" _f8e5m2fn(unsigned long long int v) { return Float8E5M2(narrow(v), Float8E5M2::FromBits()); } -inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) { - return Float8E5M2(static_cast(v), true); -} +inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) { return Float8E5M2(static_cast(v), true); } #endif @@ -513,40 +530,42 @@ struct Float8E5M2FNUZ { uint32_t b; std::memcpy(&b, &v, sizeof(b)); - val = (b & 0x80000000) >> 24; // sign - if ((b & 0x7fc00000) == 0x7fc00000) { - val = 0x80; - } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { val |= 0x7F; } else { val = 0x80; } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; } else { uint32_t e = (b & 0x7F800000) >> 23; // exponent uint32_t m = b & 0x007FFFFF; // mantissa if (e != 0) { if (e < 109) { - } else if (e < 110) { - val |= 1; - if ((m >> 23) & 1) { - // rounding - val += 1; - } - } else if (e < 112) { // 127 - 16 + 1 + } else if (e < 112) { + // denormalized number auto d = 111 - e; - val |= 1 << (1 - d); - val |= m >> (22 + d); - if ((m >> (21 + d)) & 1) { + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && + ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { // rounding val += 1; } - } else if (e < 143) { // 127 + 15 + 1 + } else if (e < 143) { + // normalized number auto ex = e - 111; val |= ex << 2; val |= m >> 21; - if (m & 0x100000) { + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { if ((val & 0x7F) < 0x7F) { // rounding val += 1; @@ -554,7 +573,7 @@ struct Float8E5M2FNUZ { val = 0x80; } } - } else if ((e == 255) && (m == 0)) { // inf + } else if ((e == 255) && (m == 0)) { val = 0x80; } else if (saturate) { val |= 0x7F; @@ -605,9 +624,15 @@ struct Float8E5M2FNUZ { inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } }; -inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val == right.val; } -inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val != right.val; } -inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val < right.val; } +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { + return left.val == right.val; +} +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { + return left.val != right.val; +} +inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { + return left.val < right.val; +} // User defined suffixes to make it easier to declare // initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char @@ -617,9 +642,7 @@ inline Float8E5M2FNUZ operator"" _f8e5m2fnuz(unsigned long long int v) { return Float8E5M2FNUZ(narrow(v), Float8E5M2FNUZ::FromBits()); } -inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) { - return Float8E5M2FNUZ(static_cast(v), true); -} +inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) { return Float8E5M2FNUZ(static_cast(v), true); } #endif diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h index 48c4e4320dfd7..a071f3182faad 100644 --- a/include/onnxruntime/core/framework/ort_value.h +++ b/include/onnxruntime/core/framework/ort_value.h @@ -68,11 +68,7 @@ struct OrtValue { } bool IsSparseTensor() const { -#if !defined(DISABLE_SPARSE_TENSORS) return (type_ != nullptr && type_->IsSparseTensorType()); -#else - ORT_THROW("Sparse tensor is not supported in this build."); -#endif } onnxruntime::MLDataType Type() const { diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bc7792ba4366b..456a11603de65 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4333,8 +4333,12 @@ struct OrtApi { * \param[in] input_len Number of elements in the input_names and inputs arrays * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr - * The array will be passed back to run_async_callback + * \param[out] output OrtValue* array of size output_names_len. + * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue. + * Later, the output array will be passed to run_async_callback with all null(s) filled with valid + * OrtValue pointer(s) allocated by onnxruntime. + * NOTE: it is customer's duty to finally release the output array and each of its member, + * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer. * \param[in] run_async_callback Callback function on model run completion * \param[in] user_data User data that pass back to run_async_callback */ diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index b9b6676c0072d..47356c3fe3608 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1073,11 +1073,15 @@ struct SessionImpl : ConstSessionImpl { * * \param[in] run_options * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] input_values Array of ::OrtValue%s of the input values + * \param[in] input_values Array of Value objects of length input_count * \param[in] input_count Number of elements in the input_names and inputs arrays * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr - * The array will be passed back to the callback + * \param[out] output_values Array of provided Values to be filled with outputs. + * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. + * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. + * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. + * NOTE: it is customer's duty to finally release output_values and each of its member, + * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. * \param[in] output_count Number of elements in the output_names and outputs array * \param[in] callback Callback function on model run completion * \param[in] user_data User data that pass back to the callback diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts index 8f597765ebe8a..3e303bcf64b8e 100644 --- a/js/common/lib/version.ts +++ b/js/common/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/common/package-lock.json b/js/common/package-lock.json index b9e5fd6082457..69cb6b60aaf35 100644 --- a/js/common/package-lock.json +++ b/js/common/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/common/package.json b/js/common/package.json index 331f17dbc44be..06616c3247c07 100644 --- a/js/common/package.json +++ b/js/common/package.json @@ -2,7 +2,7 @@ "license": "MIT", "type": "module", "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "repository": { "url": "https://github.com/Microsoft/onnxruntime.git", "type": "git" diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts index 8f597765ebe8a..3e303bcf64b8e 100644 --- a/js/node/lib/version.ts +++ b/js/node/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/node/package-lock.json b/js/node/package-lock.json index bd01302262273..6994f70a45233 100644 --- a/js/node/package-lock.json +++ b/js/node/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-node", - "version": "1.16.0", + "version": "1.16.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-node", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "os": [ "win32", @@ -27,7 +27,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/node/package.json b/js/node/package.json index c898aeb56c0f5..faa07d1149fab 100644 --- a/js/node/package.json +++ b/js/node/package.json @@ -13,7 +13,7 @@ 3 ] }, - "version": "1.16.0", + "version": "1.16.2", "dependencies": { "onnxruntime-common": "file:../common" }, diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 78f32ec09250b..f8aeadbe27c56 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -68,7 +68,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) { int64_t bytesOffset = info[1].As().Int64Value(); int64_t bytesLength = info[2].As().Int64Value(); - ParseSessionOptions(info[1].As(), sessionOptions); + ParseSessionOptions(info[3].As(), sessionOptions); this->session_.reset( new Ort::Session(OrtEnv(), reinterpret_cast(buffer) + bytesOffset, bytesLength, sessionOptions)); } else { diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index b3f0c466308a5..058531f415d61 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -66,12 +66,14 @@ class OnnxruntimeSessionHandler implements SessionHandler { let results: Binding.ModelLoadInfoType; // load a model if (typeof this.#pathOrBuffer === 'string') { + // load model from model path results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options); } else { + // load model from buffer if (!this.#inferenceSession.loadModelFromBlob) { throw new Error('Native module method "loadModelFromBlob" is not defined'); } - const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer); + const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer.buffer); results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options); } // resolve promise if onnxruntime session is successfully created diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts index 8f597765ebe8a..3e303bcf64b8e 100644 --- a/js/react_native/lib/version.ts +++ b/js/react_native/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/react_native/package.json b/js/react_native/package.json index 3020a04f0af31..2c19037257051 100644 --- a/js/react_native/package.json +++ b/js/react_native/package.json @@ -36,7 +36,7 @@ "registry": "https://registry.npmjs.org/" }, "source": "lib/index", - "version": "1.16.0", + "version": "1.16.2", "main": "dist/commonjs/index", "homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md", "files": [ diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock index 21734bc50b000..ff2cfd2c8f98f 100644 --- a/js/react_native/yarn.lock +++ b/js/react_native/yarn.lock @@ -5188,7 +5188,7 @@ onetime@^5.1.0, onetime@^5.1.2: mimic-fn "^2.1.0" "onnxruntime-common@file:../common": - version "1.16.0" + version "1.16.2" open@^6.2.0: version "6.4.0" diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 4a1109b9ec5dc..e33854819c5db 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -38,7 +38,7 @@ Do not modify directly.* | Floor | ai.onnx(6-12,13+) | | | Gather | ai.onnx(1-10,11-12,13+) | | | Gelu | com.microsoft(1+) | | -| Gemm | ai.onnx(7-8,9-10,11+) | | +| Gemm | ai.onnx(7-8,9-10,11-12,13+) | | | GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts index 8f597765ebe8a..3e303bcf64b8e 100644 --- a/js/web/lib/version.ts +++ b/js/web/lib/version.ts @@ -4,4 +4,4 @@ // This file is generated by /js/scripts/update-version.ts // Do not modify file content manually. -export const version = '1.16.0'; +export const version = '1.16.2'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index 1d490aa9028ff..82fe3d5b6af43 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -26,43 +26,41 @@ import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, elementsPerThread: readonly number[]): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; const channelDim = isChannelsLast ? 3 : 1; const outputSize = ShapeUtil.size(outputShape); - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - const isVec4 = inChannels % 4 === 0 && outChannels % 4 === 0; const workPerThread = isVec4 ? 2 : 1; + const group = attributes.group; + const wShape = inputs[1].dims; + const inputChannelsPerGroup = wShape[0] / group; + const outputChannelsPerGroup = wShape[1]; - const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0]; - - const declareInputs = [ - `@group(0) @binding(0) var Dy: array<${ - isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var W: array<${isVec4 ? 'vec4' : 'f32'}>;` - ]; let declareFunctions = ` fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); declareFunctions += ` fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } - const w = inputVariable('W', inputs[1].dataType, inputs[1].dims); - const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims); - const output = outputVariable('result', inputs[0].dataType, outputShape); + const components = isVec4 ? 4 : 1; + const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components); + const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components); + const inputVariables = [dy, w]; + if (hasBias) { + inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components)); + } + const output = outputVariable('result', inputs[0].dataType, outputShape, components); const codeSnippet4 = `{ - let batch: u32 = global_id.z / outShape[1]; - let r = global_id.z % outShape[1]; - let c = global_id.y * ${workPerThread}; - let d1: u32 = global_id.x * 4; + let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1]; + let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1]; + let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread}; + let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4; let dyCorner = vec2(i32(r), i32(c)) - vec2(pads); @@ -73,18 +71,21 @@ const createConvTranspose2DOpProgramShaderSource = dotProd[i] = vec4(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = f32(dyCorner.x + wR) / f32(strides.x); - let wRPerm: u32= filterDims[0] - 1 - wR; + var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + let wRPerm = filterDims[0] - 1 - wR; if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || - fract(dyR) > 0.0) { + fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = f32(dyCorner.y + wC) / f32(strides.y); - let dyC2 = f32(dyCorner.y + 1 + wC) / f32(strides.y); - let wCPerm: u32 = filterDims[1] - 1 - wC; + let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); + let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let wCPerm = filterDims[1] - 1 - wC; + if (wCPerm < 0) { + continue; + } var bDyCVal = true; var bDyCVal2 = true; if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || @@ -101,57 +102,53 @@ const createConvTranspose2DOpProgramShaderSource = if (bDyCVal && bDyCVal2) { let d2Length = outBackprop[3]; for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; - let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')}; - let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')}; - let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')}; + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; + let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; + let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - var xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let tmpval = vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); dotProd[0] = dotProd[0] + tmpval; - xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC2', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC2')}; + xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); } } else if (bDyCVal) { - let d2Length = outBackprop[3]; + let d2Length = outBackprop[${channelDim}]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; - let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')}; - let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')}; - let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')}; + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; + let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; + let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - var xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let tmpval = vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); dotProd[0] = dotProd[0] + tmpval; } } else if (bDyCVal2) { let d2Length = outBackprop[3]; for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) { - let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; - let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')}; - let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')}; - let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')}; + let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')}; + let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')}; + let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')}; + let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; - var xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let tmpval = vec4(xValue * wValue0, - xValue * wValue1, - xValue * wValue2, - xValue * wValue3); + var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; + let tmpval = vec4(dot(xValue, wValue0), + dot(xValue, wValue1), + dot(xValue, wValue2), + dot(xValue, wValue3)); dotProd[1] = dotProd[1] + tmpval; } } @@ -159,16 +156,21 @@ const createConvTranspose2DOpProgramShaderSource = } for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) { - ${output.set('batch', 'r', 'c+i', 'd1', 'dotProd[i]')}; + let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : '0.0'}; + ${output.set('batch', 'r', 'c + i', 'd1', 'value')}; } }`; const codeSnippet = ` let outputIndices = ${output.offsetToIndices('global_idx')}; - let batch = outputIndices[0]; - let d1 = outputIndices[${channelDim}]; - let dyCorner = vec2(i32(outputIndices[${rowDim}]), i32(outputIndices[${colDim}])) - pads; + let batch = ${output.indicesGet('outputIndices', 0)}; + let d1 = ${output.indicesGet('outputIndices', channelDim)}; + let r = ${output.indicesGet('outputIndices', rowDim)}; + let c = ${output.indicesGet('outputIndices', colDim)}; + let dyCorner = vec2(i32(r), i32(c)) - pads; let dyRCorner = dyCorner.x; let dyCCorner = dyCorner.y; + let groupId = d1 / ${outputChannelsPerGroup}; + let wOutChannel = d1 - groupId * ${outputChannelsPerGroup}; // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. var dotProd = 0.0; @@ -178,7 +180,7 @@ const createConvTranspose2DOpProgramShaderSource = } let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,30 +192,29 @@ const createConvTranspose2DOpProgramShaderSource = } let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - for (var d2: u32 = 0; d2 < outBackprop[3]; d2 = d2 + 1) { + for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { + let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ - isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')}; - let wValue = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')}; + isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : + dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; + let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; } } } - ${output.setByOffset('global_idx', 'dotProd')}; + let value = dotProd + ${hasBias ? 'bias[d1]' : '0.0'}; + ${output.setByOffset('global_idx', 'value')}; `; return ` - ${w.impl('indicesToOffset', 'get')} - ${dy.impl('indicesToOffset', 'get')} - ${output.impl('offsetToIndices')} + ${shaderHelper.declareVariables(...inputVariables, output)} ${declareFunctions} - ${declareInputs.join('\n')} - @group(0) @binding(${declareInputs.length}) var result: array<${isVec4 ? 'vec4' : 'f32'}>; const outShape : vec4 = vec4(${outputShape.join(',')}); const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); @@ -240,25 +241,18 @@ export const createConvTranspose2DProgramInfo = (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => { const hasBias = inputs.length > 2; - const isChannelsLast = attributes.format === 'NHWC'; + // const isChannelsLast = attributes.format === 'NHWC'; const outputShape = attributes.outputShape; - const batchSize = outputShape[0]; - const outWidth = outputShape[isChannelsLast ? 1 : 2]; - const outHeight = outputShape[isChannelsLast ? 2 : 3]; - const outChannels = outputShape[isChannelsLast ? 3 : 1]; - const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; - const isVec4 = inChannels % 4 === 0 && outChannels % 4 === 0; + const outputSize = ShapeUtil.size(outputShape); - const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; - const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; - const workGroupSize: [number, number, number] = - isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; - const elementsPerThread = - isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1]; + // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + // TODO Enable isVec4 for performance + // Disabled due to weight matrix layout issue + // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0; const dispatch = [ - Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), - Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), - Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[1]) + Math.ceil(outputSize / 64), + 1, + 1, ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); @@ -271,6 +265,6 @@ export const createConvTranspose2DProgramInfo = }], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, elementsPerThread), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 5f3d1564664bf..02b978a381de5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -50,8 +50,6 @@ const createBinaryOpProgramShader = }; broadcastImpl = ` - ${output.impl('offsetToIndices')} - fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 { return ${calcOffsetImpl(dimsA)}; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index e64c74972581d..7da57bcb9c647 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -16,28 +16,6 @@ import {ShapeUtil} from '../../util'; **/ export const WORKGROUP_SIZE = 64; -interface IndicesHelperImplementations { - /** - * implementation of `offsetToIndices` function. - */ - readonly offsetToIndices: string; - - /** - * implementation of `indicesToOffset` function. - */ - readonly indicesToOffset: string; - - /** - * implementation of `set`, `setByIndices` and `setByOffset` function. - */ - readonly set: string; - - /** - * implementation of `get`, `getByIndices` and `getByOffset` function. - */ - readonly get: string; -} - interface IndicesHelperTypes { /** * WGSL type of indices expression @@ -96,12 +74,10 @@ interface IndicesHelperTypes { */ export interface IndicesHelper { /** - * get WGSL code of function implementation for the util functions + * get WGSL code of function implementation for the util functions. * - * @param functions - a list of function names to get implementation for. If not specified, all functions will be - * returned. */ - readonly impl: (...functions: ReadonlyArray) => string; + readonly impl: () => string; /** * get type info @@ -215,9 +191,12 @@ export interface IndicesHelper { readonly shape: readonly number[]; } -const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, string] => { +const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => { // return type is [ storage type, runtime type ] or a single string for both switch (type) { + // TODO: enable after "shader-f16" WSGL extension release + // case DataType.float16: + // return components > 1 ? `vec${components}` : 'f16'; case DataType.float: return components > 1 ? `vec${components}` : 'f32'; case DataType.int32: @@ -245,6 +224,11 @@ const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, st } }; +export const tensorTypeToWsglStorageType = (type: DataType, components: 1|2|3|4 = 1) => { + const mappedType = getWgslMappedType(type, components); + return typeof mappedType === 'string' ? mappedType : mappedType[0]; +}; + /** * A helper function to get a IndicesHelper for a given input or output. * @@ -260,13 +244,22 @@ const createIndicesHelper = components: 1|2|3|4): IndicesHelper => { const rank = shape.length; const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`; - const mappedType = getWgslValueType(tensorType, components); + const mappedType = getWgslMappedType(tensorType, components); const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1]; const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0]; const type = {indices: indicesType, value: valueType, storage: storageType, tensor: tensorType}; const normalizeDim = (dim: number|string): string => typeof dim === 'string' ? dim : `${dim}u`; + const implementationUsed = { + offsetToIndices: false, + indicesToOffset: false, + set: false, + setByIndices: false, + get: false, + getByIndices: false, + }; + const strides = ShapeUtil.computeStrides(shape); let o2iSnippet = ''; for (let i = 0; i < rank - 1; i++) { @@ -287,7 +280,10 @@ const createIndicesHelper = return indices; }`; - const offsetToIndices = (varOffset: string) => rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; + const offsetToIndices = (varOffset: string) => { + implementationUsed.offsetToIndices = true; + return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`; + }; const offsets: string[] = []; if (rank >= 2) { @@ -301,7 +297,10 @@ const createIndicesHelper = return ${offsets.join('+')}; }`; - const indicesToOffset = (varIndices: string) => rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; + const indicesToOffset = (varIndices: string) => { + implementationUsed.indicesToOffset = true; + return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; + }; const indices = (...init: ReadonlyArray) => rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; @@ -357,17 +356,18 @@ const createIndicesHelper = } })(); + const getByIndicesImplementation = rank < 2 ? '' : ` + fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { + return ${name}[i2o_${name}(indices)]; + }`; + const getImplementation = rank < 2 ? '' : (() => { const params = shape.map((_, i) => `d${i}: u32`).join(', '); const dims = shape.map((_, i) => `d${i}`).join(', '); return ` - fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { - return ${name}[i2o_${name}(indices)]; - } fn get_${name}(${params}) -> ${valueType} { return get_${name}ByIndices(${indices(dims)}); - } - `; + }`; })(); const get = (...indices: ReadonlyArray) => { @@ -376,14 +376,16 @@ const createIndicesHelper = } const normalizedIndices = indices.map(normalizeDim).join(','); - const funcName = `get_${name}`; if (rank === 0) { return getByOffset('0u'); } else if (rank === 1) { return getByOffset(normalizedIndices[0]); } else { - return `${funcName}(${normalizedIndices})`; + implementationUsed.get = true; + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; + return `get_${name}(${normalizedIndices})`; } }; @@ -391,21 +393,24 @@ const createIndicesHelper = if (rank < 2) { return getByOffset(varIndices); } else { + implementationUsed.getByIndices = true; + implementationUsed.indicesToOffset = true; return `get_${name}ByIndices(${varIndices})`; } }; + const setByIndicesImplementation = rank < 2 ? '' : ` + fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) { + ${setByOffset(`i2o_${name}(indices)`, 'value')} + }`; + const setImplementation = rank < 2 ? '' : (() => { const params = shape.map((_, i) => `d${i}: u32`).join(', '); const dims = shape.map((_, i) => `d${i}`).join(', '); return ` - fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) { - ${setByOffset(`i2o_${name}(indices)`, 'value')} - } fn set_${name}(${params}, value: ${valueType}) { set_${name}ByIndices(${indices(dims)}, value); - } - `; + }`; })(); const set = (...indicesAndValue: ReadonlyArray) => { @@ -424,6 +429,9 @@ const createIndicesHelper = } else if (rank === 1) { return setByOffset(normalizedIndices[0], value); } else { + implementationUsed.set = true; + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; return `set_${name}(${normalizedIndices}, ${value})`; } }; @@ -432,32 +440,34 @@ const createIndicesHelper = if (rank < 2) { return setByOffset(varIndices, value); } else { + implementationUsed.setByIndices = true; + implementationUsed.indicesToOffset = true; return `set_${name}ByIndices(${varIndices}, ${value});`; } }; - const funcImpls = { - offsetToIndices: offsetToIndicesImplementation, - indicesToOffset: indicesToOffsetImplementation, - set: setImplementation, - get: getImplementation, - }; - const impl = (...functions: Array) => { + const impl = () => { const impls = []; - if (functions.length === 0) { - functions.push('offsetToIndices', 'indicesToOffset', 'set', 'get'); + if (implementationUsed.offsetToIndices) { + impls.push(offsetToIndicesImplementation); + } + if (implementationUsed.indicesToOffset) { + impls.push(indicesToOffsetImplementation); + } + if (implementationUsed.set) { + impls.push(setImplementation); + } + if (implementationUsed.setByIndices) { + impls.push(setByIndicesImplementation); + } + if (implementationUsed.get) { + impls.push(getImplementation); } - for (const func of functions) { - const impl = funcImpls[func]; - if (impl === undefined) { - throw new Error(`unknown function ${func}`); - } else { - impls.push(impl); - } + if (implementationUsed.getByIndices) { + impls.push(getByIndicesImplementation); } return impls.join('\n'); }; - impl.toString = () => impl(); return { impl, @@ -552,6 +562,11 @@ export interface ShaderHelper { * @param variables - an array of IndicesHelper for the variables. */ declareVariables(...variables: IndicesHelper[]): string; + + /** + * Get additional implementation that needs to be added to the shader source. + */ + readonly additionalImplementations: string; } class ShaderHelperImpl implements ShaderHelper { @@ -585,6 +600,7 @@ class ShaderHelperImpl implements ShaderHelper { } declareVariable(variable: IndicesHelper, bindingIndex: number): string { + this.indicesHelpers.push(variable); const access = variable.usage === 'input' ? 'read' : 'read_write'; const storageType = variable.type.storage; return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`; @@ -594,6 +610,12 @@ class ShaderHelperImpl implements ShaderHelper { let i = 0; return variables.filter(v => ShapeUtil.size(v.shape) > 0).map(v => this.declareVariable(v, i++)).join('\n'); } + + private indicesHelpers: IndicesHelper[] = []; + + get additionalImplementations(): string { + return this.indicesHelpers.map(i => i.impl()).join('\n'); + } } export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper => diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts index 8b91b64a09200..9b294803d3787 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts @@ -109,9 +109,6 @@ const createConcatProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(...inputVars, output)} - ${inputVars.map(i => i.impl('indicesToOffset', 'get')).join('\n')} - ${output.impl('offsetToIndices')} - const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateInputIndexImpl(sizeInConcatAxis.length)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts index 7a0e1f01c461f..8a794ce16a0b5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts @@ -47,9 +47,6 @@ const createGroupedConvProgramInfo = ${shaderHelper.declareVariables(...inputVars, output)} ${activationFunction} - ${output.impl('offsetToIndices')} - ${x.impl('indicesToOffset', 'get')} - ${w.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts index b07fe3a90f3b9..2d845775f1c62 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts @@ -58,8 +58,6 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} let outputIndices = ${output.offsetToIndices('global_idx')}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts index 2ce8427bb6e7f..f62c766aa9ed0 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface InstanceNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -45,7 +45,7 @@ const createInstanceNormProgramInfo = Got scale size of ${scaleSize} and bias size of ${biasSize}`); } - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const getShaderSource = (shaderHelper: ShaderHelper) => ` const C: u32 = ${C}; @@ -99,7 +99,7 @@ const createInstanceNormNHWCProgramInfo = const C = xShape[xShape.length - 1]; const H = ShapeUtil.sizeFromDimension(xShape, 1) / C; - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const normCount = C * N; const getShaderSource = (shaderHelper: ShaderHelper) => ` diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts index 48627bfaec401..8a9927b25a52e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface LayerNormAttributes extends AttributeWithCacheKey { axis: number; @@ -54,7 +54,7 @@ const createLayerNormProgramInfo = } } - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const hasMeanDataOutput = outputCount > 1; const hasInvStdOutput = outputCount > 2; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 9af8fc7b6d33d..79071d32443d6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -128,9 +128,6 @@ const generatePoolingCode = (${attributes.pads.map(i => `${i}u`).join(',')}); const inputDims = array(${inputDims.map(i => `${i}u`).join(',')}); const kernelStrides = array(${kernelStrides.map(i => `${i}u`).join(',')}); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index b645510d8384b..cb592c838dd97 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -85,9 +85,6 @@ export const createReduceProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset')} - ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} var inputIndices: ${input.type.indices}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts index 505bae7ce2302..1d0b8229a76f7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts @@ -484,8 +484,6 @@ const createResizeProgramInfo = } })()}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} if (${noScale}) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index 96bf1cd9a6ef6..4b845bcf2121b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType, tensorTypeToWsglType} from '../../../wasm-common'; +import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {ShaderHelper, tensorTypeToWsglStorageType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -84,7 +84,7 @@ const createSkipLayerNormProgramInfo = const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : []; const hasBetaInput = inputs.length > 3; const hasBiasInput = inputs.length > 4; - const dataType = tensorTypeToWsglType(inputs[0].dataType); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const hasMeanOutput = isTraining && outputCount > 1; const hasInvStdDevOutput = isTraining && outputCount > 2; const hasInputSkipBiasSumOutput = outputCount > 3; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts index 1f881a75ffbde..4211e526898e6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts @@ -153,8 +153,6 @@ const createSliceProgramInfo = const steps = array(${steps.map(i => `${i}u`).join(',')}); const inputShape = array(${inputShape.map(i => `${i}u`).join(',')}); - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${calculateInputIndicesImpl(input, output, inputShape, outputShape)} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts index 54f493422816f..9a150d21ea02e 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts @@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => { const createSplitAttributesFromInputs = (inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => { const splitSizes: number[] = []; + let numOutputs: number = attributes.numOutputs; if (inputs[1].dims[0] > 0) { inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v))); + numOutputs = splitSizes.length; } - return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes}); + return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes}); }; const calculateOutputIndexImpl = (numberOfTensors: number): string => ` @@ -85,8 +87,6 @@ const createSplitProgramInfo = const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`; const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.declareVariables(input, ...outputs)} - ${input.impl('indicesToOffset', 'offsetToIndices', 'get')} - ${outputs.map(o => o.impl('indicesToOffset', 'set')).join('\n')} const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')}); ${calculateOutputIndexImpl(sizeInConcatAxis.length)} ${writeBufferDataImpl(outputs)} @@ -114,7 +114,7 @@ const createSplitProgramInfoLoader = const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes); const metadata: ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey}; - return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)}; + return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)}; }; export const split = (context: ComputeContext, attributes: SplitAttributes): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts index 2b80ce173245b..99d9668757caa 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts @@ -66,8 +66,6 @@ export const createTileProgramInfo = const getShaderSource = (shaderHelper: ShaderHelper) => ` const inputShape = ${input.indices(...inputShape)}; ${shaderHelper.declareVariables(input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} let outputIndices = ${output.offsetToIndices('global_idx')}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index 0b0185fc17c9b..ebedc61712e8a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -64,8 +64,6 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu ${shaderHelper.declareVariables(input, output)} ${permFunctionBody(perm, rank, input, output)} - ${output.impl('offsetToIndices')} - ${input.impl('indicesToOffset', 'get')} ${shaderHelper.mainStart()} ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts index b46b35b71412e..da710b7dc2596 100644 --- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts @@ -114,7 +114,9 @@ export class ProgramManager { build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact { const device = this.backend.device; - const code = programInfo.getShaderSource(createShaderHelper(normalizedDispatchGroupSize)); + const shaderHelper = createShaderHelper(normalizedDispatchGroupSize); + const userCode = programInfo.getShaderSource(shaderHelper); + const code = `${shaderHelper.additionalImplementations}\n${userCode}`; const shaderModule = device.createShaderModule({code}); LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`); diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index a89a585906f9d..389773f3e8884 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,19 +164,3 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; - -export const tensorTypeToWsglType = (type: DataType) => { - switch (type) { - case DataType.float: - return 'f32'; - // TODO: enable after "shader-f16" WSGL extension release - // case DataType.float16: - // return 'f16'; - case DataType.int32: - return 'i32'; - case DataType.uint32: - return 'u32'; - default: - throw new Error(`Unsupported type: ${type}`); - } -}; diff --git a/js/web/package-lock.json b/js/web/package-lock.json index 4c5649d8806c9..8ad55996f7455 100644 --- a/js/web/package-lock.json +++ b/js/web/package-lock.json @@ -1,12 +1,12 @@ { "name": "onnxruntime-web", - "version": "1.16.0", + "version": "1.16.2", "lockfileVersion": 2, "requires": true, "packages": { "": { "name": "onnxruntime-web", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "dependencies": { "flatbuffers": "^1.12.0", @@ -49,7 +49,7 @@ }, "../common": { "name": "onnxruntime-common", - "version": "1.16.0", + "version": "1.16.2", "license": "MIT", "devDependencies": { "typedoc": "^0.23.22" diff --git a/js/web/package.json b/js/web/package.json index ce06475f672fd..76f793263e01a 100644 --- a/js/web/package.json +++ b/js/web/package.json @@ -8,7 +8,7 @@ "type": "git" }, "author": "fs-eire", - "version": "1.16.0", + "version": "1.16.2", "jsdelivr": "dist/ort.min.js", "dependencies": { "flatbuffers": "^1.12.0", diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc new file mode 100644 index 0000000000000..a249dc807fa0b --- /dev/null +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -0,0 +1,289 @@ +[ + { + "name": "ConvTranspose without bias addition A", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose without bias addition B", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + "dims": [2, 2, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 400, 940, 560, 1080, 2520, 1480, 760, 1740, 1000, 640, 1500, 880, 1720, 3960, 2280, 1160, 2620, 1480 + ], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition A", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 4, 1, 1], + "type": "float32" + }, + { + "data": [ + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + ], + "dims": [4, 4, 2, 2], + "type": "float32" + }, + { + "data": [0.1, 0.2, 0.3, 0.4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, + 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, + 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, + 100.4000015258789 + ], + "dims": [1, 4, 2, 2], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 1, 1, 1], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose- group - A", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [1, 1], "type": "ints" }, + { "name": "group", "data": 2, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0], + "dims": [1, 2, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0], + "dims": [2, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68], + "dims": [1, 2, 3, 3], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose- group - B", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0 + ], + "dims": [1, 3, 3, 3], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + }, + { + "data": [0.125, 0.25, 0.375], + "dims": [3], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125, + 52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25, + 214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375, + 470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375 + ], + "dims": [1, 3, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose- group - C", + "operator": "ConvTranspose", + "attributes": [ + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "group", "data": 3, "type": "int" } + ], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, + 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0 + ], + "dims": [1, 3, 3, 4], + "type": "float32" + }, + { + "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0], + "dims": [3, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368, + 394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146, + 1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420 + ], + "dims": [1, 3, 4, 5], + "type": "float32" + } + ] + } + ] + }, + + { + "name": "ConvTranspose- pointwise", + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0], + "dims": [1, 2, 2, 2], + "type": "float32" + }, + { + "data": [0.0, 1.0, 2.0, 3.0], + "dims": [2, 2, 1, 1], + "type": "float32" + }, + { + "data": [1, 2], + "dims": [2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9, 11, 13, 15, 14, 18, 22, 26], + "dims": [1, 2, 2, 2], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index d39d8edf0b73a..022451c885dd8 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -7,7 +7,7 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_ or the `Github project `_. """ -__version__ = "1.16.0" +__version__ = "1.16.2" __author__ = "Microsoft" # we need to do device version validation (for example to check Cuda version for an onnxruntime-training package). diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index f1ab3e691b702..4c9c15d07a9b8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -37,6 +37,7 @@ enum AttentionKernelType { AttentionKernel_TrtFlashAttention, AttentionKernel_TrtFusedCrossAttention, AttentionKernel_CutlassMemoryEfficientAttention, + AttentionKernel_FlashAttention, AttentionKernel_Default }; @@ -98,8 +99,16 @@ constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTI // Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled). constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION"; +// Environment variable to enable or disable flash attention. Default is 0 (enabled). +constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION"; + // Minimum sequence length to enable memory efficient attention in FP32. -constexpr int kMinSequenceLengthForMemoryEfficientAttentionFp32 = 256; +constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256; + +// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention +constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV"; +// Default value for the above setting. +constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513; } // namespace attention diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index b8066567fc357..c911b6e76701c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -8,6 +8,7 @@ #include "contrib_ops/cuda/bert/attention.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -39,20 +40,36 @@ REGISTER_KERNEL_TYPED(MLFloat16) template Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) { - disable_fused_self_attention_ = sizeof(T) != 2 || - ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); + disable_fused_self_attention_ = + sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); - enable_trt_flash_attention_ = sizeof(T) == 2 && - !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); + enable_trt_flash_attention_ = + sizeof(T) == 2 && + !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); - enable_fused_causal_attention_ = sizeof(T) == 2 && - ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); + enable_fused_causal_attention_ = + sizeof(T) == 2 && + ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false); -#if USE_FLASH_ATTENTION - disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); +#if USE_MEMORY_EFFICIENT_ATTENTION + disable_memory_efficient_attention_ = + ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); #else disable_memory_efficient_attention_ = true; #endif + +#if USE_FLASH_ATTENTION + disable_flash_attention_ = + sizeof(T) != 2 || + onnxruntime::ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + attention::kMinSeqLenForFlashAttentionPackedQKV, + attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); +#else + disable_flash_attention_ = true; + min_seq_len_for_flash_attention_packed_qkv_ = 0; +#endif } template @@ -100,71 +117,98 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { MHARunner* fused_runner = nullptr; // Check whether we can use fused kernel - int sm = device_prop.major * 10 + device_prop.minor; - bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - - if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT - // GPT fused kernels requires left side padding. mask can be: - // none (no padding), 1D sequence lengths or 2d mask. - // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token - // where past state is empty. - bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && - parameters.past_sequence_length == 0 && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); - if (use_causal_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); - } + const int sm = device_prop.major * 10 + device_prop.minor; + const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. - fused_runner = fused_fp16_runner_.get(); - } - } else { // BERT - bool use_fused_runner = !disable_fused_self_attention_ && - (nullptr == mask_index || is_mask_1d_seq_len) && - nullptr == past && - nullptr == present && - nullptr == relative_position_bias && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, false); - - if (use_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + (nullptr == relative_position_bias) && + nullptr == past && + nullptr == present && + parameters.hidden_size == parameters.v_hidden_size && + nullptr == mask_index && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. + if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + use_flash_attention = false; + } +#else + constexpr bool use_flash_attention = false; +#endif + + if (!use_flash_attention) { + if (is_unidirectional_) { // GPT + if (enable_fused_causal_attention_) { + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. + // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token + // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && + parameters.past_sequence_length == 0 && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); + if (use_causal_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } + + // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. + fused_runner = fused_fp16_runner_.get(); + } } + } else { // BERT + bool use_fused_runner = !disable_fused_self_attention_ && + (nullptr == mask_index || is_mask_1d_seq_len) && + nullptr == past && + nullptr == present && + nullptr == relative_position_bias && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, false); - // In case some kernel not loaded due to shared memory limit, we need to double check here. - const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length); - if (fused_fp16_runner_->isValid(S)) { - fused_runner = fused_fp16_runner_.get(); + if (use_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } + + // In case some kernel not loaded due to shared memory limit, we need to double check here. + const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length); + if (fused_fp16_runner_->isValid(S)) { + fused_runner = fused_fp16_runner_.get(); + } } } } -#if USE_FLASH_ATTENTION - bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - bool use_memory_efficient_attention = fused_runner == nullptr && - !disable_memory_efficient_attention_ && - (nullptr == mask_index || is_mask_1d_key_seq_len_start) && - nullptr == past && - nullptr == present && - (nullptr == relative_position_bias || is_good_for_rpb) && - (sizeof(T) == 2 || // sequence length threshold is 0 in FP16 - parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && - has_memory_efficient_attention(sm, sizeof(T) == 2); +#if USE_MEMORY_EFFICIENT_ATTENTION + bool use_memory_efficient_attention = + !use_flash_attention && + fused_runner == nullptr && + !disable_memory_efficient_attention_ && + nullptr == past && + nullptr == present && + (parameters.head_size & 7) == 0 && + (parameters.v_head_size & 7) == 0 && + (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + has_memory_efficient_attention(sm, sizeof(T) == 2); + + if (use_memory_efficient_attention) { + bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; + use_memory_efficient_attention = (nullptr == relative_position_bias || is_good_for_rpb); + } #else constexpr bool use_memory_efficient_attention = false; - ORT_UNUSED_PARAMETER(is_mask_1d_key_seq_len_start); #endif cublasHandle_t cublas = GetCublasHandle(context); @@ -199,6 +243,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -215,7 +260,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data()); data.past_key = nullptr; data.past_value = nullptr; - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) + ? nullptr + : reinterpret_cast(relative_position_bias->Data()); data.has_qkv_workspace = true; data.workspace = reinterpret_cast(work_space.get()); data.output = reinterpret_cast(output->MutableData()); @@ -224,6 +271,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { data.present_value = nullptr; data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = nullptr; + data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = nullptr; data.cumulated_sequence_length_kv_cache = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h index ba7c56c04fdde..455e55ba05a66 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention.h @@ -21,10 +21,12 @@ class Attention final : public CudaKernel, public AttentionBase { Status ComputeInternal(OpKernelContext* context) const override; protected: + bool disable_flash_attention_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool enable_fused_causal_attention_; bool disable_memory_efficient_attention_; + int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 4d478ef158503..ae7696eb9fe0f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -42,6 +42,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -64,7 +65,8 @@ size_t AlignSize(size_t bytes) { void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { if (this->sequence_length != sequence_length) { ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); - LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, this->max_batch_size, sequence_length, stream); + LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, + this->max_batch_size, sequence_length, stream); this->sequence_length = sequence_length; } } @@ -114,6 +116,7 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, + bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention) { // Note that q, k and v might need alignment for fused attention kernels. @@ -121,6 +124,14 @@ size_t GetAttentionWorkspaceSize( ((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size); #if USE_FLASH_ATTENTION + if (use_flash_attention) { + return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + } +#else + ORT_UNUSED_PARAMETER(use_flash_attention); +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION if (use_memory_efficient_attention) { size_t fmha_buffer_bytes = 0; if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) { @@ -276,333 +287,439 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, half* present); template -Status PrepareQkv(contrib::AttentionParameters& parameters, - AttentionData& data, - cudaStream_t stream, - int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { +Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; const int qk_head_size = parameters.head_size; const int v_head_size = parameters.v_head_size; const bool past_present_share_buffer = parameters.past_present_share_buffer; void* fused_runner = data.fused_runner; - bool use_memory_efficient_attention = data.use_memory_efficient_attention; + bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention; T* qkv = data.workspace; bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); - // Default format for memory efficient attention. - // When there is past state, the format shall be BxNxSxH, so we disable memory efficient attention when there is past. - DUMP_TENSOR_INIT(); - if (nullptr != data.gemm_buffer) { - if (data.bias == nullptr) { - assert(nullptr == fused_runner); - // For quantized attention, bias has been added so only need transpose here. - // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH - assert(qk_head_size == v_head_size); - int matrix_to_trans = (past_present_share_buffer ? 1 : 3); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.gemm_buffer, qkv, 3)); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } else { - // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) - // For memory efficient attention, transpose to 3xBxSxNxH (format 3) - // For unfused kernel, transpose to 3xBxNxSxH (format 1) - // For fused causal kernel, use format 1 since we need have K and V to update present state, - // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. - const int format = (use_fused_kernel ? 2 : (use_memory_efficient_attention ? 3 : 1)); - qkv_format = use_fused_kernel - ? AttentionQkvFormat::QKV_BSN3H - : (use_memory_efficient_attention - ? AttentionQkvFormat::Q_K_V_BSNH - : (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH)); - - // For fused causal, we will update gemm_buffer with bias directly. - T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; - - int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); - // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v - // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) - LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.past_sequence_length); - } + if (data.bias == nullptr) { + assert(nullptr == fused_runner); + // For quantized attention, bias has been added so only need transpose here. + // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH + assert(qk_head_size == v_head_size); + int matrix_to_trans = (past_present_share_buffer ? 1 : 3); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.gemm_buffer, qkv, 3)); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } else { + // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2) + // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3) + // For unfused kernel, transpose to 3xBxNxSxH (format 1) + // For fused causal kernel, use format 1 since we need have K and V to update present state, + // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel. + const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1)); + qkv_format = use_fused_kernel + ? AttentionQkvFormat::QKV_BSN3H + : (use_flash_or_efficient_attention + ? AttentionQkvFormat::Q_K_V_BSNH + : (use_fused_causal + ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH + : AttentionQkvFormat::Q_K_V_BNSH)); + + // For fused causal, we will update gemm_buffer with bias directly. + T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr; + + int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3); + // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v + // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H) + LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, + 3, parameters.do_rotary, parameters.past_sequence_length); } - // attention with past/present state - else if (data.past_key != nullptr || data.present_key != nullptr) { - // Below logic does not support memory efficient attention with past (like pass_past_in_kv) but without bias - if (data.bias == nullptr) { - // cross attention with past state - if (data.past_key != nullptr && data.present_key == nullptr) { - assert(data.past_value != nullptr); - assert(data.query != nullptr); - assert(data.key == nullptr); - assert(data.value == nullptr); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - } - // cross attention with present state or self attention with present state - else if (data.past_key == nullptr && data.present_key != nullptr) { - assert(data.past_value == nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - - // TODO: supporting packed kv for cross attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, data.present_key)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, data.present_value)); - } - // self attention with past and present state - else { - assert(data.past_key != nullptr); - assert(data.past_value != nullptr); - assert(data.present_key != nullptr); - assert(data.present_value != nullptr); - assert(data.query != nullptr); - assert(data.key != nullptr); - assert(data.value != nullptr); - // TODO: supporting packed qkv for self attention may benefit performance - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.query, q)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.key, k)); - ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.value, v)); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + return Status::OK(); +} + +// For MultiHeadAttention with past state +template +Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + DUMP_TENSOR_INIT(); + + if (data.bias == nullptr) { + // Below logic does not support fused attention with past without bias + // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past. + + // cross attention with past state + if (data.past_key != nullptr && data.present_key == nullptr) { + assert(data.past_value != nullptr); + assert(data.query != nullptr); + assert(data.key == nullptr); + assert(data.value == nullptr); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); } -#if USE_FLASH_ATTENTION - // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value - else if (use_memory_efficient_attention && data.past_key != nullptr && data.past_value != nullptr && parameters.pass_past_in_kv) { - // Transpose past_key and past_value to use memory efficient attention - - // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_key, data.temp_k_workspace)); - // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) - ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.past_value, data.temp_v_workspace)); - - // query => q, temp_k_workspace => k, temp_v_workspace => v - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - - data.past_key = nullptr; - data.past_value = nullptr; + // cross attention with present state or self attention with present state + else if (data.past_key == nullptr && data.present_key != nullptr) { + assert(data.past_value == nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); + + // TODO: supporting packed kv for cross attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.key, data.present_key)); + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.value, data.present_value)); } - // When there is no past_key/past_value and there is present_key/present_value (e.g. get initial kv to use as past_kv in the next iteration) - else if (use_memory_efficient_attention && data.present_key != nullptr && data.present_value != nullptr) { - // Use memory efficient attention kernel - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); - - // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) + // self attention with past and present state + else { + assert(data.past_key != nullptr); + assert(data.past_value != nullptr); + assert(data.present_key != nullptr); + assert(data.present_value != nullptr); + assert(data.query != nullptr); + assert(data.key != nullptr); + assert(data.value != nullptr); + // TODO: supporting packed qkv for self attention may benefit performance + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.query, q)); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, false, data.temp_k_workspace, data.present_key)); - - // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + max_threads_per_block, false, data.key, k)); ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, - max_threads_per_block, false, data.temp_v_workspace, data.present_value)); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + max_threads_per_block, false, data.value, v)); } + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.past_key != nullptr && + data.past_value != nullptr && + parameters.pass_past_in_kv) { + // Transpose past_key and past_value to use memory efficient attention + + // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_key, data.temp_k_workspace)); + // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v) + ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.past_value, data.temp_v_workspace)); + + // query => q, temp_k_workspace => k, temp_v_workspace => v + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + + data.past_key = nullptr; + data.past_value = nullptr; + } + // When there is no past_key/past_value and there is present_key/present_value + // (e.g. get initial kv to use as past_kv in the next iteration) + else if ((data.use_memory_efficient_attention || data.use_flash_attention) && + data.present_key != nullptr && + data.present_value != nullptr) { + // Use memory efficient attention kernel + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace); + + // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, false, data.temp_k_workspace, data.present_key)); + + // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v) + ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads, + max_threads_per_block, false, data.temp_v_workspace, data.present_value)); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } #endif - else { - // Use unfused kernel for Q, use unfused kernel for K and V if needed - constexpr int format = 0; - // Query (BxSxNxH) => Q (BxNxSxH) + else { + // Use unfused kernel for Q, use unfused kernel for K and V if needed + constexpr int format = 0; + // Query (BxSxNxH) => Q (BxNxSxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + if (!parameters.pass_past_in_kv) { + T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; + T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, data.bias + num_heads * qk_head_size, k_dest, + true, -1); + + // Value (BxLxNxH_v) => V (BxNxLxH_v) LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, true, -1); - if (!parameters.pass_past_in_kv) { - T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k; - T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v; - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, data.bias + num_heads * qk_head_size, k_dest, - true, -1); - - // Value (BxLxNxH_v) => V (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, data.bias + 2 * num_heads * qk_head_size, v_dest, - true, -1); - - DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size * num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size * num_heads, kv_sequence_length, v_head_size); - } - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size); } - } else if (data.key == nullptr) { // gemm_buffer == nullptr and packed qkv - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} - DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); - - if (use_memory_efficient_attention) { - // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, qkv, - true, v_head_size, qkv_add_bias, 3); - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (!use_fused_kernel) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); - } +// For MultiHeadAttention without past state, with packed QKV inputs +template +Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); - qkv_format = AttentionQkvFormat::QKV_BSN3H; + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack qkv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, qkv, + true, v_head_size, qkv_add_bias, 3); + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (!use_fused_kernel) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed QKV format is not implemented for current GPU. Please disable it in fusion options."); } - } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv - // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. - // CheckInputs verified this constraint. - assert(data.bias == nullptr); - assert(qk_head_size == v_head_size); - DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); - - if (use_memory_efficient_attention) { - // unpack kv to BSNH. Note that there is no bias so we need not output query to q. - constexpr int format = 4; - T* qkv_add_bias = nullptr; - const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); - LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, kv_bias, k, - true, v_head_size, qkv_add_bias, 2); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } else { - if (data.fused_cross_attention_kernel == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); - } + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } + return Status::OK(); +} - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; +// For MultiHeadAttention without past state, with packed KV inputs +template +Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint. + // CheckInputs verified this constraint. + assert(data.bias == nullptr); + assert(qk_head_size == v_head_size); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size); + + if (data.use_memory_efficient_attention || data.use_flash_attention) { + // unpack kv to BSNH. Note that there is no bias so we need not output query to q. + constexpr int format = 4; + T* qkv_add_bias = nullptr; + const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size); + LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, kv_bias, k, + true, v_head_size, qkv_add_bias, 2); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } else { + if (data.fused_cross_attention_kernel == nullptr) { + return ORT_MAKE_STATUS( + ONNXRUNTIME, NOT_IMPLEMENTED, + "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options."); } - } else { // gemm_buffer == nullptr and not packed - assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); - DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } + return Status::OK(); +} + +// For MultiHeadAttention without past state, with Q, K and V inputs +template +Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.kv_sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + void* fused_runner = data.fused_runner; + + T* qkv = data.workspace; + + bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional); + bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional); + + // gemm_buffer == nullptr and not packed + assert(data.query != nullptr && data.key != nullptr && data.value != nullptr); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size); #if DUMP_TENSOR_LEVEL > 1 - if (data.bias != nullptr) { - DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); - DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); - DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); - } + if (data.bias != nullptr) { + DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size); + DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size); + DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size); + } #endif - if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { - DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, num_heads, sequence_length, kv_sequence_length); - } + if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) { + DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, + num_heads, sequence_length, kv_sequence_length); + } - if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { - DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); - } + if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) { + DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1); + } - if (data.fused_cross_attention_kernel != nullptr) { - assert(qk_head_size == v_head_size); + if (data.fused_cross_attention_kernel != nullptr) { + assert(qk_head_size == v_head_size); - // For fused cross attention, besides adding bias, K and V needed to be packed: - // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); + // For fused cross attention, besides adding bias, K and V needed to be packed: + // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length); - qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; - } -#if USE_FLASH_ATTENTION - else if (use_memory_efficient_attention) { - LaunchAddBias(stream, max_threads_per_block, - batch_size, sequence_length, kv_sequence_length, - num_heads, qk_head_size, v_head_size, - data.bias, data.query, data.key, data.value, q, k, v); - - DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size); - DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BSNH; - } + qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H; + } +#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION + else if (data.use_memory_efficient_attention || data.use_flash_attention) { + LaunchAddBias(stream, max_threads_per_block, + batch_size, sequence_length, kv_sequence_length, + num_heads, qk_head_size, v_head_size, + data.bias, data.query, data.key, data.value, q, k, v); + + DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BSNH; + } #endif - else if (use_fused_kernel) { - assert(qk_head_size == v_head_size); - - // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) - LaunchAddBiasTransposeTrt( - stream, max_threads_per_block, - batch_size, sequence_length, - num_heads, qk_head_size, - data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); - DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); - - qkv_format = AttentionQkvFormat::QKV_BSN3H; - } else { // unfused kernel - ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); - - // Query (BxSxNxH) => Q (BxNxSxH) - constexpr int format = 0; - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, sequence_length, num_heads, qk_head_size, - data.query, data.bias, q, - true, -1); - - // Key (BxLxNxH) => K (BxNxLxH) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, qk_head_size, - data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, - true, -1); + else if (use_fused_kernel) { + assert(qk_head_size == v_head_size); - // Value (BxLxNxH_v) => K (BxNxLxH_v) - LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, - batch_size, kv_sequence_length, num_heads, v_head_size, - data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, - true, -1); + // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H) + LaunchAddBiasTransposeTrt( + stream, max_threads_per_block, + batch_size, sequence_length, + num_heads, qk_head_size, + data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length); + DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size); + + qkv_format = AttentionQkvFormat::QKV_BSN3H; + } else { // unfused kernel + ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal"); + + // Query (BxSxNxH) => Q (BxNxSxH) + constexpr int format = 0; + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, sequence_length, num_heads, qk_head_size, + data.query, data.bias, q, + true, -1); + + // Key (BxLxNxH) => K (BxNxLxH) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, qk_head_size, + data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k, + true, -1); + + // Value (BxLxNxH_v) => K (BxNxLxH_v) + LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, + batch_size, kv_sequence_length, num_heads, v_head_size, + data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v, + true, -1); + + DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size); + DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size); + qkv_format = AttentionQkvFormat::Q_K_V_BNSH; + } + return Status::OK(); +} - DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size); - DUMP_TENSOR_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size); - qkv_format = AttentionQkvFormat::Q_K_V_BNSH; - } +template +Status PrepareQkv(contrib::AttentionParameters& parameters, + AttentionData& data, + cudaStream_t stream, + int max_threads_per_block, + T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + if (nullptr != data.gemm_buffer) { // Attention operator + ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, qkv_format)); + } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); + } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); + } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); + } else { // multihead attention operator, no past, separated Q/K/V inputs + ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format)); } CUDA_RETURN_IF_ERROR(cudaGetLastError()); @@ -631,7 +748,10 @@ Status QkvToContext( void* fused_runner = data.fused_runner; // At most one fused kernel is enabled. - assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1); + assert((int(data.use_flash_attention) + + int(data.use_memory_efficient_attention) + + int(fused_runner != nullptr) + + int(data.fused_cross_attention_kernel != nullptr)) <= 1); const int batches = batch_size * num_heads; @@ -673,8 +793,9 @@ Status QkvToContext( if (nullptr != data.present) { assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); ORT_RETURN_IF_ERROR( - LaunchConcatPastToPresent(stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, - max_threads_per_block, data.past, k, data.present)); + LaunchConcatPastToPresent( + stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, + max_threads_per_block, data.past, k, data.present)); // Update pointers to present_k and present_v. k = data.present; @@ -708,22 +829,25 @@ Status QkvToContext( cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream); } else { ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads, + LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, + batch_size, qk_head_size, num_heads, max_threads_per_block, 1, data.past_key, k, data.present_key)); ORT_RETURN_IF_ERROR( - LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, batch_size, v_head_size, num_heads, + LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, + batch_size, v_head_size, num_heads, max_threads_per_block, 1, data.past_value, v, data.present_value)); // Update pointers to present_k and present_v. k = data.present_key; v = data.present_value; } } - } else { + } else { // past_present_share_buffer assert(qk_head_size == v_head_size); assert(data.fused_cross_attention_kernel == nullptr); assert(!use_fused_kernel); assert(data.gemm_buffer != nullptr); assert(!data.use_memory_efficient_attention); + assert(!data.use_flash_attention); assert(data.has_qkv_workspace); if (nullptr != data.past_key || nullptr != data.present_key) { @@ -799,7 +923,7 @@ Status QkvToContext( kv_sequence_length, // sequence length of KV stream); - DUMP_TENSOR("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("trt cross output", data.output, batch_size, sequence_length, num_heads, v_head_size); return Status::OK(); } @@ -836,11 +960,11 @@ Status QkvToContext( } fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream); - DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size); } else { assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH); fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream); - DUMP_TENSOR("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size); } return Status::OK(); } @@ -850,6 +974,37 @@ Status QkvToContext( : parameters.scale; #if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(nullptr == data.mask_index); + assert(nullptr == data.relative_position_bias); + assert(parameters.head_size == parameters.v_head_size); + + void* query = reinterpret_cast(q); + void* key = reinterpret_cast(k); + void* value = reinterpret_cast(v); + // For packed KV, we can use query input directly. + if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) { + query = reinterpret_cast(const_cast(data.query)); + } + + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); + + constexpr bool is_causal = false; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd( + device_prop, stream, query, key, value, data.output, reinterpret_cast(scratch1), + parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size, + parameters.sequence_length, parameters.total_sequence_length, scale, is_causal)); + + DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); + + return Status::OK(); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION if (data.use_memory_efficient_attention) { // We only enable fused cross attention when there is no key padding mask. // Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query. @@ -864,9 +1019,9 @@ Status QkvToContext( query = data.query; } - DUMP_TENSOR_D("attention q(BSNH)", q, batch_size * sequence_length, num_heads * qk_head_size); - DUMP_TENSOR_D("attention k(BSNH)", k, batch_size * sequence_length, num_heads * qk_head_size); - DUMP_TENSOR_D("attention v(BSNH)", v, batch_size * sequence_length, num_heads * v_head_size); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size); MemoryEfficientAttentionParams p; p.sm = device_prop.major * 10 + device_prop.minor; @@ -879,19 +1034,28 @@ Status QkvToContext( p.v_head_size = parameters.v_head_size; p.causal = parameters.is_unidirectional; p.scale = scale; - p.seqlen_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index)); - p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + batch_size)); - p.seqstart_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); + p.seqlen_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index)); + p.seqstart_q_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index + batch_size)); + p.seqstart_k_ptr = nullptr == data.mask_index + ? nullptr + : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1)); p.query = query; p.key = key; p.value = value; p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias; p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias; p.output = data.output; - p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr; + p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) + ? scratch1 + : nullptr; p.stream = stream; run_memory_efficient_attention(p); - DUMP_TENSOR("attention cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, v_head_size); + return Status::OK(); } #endif @@ -922,9 +1086,9 @@ Status QkvToContext( q, qk_head_size, sequence_length * qk_head_size, &zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop)); - DUMP_TENSOR_D("Q", q, batch_size * num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("K", k, batch_size * num_heads, qk_head_size, sequence_length); - DUMP_TENSOR_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length); + DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, total_sequence_length); @@ -940,11 +1104,12 @@ Status QkvToContext( T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax. ORT_RETURN_IF_ERROR( - ComputeSoftmaxWithRawMask(ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, - mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, - scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, - parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, - mask_filter_value)); + ComputeSoftmaxWithRawMask( + ort_stream, total_sequence_length, sequence_length, batch_size, num_heads, + mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias, + scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension, + parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace, + mask_filter_value)); } else if (nullptr != mask_index) { // 1d mask index assert(mask_index_dims.size() == 1); // mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions. @@ -955,12 +1120,13 @@ Status QkvToContext( scratch1, scratch2, parameters.is_unidirectional)); } else { // no mask ORT_RETURN_IF_ERROR( - ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, - parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional)); + ComputeSoftmax( + stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias, + parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional)); } - DUMP_TENSOR_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length); - DUMP_TENSOR_D("V", v, batch_size * num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length); + DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = qkv; @@ -974,7 +1140,7 @@ Status QkvToContext( // Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads, max_threads_per_block, false, temp_output, data.output); - DUMP_TENSOR("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size); + DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size); return result; } @@ -1109,15 +1275,17 @@ Status DecoderQkvToContext( if (has_key_padding_mask) { constexpr int mask_dimension = 2; constexpr int max_sequence_length = 0; - ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask(ort_stream, kv_sequence_length, sequence_length, batch_size, - num_heads, nullptr, key_padding_mask, add_before_softmax, - false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, - 1.0f, mask_dimension, max_sequence_length, false, nullptr, - mask_filter_value)); + ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask( + ort_stream, kv_sequence_length, sequence_length, batch_size, + num_heads, nullptr, key_padding_mask, add_before_softmax, + false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional, + 1.0f, mask_dimension, max_sequence_length, false, nullptr, + mask_filter_value)); } else { - ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, num_heads, - add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, - is_unidirectional)); + ORT_RETURN_IF_ERROR(ComputeSoftmax( + stream, kv_sequence_length, sequence_length, batch_size, num_heads, + add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2, + is_unidirectional)); } // compute P*V (as V*P), and store in scratch3: BxNxSxH diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 5c63a8d8a80b6..af7373dd9fa1b 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -43,6 +43,7 @@ size_t GetAttentionWorkspaceSize( size_t kv_sequence_length, size_t total_sequence_length, void* fused_runner, + bool use_flash_attention, bool use_fused_cross_attention, bool use_memory_efficient_attention); @@ -74,6 +75,7 @@ struct AttentionData { void* fused_runner; const void* fused_cross_attention_kernel; + bool use_flash_attention; bool use_memory_efficient_attention; mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index 00fa265e117bc..ed330b0fca332 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #if defined(__GNUC__) #pragma GCC diagnostic push @@ -124,4 +124,4 @@ void DispatchBlockSize(const MemoryEfficientAttentionParams& params) { #pragma GCC diagnostic pop #endif -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu index 237f7ea8c9c42..540a2699587eb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu index 941ea87baa398..005425c56e0ae 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu index 5a0e7c9ed5b7a..955423b6c6762 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu index d0775a29c4cf1..0b54d90c4da30 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h" @@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu index 284211f96514d..750cace39ae39 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" @@ -27,4 +27,4 @@ void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h index 326ff451e600a..f725be8d7cf89 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h @@ -2,7 +2,7 @@ // Licensed under the MIT License. #pragma once -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/bert/attention_common.h" @@ -58,4 +58,4 @@ void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& p } // namespace contrib } // namespace onnxruntime -#endif // USE_FLASH_ATTENTION +#endif // USE_MEMORY_EFFICIENT_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h new file mode 100644 index 0000000000000..9db98061bbd66 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h @@ -0,0 +1,40 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +namespace onnxruntime { +namespace flash { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + template + __device__ BlockInfo(const Params& params, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), + sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]), + actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q), + actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) { + } + + template + inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h new file mode 100644 index 0000000000000..9394a19c9897a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h @@ -0,0 +1,85 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include +#include + +namespace onnxruntime { +namespace flash { + +constexpr int TOTAL_DIM = 0; +constexpr int H_DIM = 1; +constexpr int D_DIM = 2; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = uint32_t; + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + // The O matrix (output). + void* __restrict__ o_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void* __restrict__ p_ptr; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q; + int* __restrict__ cu_seqlens_k; + + int* __restrict__ blockmask; + + bool is_bf16 = false; + bool is_causal; + + const cudaDeviceProp* dprops; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream); + +} // namespace flash +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc new file mode 100644 index 0000000000000..87831d1eddfe9 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -0,0 +1,198 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cuda/bert/flash_attention/flash.h" +#include "contrib_ops/cuda/bert/flash_attention/static_switch.h" + +namespace onnxruntime { +namespace flash { + +void set_params_fprop(Flash_fwd_params& params, + // sizes + size_t batch_size, + size_t seqlen_q, + size_t seqlen_k, + size_t seqlen_q_rounded, + size_t seqlen_k_rounded, + size_t num_heads, + size_t num_heads_k, + size_t head_size, + size_t head_size_rounded, + // device pointers + void* q, + void* k, + void* v, + void* out, + void* cu_seqlens_q_d, + void* cu_seqlens_k_d, + void* p_d, + void* softmax_lse_d, + float softmax_scale, + bool is_causal) { + // Set the pointers and strides. + params.q_ptr = q; + params.k_ptr = k; + params.v_ptr = v; + params.o_ptr = out; + + // All stride are in elements, not bytes. + params.q_row_stride = num_heads * head_size; + params.k_row_stride = num_heads_k * head_size; + params.v_row_stride = num_heads * head_size; + params.q_head_stride = head_size; + params.k_head_stride = head_size; + params.v_head_stride = head_size; + params.o_row_stride = num_heads * head_size; + params.o_head_stride = head_size; + params.is_bf16 = false; + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0) + params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0) + } else { + params.q_batch_stride = 0; + params.k_batch_stride = 0; + params.v_batch_stride = 0; + params.o_batch_stride = 0; + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.h_h_k_ratio = num_heads / num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + + params.is_causal = is_causal; +} + +size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) { + size_t bytes = sizeof(float) * batch_size * num_heads * seqlen; + return bytes; +} + +void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) { + FP16_SWITCH(!params.is_bf16, [&] { + FWD_HEADDIM_SWITCH(params.d, [&] { + run_mha_fwd_(params, stream); + }); + }); +} + +Status mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + Flash_fwd_params params; + params.dprops = &dprops; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + /*cu_seqlens_q*/ nullptr, + /*cu_seqlens_k*/ nullptr, + nullptr, + softmax_lse, + softmax_scale, + is_causal); + + run_mha_fwd(params, stream); + return Status::OK(); +} + +Status mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, head_size) + void* out, // half (total_q, num_heads, head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal) { + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + Flash_fwd_params params; + params.dprops = &dprops; + set_params_fprop(params, + batch_size, + max_seqlen_q, max_seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + cu_seqlens_q, + cu_seqlens_k, + nullptr, + softmax_lse, + softmax_scale, + is_causal); + run_mha_fwd(params, stream); + return Status::OK(); +} + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) { + bool is_sm8x = dprops.major == 8 && dprops.minor >= 0; + bool is_sm90 = dprops.major == 9 && dprops.minor == 0; + return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0); +} + +} // namespace flash +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h new file mode 100644 index 0000000000000..2ae46d34c373a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -0,0 +1,78 @@ +/****************************************************************************** + * Copyright (c) 2022, Tri Dao. + * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#if USE_FLASH_ATTENTION + +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace flash { +Status mha_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // batch_size x seqlen_q x num_heads x head_size + void* k, // batch_size x seqlen_k x num_heads_k x head_size + void* v, // batch_size x seqlen_k x num_heads_k x head_size + void* out, // batch_size x seqlen_q x num_heads x head_size + void* softmax_lse, // batch_size x num_heads x seqlen_q + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int seqlen_q, + int seqlen_k, + float softmax_scale, + bool is_causal); + +Status mha_varlen_fwd(const cudaDeviceProp& dprops, + cudaStream_t stream, + void* q, // half (total_q, num_heads, head_size) + void* k, // half (total_k, num_heads, head_size) + void* v, // half (total_k, num_heads, v_head_size) + void* out, // half (total_q, num_heads, v_head_size) + int* cu_seqlens_q, // int (batch_size + 1) + int* cu_seqlens_k, // int (batch_size + 1) + void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q) + int batch_size, + int num_heads, + int num_heads_k, + int head_size, + int max_seqlen_q, + int max_seqlen_k, + float softmax_scale, + bool is_causal); + +size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); + +bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k); + +} // namespace flash +} // namespace onnxruntime + +#endif // USE_FLASH_ATTENTION diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu new file mode 100644 index 0000000000000..44ea92e58c86e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu new file mode 100644 index 0000000000000..a2bf16bc74e72 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu new file mode 100644 index 0000000000000..56fc04126ab12 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu new file mode 100644 index 0000000000000..6fb24640710a3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu new file mode 100644 index 0000000000000..94d51e922d7cb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu new file mode 100644 index 0000000000000..d32eec27634ce --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu new file mode 100644 index 0000000000000..65a2e42192532 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu new file mode 100644 index 0000000000000..f37ee5005855a --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu @@ -0,0 +1,18 @@ +// Copyright (c) 2023, Tri Dao. + +// Splitting the different head dimensions to different files to speed up compilation. +#if USE_FLASH_ATTENTION + +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h" + +namespace onnxruntime { +namespace flash { + +template <> +void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} + +} // namespace flash +} // namespace onnxruntime +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h new file mode 100644 index 0000000000000..b5af31e432d42 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h @@ -0,0 +1,532 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#if defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wunused-variable" +#pragma GCC diagnostic ignored "-Wunused-but-set-variable" +#endif + +#include +#include +#include + +#include +#include +#include +#include + +#include "contrib_ops/cuda/bert/flash_attention/block_info.h" +#include "contrib_ops/cuda/bert/flash_attention/kernel_traits.h" +#include "contrib_ops/cuda/bert/flash_attention/utils.h" +#include "contrib_ops/cuda/bert/flash_attention/softmax.h" + +namespace onnxruntime { +namespace flash { +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE auto +make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(cute::Layout, cute::Int>, + cute::Stride<_1, cute::Int>>{}, + make_layout(cute::size<2>(TileShape_MNK{}))); + + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE auto +make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom, + TiledMMA const& tiled_mma) { + using TileShape_MNK = typename TiledMMA::TiledShape_MNK; + using AtomShape_MNK = typename TiledMMA::AtomShape_MNK; + constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value; + constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M; + constexpr int MMAStride_M = MMA_M * AtomShape_M; + auto t = make_tile(cute::Layout, cute::Int>, + cute::Stride<_1, cute::Int>>{}, + // TODO: Shouldn't this be size<1>? + make_layout(cute::size<2>(TileShape_MNK{}))); + // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); } + return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum, + Tensor2& acc_o, float softmax_scale_log2) { + if (Is_first) { + flash::template reduce_max(scores, scores_max); + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + flash::reduce_sum(scores, scores_sum); + } else { + cute::Tensor scores_max_prev = make_fragment_like(scores_max); + copy(scores_max, scores_max_prev); + flash::template reduce_max(scores, scores_max); + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_max); ++mi) { + float scores_max_cur = !Check_inf + ? scores_max(mi) + : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi)); + float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + scores_sum(mi) *= scores_scale; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scores_scale; + } + } + flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2); + cute::Tensor scores_sum_cur = make_fragment_like(scores_sum); + flash::reduce_sum(scores, scores_sum_cur); +#pragma unroll + for (int mi = 0; mi < cute::size(scores_sum); ++mi) { + scores_sum(mi) += scores_sum_cur(mi); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void write_softmax_to_gmem( + cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_thr_copy_P) { + // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N) + cute::Layout l = tOrP.layout(); + cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l)))); + CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{}); + CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP)); +#pragma unroll + for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) { + copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + constexpr int kNWarps = Kernel_traits::kNWarps; + constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return; + + int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN); + if (Is_causal) { + n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN)); + } + + // We iterate over the blocks in reverse order. This is because the last block is the only one + // that needs masking when we read K and V from global memory. Moreover, iterating in reverse + // might save us 1 register (we just need n_block instead of both n_block and n_block_max). + + const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride; + // We move K and V to the last block. + const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride; + const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride; + const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN; + + cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), + cute::Shape, cute::Int>{}, + make_stride(params.q_row_stride, _1{})); + cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k), + cute::Shape, cute::Int>{}, + make_stride(params.k_row_stride, _1{})); + cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v), + cute::Shape, cute::Int>{}, + make_stride(params.v_row_stride, _1{})); + cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p), + cute::Shape, cute::Int>{}, + make_stride(params.seqlen_k_rounded, _1{})); + + cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutQ{}); + // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem; + cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)), + typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{}); + cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); + cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + + typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; + auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); + typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P; + auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx); + + cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ); + cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ); + cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K) + cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK); + cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K) + cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV); + cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP); + + typename Kernel_traits::TiledMma tiled_mma; + auto thr_mma = tiled_mma.get_thread_slice(tidx); + cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K) + cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K) + cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N) + + cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K + + // + // Copy Atom retiling + // + + auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx); + cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ); + + auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma); + auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx); + cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK); + + auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma); + auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx); + cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt); + + // TODO: this might need to change if we change the mma instruction in SM70 + cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{}); + cute::Tensor scores_sum = make_fragment_like(scores_max); + + // + // PREDICATES + // + + // Construct identity layout for sQ and sK + cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Allocate predicate tensors for k + cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ))); + cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK))); + + // Set predicates for k bounds + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tQpQ); ++k) { + tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; + } +#pragma unroll + for (int k = 0; k < cute::size(tKVpKV); ++k) { + tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; + } + } + + // Prologue + + cute::Tensor tQrQ = make_fragment_like(tQgQ); + // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs + flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, + binfo.actual_seqlen_q - m_block * kBlockM); + if (Kernel_traits::Is_Q_in_regs) { + cute::cp_async_fence(); + } + + if (Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<0>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + __syncthreads(); + } + + int n_block = n_block_max - 1; + // We don't need to clear the sK smem tiles since we'll mask out the scores anyway. + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, + binfo.actual_seqlen_k - n_block * kBlockN); + cute::cp_async_fence(); + + if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) { + flash::cp_async_wait<1>(); + __syncthreads(); + cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ); + CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M + cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view); + } + + clear(acc_o); + + // For performance reason, we separate out two kinds of iterations: + // those that need masking on S, and those that don't. + // We need masking on S for the very last block when K and V has length not multiple of kBlockN. + // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks. + // We will have at least 1 "masking" iteration. + + // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to + // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1. + constexpr int n_masking_steps = !Is_causal + ? 1 + : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1); +#pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + + // Advance gV + if (masking_step > 0) { + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + } else { + // Clear the smem tiles to account for predicated off loads + flash::copy( + gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN); + } + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + // if (cute::thread0()) { print(acc_s); } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + + // We don't put the masking before the matmul S = Q K^T because we don't clear sK + // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul + // can produce Inf / NaN. + if (!Is_causal) { + if (!Is_even_MN) { + flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); + } + } else { + // I can't get the stride from idx_row + flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k, + // m_block * kBlockM + get<0>(idx_row(0)), + m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, + binfo.actual_seqlen_q, + kNWarps * 16); + } + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // TODO: when we have key_padding_mask we'll need to Check_inf + masking_step == 0 + ? softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2) + : softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + // Convert scores from fp32 to fp16/bf16 + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + + // This check is at the end of the loop since we always have at least 1 iteration + if (n_masking_steps > 1 && n_block <= 0) { + --n_block; + break; + } + } + + // These are the iterations where we don't need masking on S + for (; n_block >= 0; --n_block) { + cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N) + clear(acc_s); + flash::cp_async_wait<0>(); + __syncthreads(); + // Advance gV + tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV); + cute::cp_async_fence(); + + flash::gemm( + acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, + smem_thr_copy_Q, smem_thr_copy_K); + + flash::cp_async_wait<0>(); + __syncthreads(); + if (n_block > 0) { + // Advance gK + tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride)); + flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV); + // This cp_async_fence needs to be in the if block, otherwise the synchronization + // isn't right and we get race conditions. + cute::cp_async_fence(); + } + + // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) + cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + softmax_rescale_o(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2); + + cute::Tensor rP = flash::convert_type(scores); + // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) + // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. + cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout())); + // if (Return_softmax) { + // cute::Tensor tOrP_copy = make_fragment_like(tOrP); + // copy(tOrP, tOrP_copy); + // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P); + // tPgP.data() = tPgP.data() + (-kBlockN); + // } + + flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V); + } + + // Epilogue + + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + cute::Tensor lse = make_fragment_like(scores_sum); +#pragma unroll + for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) { + float sum = scores_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum; + lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum); + float scale = inv_sum; +#pragma unroll + for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) { + acc_o_rowcol(mi, ni) *= scale; + } + } + + // Convert acc_o from fp32 to fp16/bf16 + cute::Tensor rO = flash::convert_type(acc_o); + cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N) + // Partition sO to match the accumulator partitioning + auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx); + cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N) + cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // sO has the same size as sQ, so we don't need to sync here. + if (Kernel_traits::Share_Q_K_smem) { + __syncthreads(); + } + + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + cute::Shape, cute::Int>{}, + make_stride(params.o_row_stride, _1{})); + cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), + cute::Shape>{}, cute::Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); + cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N) + cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + + __syncthreads(); + + cute::Tensor tOrO = make_tensor(cute::shape(tOgO)); + cute::copy(gmem_tiled_copy_O, tOsO, tOrO); + + cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(cute::size<0>(taccOcO))::value == 4); + // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices. + cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0); + CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(0)) == 0) { +#pragma unroll + for (int mi = 0; mi < cute::size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < binfo.actual_seqlen_q - m_block * kBlockM) { + gLSE(row) = lse(mi); + } + } + } + + // Construct identity layout for sO + cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO))); + if (!Is_even_K) { +#pragma unroll + for (int k = 0; k < cute::size(tOpO); ++k) { + tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; + } + } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void compute_attn(const Params& params) { + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + + // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting + // them to have the same number of threads or have to traverse the attention matrix + // in the same order. + // In the Philox RNG, we use the offset to store the batch, head, and the lane id + // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within + // the attention matrix. This way, as long as we have the batch, head, and the location of + // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. + + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace flash +} // namespace onnxruntime + +#if defined(__GNUC__) +#pragma GCC diagnostic pop +#endif diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h new file mode 100644 index 0000000000000..e633ef4d45fbb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -0,0 +1,210 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include "contrib_ops/cuda/bert/flash_attention/static_switch.h" +#include "contrib_ops/cuda/bert/flash_attention/flash.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h" + +namespace onnxruntime { +namespace flash { + +template +__global__ void flash_fwd_kernel(Flash_fwd_params params) { + flash::compute_attn(params); +} + +template +void run_flash_fwd(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t smem_size = Kernel_traits::kSmemSize; + + // Work-around for gcc 7. It doesn't like nested BOOL_SWITCH. + // https://github.com/kokkos/kokkos-kernels/issues/349 + // https://github.com/HazyResearch/flash-attention/issues/21 + + const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; + dim3 grid(num_m_block, params.b, params.h); + // We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check + // for cu_seqlens_q as well. + const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; + const bool is_even_K = params.d == Kernel_traits::kHeadDim; + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + // ORT_ENFORCE(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + }); + }); +} + +template +void run_mha_fwd_hdim32(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 32; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim96(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 96; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 128; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + }); +} + +template +void run_mha_fwd_hdim160(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 160; + const bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr (!Is_causal) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim192(Flash_fwd_params& params, cudaStream_t stream) { + constexpr int Headdim = 192; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + }); +} + +template +void run_mha_fwd_hdim224(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t Headdim = 224; + constexpr size_t threshold = 2 * Headdim * (128 + 2 * 64); + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_block = %d\n", max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (max_smem_per_block >= threshold) { // 112 KB + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // run_flash_fwd, Is_causal>(params, stream); + // run_flash_fwd, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params& params, cudaStream_t stream) { + constexpr size_t Headdim = 256; + constexpr size_t min_threshold = 2 * Headdim * (128 + 2 * 64); + constexpr size_t max_threshold = 4 * Headdim * (64 + 2 * 64); + size_t max_smem_per_sm = params.dprops->sharedMemPerMultiprocessor; + size_t max_smem_per_block = params.dprops->sharedMemPerBlockOptin; + // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= min_threshold && max_smem_per_sm < max_threshold) { + run_flash_fwd, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_causal>(params, stream); + }); +} + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h new file mode 100644 index 0000000000000..0c967faa85c45 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/kernel_traits.h @@ -0,0 +1,351 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include +#include +#include + +using namespace cute; + +namespace onnxruntime { +namespace flash { + +template +struct Flash_kernel_traits { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using Element = elem_type; + static constexpr bool Has_cp_async = true; +#else + using Element = cutlass::half_t; + static constexpr bool Has_cp_async = false; +#endif + + using ElementAccum = float; + using index_t = uint32_t; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + using MMA_Atom_Arch = std::conditional_t< + std::is_same_v, + MMA_Atom, + MMA_Atom>; + using ValLayoutMNK = cute::Layout>; +#else + using MMA_Atom_Arch = MMA_Atom; + using ValLayoutMNK = cute::Layout>; +#endif + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#else + using SmemCopyAtom = Copy_Atom; + using SmemCopyAtomTransposed = Copy_Atom; +#endif +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template > +struct Flash_fwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Share_Q_K_smem = Share_Q_K_smem_; + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + using TiledMma = TiledMMA< + typename Base::MMA_Atom_Arch, + Layout, _1, _1>>, // 4x1x1 or 8x1x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQ = decltype(composition(Swizzle{}, + // This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128 + Layout>, + Stride, _1>>{})); + using SmemLayoutQ = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + using SmemLayoutKV = decltype(tile_to_shape( + SmemLayoutAtomQ{}, + Shape, Int>{})); + + // This has to be kBlockN and not 8, otherwise we get wrong results for d=128 + using SmemLayoutAtomVtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomVtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomVtransposedNoSwizzle{})); + using SmemLayoutVtransposed = decltype(tile_to_shape( + SmemLayoutAtomVtransposed{}, + Shape, Int>{})); + // Maybe the VtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomVtransposedNoSwizzle{}, + Shape, Int>{})); + // using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn()); + + using SmemLayoutAtomO = decltype(composition(Swizzle{}, + Layout, Int>, + Stride, _1>>{})); + using SmemLayoutO = decltype(tile_to_shape( + SmemLayoutAtomO{}, + Shape, Int>{})); + using SmemCopyAtomO = Copy_Atom; + + static constexpr int kSmemQCount = cute::size(SmemLayoutQ{}); + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemQSize = kSmemQCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts. + // For example, for d=128, smem is split into 2 "pages", each page takes care of columns + // 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem, + // thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page, + // to the same banks. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopyO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP"); + using GmemLayoutAtomP = Layout, Int>, + Stride, _1>>; + + using GmemTiledCopyP = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomP{}, + Layout>{})); // Val layout, 8 vals per store +}; + +// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. +// No_double_buffer is another option to reduce smem usage, but will slow things down. +template > +struct Flash_bwd_kernel_traits : public Base { + using Element = typename Base::Element; + using ElementAccum = typename Base::ElementAccum; + using index_t = typename Base::index_t; + static constexpr bool Has_cp_async = Base::Has_cp_async; + using SmemCopyAtom = typename Base::SmemCopyAtom; + using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed; + + static constexpr bool Is_V_in_regs = Is_V_in_regs_; + static constexpr bool No_double_buffer = No_double_buffer_; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * 32; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32); + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + + static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_; + static_assert(kNWarps % AtomLayoutMSdP == 0); + static_assert(kNWarps % AtomLayoutNdKV == 0); + static_assert(kNWarps % AtomLayoutMdQ == 0); + + using TiledMmaSdP = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadKV = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using TiledMmadQ = TiledMMA< + typename Base::MMA_Atom_Arch, + cute::Layout, cute::Int, _1>>, // 2x4x1 or 4x2x1 thread group + typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM + + using SmemLayoutAtomQdO = decltype(composition(Swizzle{}, + cute::Layout>, + cute::Stride, _1>>{})); + using SmemLayoutQdO = decltype(tile_to_shape( + SmemLayoutAtomQdO{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKV = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutKV = decltype(tile_to_shape( + // SmemLayoutAtomQdO{}, + SmemLayoutAtomKV{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + + using SmemLayoutAtomKtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomKtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomKtransposedNoSwizzle{})); + using SmemLayoutKtransposed = decltype(tile_to_shape( + SmemLayoutAtomKtransposed{}, + make_shape(Int{}, Int{}))); + // Maybe the KtransposeNoSwizzle just needs to have the right shape + // And the strides don't matter? + using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomKtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn()); + + // TODO: generalize to other values of kBlockN + // TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2 + // static constexpr int kPBlockN = kBlockN; + static_assert(kBlockN >= 64); + // TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest. + static constexpr int kPBlockN = 64; + static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64); + // static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3); + static constexpr int kSwizzlePdS = 3; + using SmemLayoutAtomPdS = decltype(composition(Swizzle{}, + cute::Layout, cute::Int>, + cute::Stride, _1>>{})); + using SmemLayoutPdS = decltype(tile_to_shape( + SmemLayoutAtomPdS{}, + cute::make_shape(cute::Int{}, cute::Int{}))); + using SmemLayoutAtomPdStransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomPdStransposed = decltype(composition(Swizzle{}, SmemLayoutAtomPdStransposedNoSwizzle{})); + using SmemLayoutPdStransposed = decltype(tile_to_shape( + SmemLayoutAtomPdStransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomPdStransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn()); + using SmemCopyAtomPdS = Copy_Atom; + + using SmemLayoutAtomQdOtransposedNoSwizzle = Layout, Int>, + Stride<_1, Int>>; + using SmemLayoutAtomQdOtransposed = decltype(composition(Swizzle{}, SmemLayoutAtomQdOtransposedNoSwizzle{})); + using SmemLayoutQdOtransposed = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposed{}, + make_shape(Int{}, Int{}))); + using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape( + SmemLayoutAtomQdOtransposedNoSwizzle{}, + make_shape(Int{}, Int{}))); + // using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn()); + + using SmemLayoutAtomdKV = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdKV = Copy_Atom; + + using SmemLayoutAtomdQ = decltype(composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemCopyAtomdQ = Copy_Atom; + + static constexpr int kSmemQdOCount = cute::size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ + static constexpr int kSmemKVCount = cute::size(SmemLayoutKV{}) * 2; + static constexpr int kSmemdSCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemPCount = cute::size(SmemLayoutPdS{}); + static constexpr int kSmemdQCount = cute::size(SmemLayoutdQ{}); + static constexpr int kSmemdPsumCount = kBlockM; + static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element); + static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element); + static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element); + static constexpr int kSmemPSize = kSmemPCount * sizeof(Element); + static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element); + static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum); + static constexpr int kSmemSize = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize) + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize))); + static constexpr int kSmemSize1colblock = kSmemQdOSize + (!Is_V_in_regs + ? kSmemKVSize + kSmemdSSize + kSmemPSize + : std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize)); + static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3 + kSmemdSSize + kSmemPSize; + + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = cute::Layout, cute::Int>, + cute::Stride, _1>>; + + // We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading + // from the same address by the same threadblock. This is slightly faster. + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy>; + using GmemTiledCopyQKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per read + using GmemTiledCopydO = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydKV = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + cute::Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_8, _1>>, + cute::Layout, // Thread layout, 16 threads per row + cute::Stride<_16, _1>>>; + using GmemTiledCopydQaccum = decltype(make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + cute::Layout>{})); // Val layout, 4 vals per store + + using GmemTiledCopydQaccumAtomicAdd = decltype(make_tiled_copy(Copy_Atom{}, + cute::Layout, // Thread layout, 8 threads per row + cute::Stride<_32, _1>>{}, + cute::Layout>{})); // Val layout, 1 val per store +}; + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h new file mode 100644 index 0000000000000..842edf3a98a86 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/softmax.h @@ -0,0 +1,206 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include + +#include + +#include +#include + +#include "contrib_ops/cuda/bert/flash_attention/utils.h" + +namespace onnxruntime { +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ inline void thread_reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ inline void quad_allreduce_(Tensor& dst, Tensor& src, Operator& op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); +#pragma unroll + for (int i = 0; i < size(dst); i++) { + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ inline void reduce_(Tensor const& tensor, Tensor& summary, Operator& op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ inline void reduce_max(Tensor const& tensor, Tensor& max) { + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ inline void reduce_sum(Tensor const& tensor, Tensor& sum) { + SumOp sum_op; + reduce_(tensor, sum, sum_op); +} + +// Apply the exp to all the elements. +template +inline __device__ void scale_apply_exp2(Tensor& tensor, Tensor const& max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E)); +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +// Apply the exp to all the elements. +template +inline __device__ void max_scale_exp2_sum(Tensor& tensor, Tensor& max, Tensor& sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); +#pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; +#pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + SumOp sum_op; + sum(mi) = Allreduce<4>::run(sum(mi), sum_op); + } +} + +template +inline __device__ void apply_mask(Tensor& tensor, const int max_seqlen_k, + const int col_idx_offset_ = 0) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= max_seqlen_k) { +// Without the "make_coord" we get wrong results +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + tensor(mi, make_coord(j, nj)) = -INFINITY; + } + } + } + } +} + +template +inline __device__ void apply_mask_causal(Tensor& tensor, const int col_idx_offset_, + const int max_seqlen_k, const int row_idx_offset_, + const int max_seqlen_q, const int warp_row_stride) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout::rank == 2, "Only support 2D Tensor"); + const int lane_id = threadIdx.x % 32; + // const int row_idx_offset = row_idx_offset_ + lane_id / 4; + const int row_idx_offset = row_idx_offset_; + const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; +#pragma unroll + for (int mi = 0; mi < size<0, 1>(tensor); ++mi) { + const int row_idx_base = row_idx_offset + mi * warp_row_stride; +#pragma unroll + for (int i = 0; i < size<0, 0>(tensor); ++i) { + const int row_idx = row_idx_base + i * 8; + const int col_idx_limit = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q); +#pragma unroll + for (int nj = 0; nj < size<1, 1>(tensor); ++nj) { + const int col_idx_base = col_idx_offset + nj * 8; +#pragma unroll + for (int j = 0; j < size<1, 0>(tensor); ++j) { + const int col_idx = col_idx_base + j; + if (col_idx >= col_idx_limit) { + tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY; + } + } + } + // if (cute::thread0()) { + // printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k); + // print(tensor(make_coord(i, mi), _)); + // // print(tensor(_, j + nj * size<1, 0>(tensor))); + // } + } + } +} + +template +inline __device__ void apply_mask_causal_w_idx( + Tensor& tensor, Tensor const& idx_rowcol, + const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_) { + // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 2, "Only support 2D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); + CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol)); +#pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0))); +#pragma unroll + for (int ni = 0; ni < size<1, 1>(tensor); ++ni) { + if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) { + tensor(mi, ni) = -INFINITY; + } + } + // if (cute::thread0()) { + // printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k); + // print(tensor(_, make_coord(j, ni))); + // // print(tensor(_, j + ni * size<1, 0>(tensor))); + // } + } +} + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h new file mode 100644 index 0000000000000..05ac2476690c2 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/static_switch.h @@ -0,0 +1,60 @@ +// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define FP16_SWITCH(COND, ...) \ + [&] { \ + assert(COND); \ + using elem_type = cutlass::half_t; \ + return __VA_ARGS__(); \ + }() + +#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM <= 32) { \ + constexpr static int kHeadDim = 32; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 64) { \ + constexpr static int kHeadDim = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 96) { \ + constexpr static int kHeadDim = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 128) { \ + constexpr static int kHeadDim = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 160) { \ + constexpr static int kHeadDim = 160; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 192) { \ + constexpr static int kHeadDim = 192; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 224) { \ + constexpr static int kHeadDim = 224; \ + return __VA_ARGS__(); \ + } else if (HEADDIM <= 256) { \ + constexpr static int kHeadDim = 256; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h new file mode 100644 index 0000000000000..49ee687419d0e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/utils.h @@ -0,0 +1,371 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include +#include + +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// +namespace onnxruntime { +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ uint32_t relu2(const uint32_t x); + +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("max.f16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#else + asm volatile( + "{\n" + "\t .reg .f16x2 sela;\n" + "\t set.gtu.u32.f16x2 sela, %1, %2;\n" + "\t and.b32 %0, sela, %1;\n" + "}\n" + : "=r"(res) + : "r"(x), "r"(zero)); +#endif + return res; +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ uint32_t relu2(const uint32_t x) { + uint32_t res; + const uint32_t zero = 0u; + asm volatile("max.bf16x2 %0, %1, %2;\n" + : "=r"(res) + : "r"(x), "r"(zero)); + return res; +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +template +inline __device__ uint32_t convert_relu2(const float2 x); + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +template <> +inline __device__ uint32_t convert_relu2(const float2 x) { + uint32_t res; + const uint32_t a = reinterpret_cast(x.x); + const uint32_t b = reinterpret_cast(x.y); + asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" + : "=r"(res) + : "r"(b), "r"(a)); + return res; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float2 half2_unpack(uint32_t a); + +template <> +inline __device__ float2 half2_unpack<__half>(uint32_t a) { + return __half22float2(reinterpret_cast<__half2(&)>(a)); +} + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +template <> +inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) { + return __bfloat1622float2(reinterpret_cast<__nv_bfloat162(&)>(a)); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert two half2's or bf162's into float, then take their dot product. +template +inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) { + float2 af = flash::half2_unpack(a); + float2 bf = flash::half2_unpack(b); + return af.x * bf.x + af.y * bf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Converted two vectors of 8 half's or bf16's into float, then take their dot product. +template +inline __device__ float hmulsum8(const uint4 a, const uint4 b) { + float sum; + sum = flash::hfma2_to_float(a.x, b.x); + sum += flash::hfma2_to_float(a.y, b.y); + sum += flash::hfma2_to_float(a.z, b.z); + sum += flash::hfma2_to_float(a.w, b.w); + return sum; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { + __device__ inline T operator()(T const& x, T const& y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { + // This is slightly faster + __device__ inline float operator()(float const& x, float const& y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { + __device__ inline T operator()(T const& x, T const& y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator& op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct Allreduce<2> { + template + static __device__ inline T run(T x, Operator& op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsA, + Tensor4 const& tCsB, TiledMma tiled_mma, + TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B, + ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); + } +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + if (!A_in_regs) { + cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); + } + if (!B_in_regs) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void gemm_A_in_regs(Tensor0& acc, Tensor1& tCrA, Tensor2& tCrB, Tensor3 const& tCsB, + TiledMma tiled_mma, TiledCopy smem_tiled_copy_B, + ThrCopy smem_thr_copy_B) { + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K + Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N + cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); +#pragma unroll + for (int i = 0; i < size<2>(tCrA); ++i) { + if (i < size<2>(tCrA) - 1) { + cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); + } + cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +template +inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8. +template +inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) { + using X = Underscore; + static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2); + static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2; + auto l = logical_divide(rowcol_layout, Shape>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2))) + return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)), + get<0, 1>(l), + get<1, 1, 1>(l)); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ auto convert_type(Tensor const& tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast*>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void relu_(Tensor& tensor) { + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); + using value_t = typename Engine::value_type; + // HACK: this requires tensor to be "contiguous" + Tensor tensor_uint32 = recast(tensor); +#pragma unroll + for (int i = 0; i < size(tensor_uint32); ++i) { + tensor_uint32(i) = relu2(tensor_uint32(i)); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction +template +inline __device__ auto convert_type_relu(Tensor const& tensor) { + using From_type = typename Engine::value_type; + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + constexpr int numel = decltype(size(tensor))::value; + static_assert(numel % 2 == 0); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + // HACK: this requires tensor to be "contiguous" + Tensor tensor_float2 = recast(tensor); + Tensor out_uint32 = make_tensor(tensor_float2.layout()); +#pragma unroll + for (int i = 0; i < size(out_uint32); ++i) { + out_uint32(i) = convert_relu2(tensor_float2(i)); + } + Tensor out = make_tensor(make_rmem_ptr(out_uint32.data()), tensor.layout()); +#else + Tensor out = flash::convert_type(tensor); + flash::relu_(out); +#endif + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Blocks until all but N previous cp.async.commit_group operations have committed. +// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all +// (which is equivalent to commit_group then wait_group 0). +// Instead we just call cp.async.wait_group 0, which is slightly faster. +// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113 +template +CUTE_HOST_DEVICE void cp_async_wait() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void copy(TiledCopy thr_copy, Tensor const& S, + Tensor& D, Tensor const& identity_MN, + Tensor const& predicate_K, int max_MN = 0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); +#pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { +#pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + copy(thr_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index 15f0bc1a746d3..8f1252f863ef6 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -7,6 +7,7 @@ #include "contrib_ops/cuda/bert/multihead_attention.h" #include "contrib_ops/cpu/bert/multihead_attention_helper.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -51,6 +52,17 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false); #if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || + ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false); + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + attention::kMinSeqLenForFlashAttentionPackedQKV, + attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); +#else + disable_flash_attention_ = true; + min_seq_len_for_flash_attention_packed_qkv_ = 0; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false); #else disable_memory_efficient_attention_ = true; @@ -118,9 +130,35 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { int sm = device_prop.major * 10 + device_prop.minor; bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; - bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; - bool use_fused_cross_attention = !disable_fused_cross_attention_ && + const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); + +#if USE_FLASH_ATTENTION || USE_MEMORY_EFFICIENT_ATTENTION + // Exclude this case since PrepareQkv will convert the format to BNSH. + bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; +#endif + +#if USE_FLASH_ATTENTION + bool use_flash_attention = !disable_flash_attention_ && + !past_no_bias && + nullptr == relative_position_bias && + nullptr == key_padding_mask && + parameters.head_size == parameters.v_head_size && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + // When input is packed QKV format, TensorRT kernel might be faster than flash attention when sequence length <= 512. + if (use_flash_attention && key == nullptr && value == nullptr && + parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + use_flash_attention = false; + } +#else + constexpr bool use_flash_attention = false; +#endif + + bool use_fused_cross_attention = !use_flash_attention && + !disable_fused_cross_attention_ && nullptr == key_padding_mask && nullptr == relative_position_bias && (nullptr == past_key && nullptr == past_value && !parameters.pass_past_in_kv) && @@ -141,7 +179,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - bool use_fused_runner = !disable_fused_self_attention_ && + bool use_fused_runner = !use_flash_attention && + !disable_fused_self_attention_ && fused_cross_attention_kernel == nullptr && nullptr == relative_position_bias && (value != nullptr || key == nullptr) && @@ -166,32 +205,30 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { } } - const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value); - -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16 - parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 || - parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32; - - // Exclude this case since PrepareQkv will convert the format to BNSH. - bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr; + parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32 || + parameters.kv_sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32; bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0; - bool use_memory_efficient_attention = fused_runner == nullptr && + bool use_memory_efficient_attention = !use_flash_attention && + fused_runner == nullptr && fused_cross_attention_kernel == nullptr && !disable_memory_efficient_attention_ && + (parameters.head_size & 7) == 0 && + (parameters.v_head_size & 7) == 0 && is_long_sequence && !past_no_bias && (relative_position_bias == nullptr || is_good_for_rpb) && - (nullptr == key_padding_mask || is_mask_1d_key_seq_len_start) && + (nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) && has_memory_efficient_attention(sm, sizeof(T) == 2); #else constexpr bool use_memory_efficient_attention = false; - ORT_UNUSED_PARAMETER(is_mask_1d_key_seq_len_start); #endif // When packed kv or packed qkv is used, there is no needed for add bias transpose thus no qkv workspace. + // TODO(tianleiwu): flash attention or memory efficient attention might not need qkv workspace sometime. bool no_qkv_workspace = nullptr == value && (use_fused_cross_attention || (nullptr != fused_runner && nullptr == key)) && nullptr == key_padding_mask && @@ -211,6 +248,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); } @@ -219,8 +257,9 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { const size_t past_k_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.head_size; const size_t past_v_bytes = element_size * parameters.batch_size * parameters.kv_sequence_length * parameters.num_heads * parameters.v_head_size; - auto temp_k_work_space = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; - auto temp_v_work_space = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; + const bool use_temp_k_v_workspace = parameters.pass_past_in_kv || use_memory_efficient_attention || use_flash_attention; + auto temp_k_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_k_bytes, context->GetComputeStream()) : nullptr; + auto temp_v_work_space = use_temp_k_v_workspace ? GetScratchBuffer(past_v_bytes, context->GetComputeStream()) : nullptr; typedef typename ToCudaType::MappedType CudaT; AttentionData data; @@ -241,14 +280,15 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); data.has_qkv_workspace = !no_qkv_workspace; data.workspace = reinterpret_cast(work_space.get()); - data.temp_k_workspace = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? reinterpret_cast(temp_k_work_space.get()) : nullptr; - data.temp_v_workspace = (parameters.pass_past_in_kv || use_memory_efficient_attention) ? reinterpret_cast(temp_v_work_space.get()) : nullptr; + data.temp_k_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_k_work_space.get()) : nullptr; + data.temp_v_workspace = use_temp_k_v_workspace ? reinterpret_cast(temp_v_work_space.get()) : nullptr; data.output = reinterpret_cast(output->MutableData()); data.present = nullptr; data.present_key = (nullptr == present_key) ? nullptr : reinterpret_cast(present_key->MutableData()); data.present_value = (nullptr == present_value) ? nullptr : reinterpret_cast(present_value->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); data.fused_cross_attention_kernel = fused_cross_attention_kernel; + data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = &(this->cumulated_sequence_length_q_cache_); data.cumulated_sequence_length_kv_cache = &(this->cumulated_sequence_length_kv_cache_); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index af5045e70d3b4..33fa3d50e4564 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -28,7 +28,9 @@ class MultiHeadAttention final : public CudaKernel { bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; + bool disable_flash_attention_; bool disable_memory_efficient_attention_; + int min_seq_len_for_flash_attention_packed_qkv_; mutable std::unique_ptr fused_fp16_runner_; mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_; mutable CumulatedSequenceLengthCache cumulated_sequence_length_q_cache_; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc index 1b2c5f6200839..ec8b1d051b3d9 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention.cc @@ -283,7 +283,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { MHARunner* fused_runner = this->GetFusedRunner(device_prop, parameters); bool use_memory_efficient_attention = false; -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (nullptr == fused_runner) { int sm = device_prop.major * 10 + device_prop.minor; bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; @@ -324,6 +324,7 @@ Status PackedAttention::ComputeInternal(OpKernelContext* context) const { parameters.v_head_size, parameters.sequence_length, fused_runner, + false, use_memory_efficient_attention, no_qkv_workspace); auto work_space = this->GetScratchBuffer(workSpaceSize, context->GetComputeStream()); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index 5a99a98ce86be..aba0efdbd7d5f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -16,6 +16,7 @@ #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/rotary_embedding_util.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -47,22 +48,32 @@ size_t GetAttentionWorkspaceSize( size_t v_head_size, size_t sequence_length, void* fused_runner, + bool use_flash_attention, bool use_memory_efficient_attention, bool no_qkv_workspace) { // Note that q, k and v might need alignment for fused attention kernels. const size_t qkv_bytes = no_qkv_workspace ? 0 : (element_size * batch_size * num_heads * sequence_length * (qk_head_size + qk_head_size + v_head_size)); +#if USE_FLASH_ATTENTION + // Use portion of workspace for softmax buffer. + if (use_flash_attention) { + size_t flash_buffer_bytes = onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads); + return qkv_bytes + flash_buffer_bytes; + } +#else + ORT_UNUSED_PARAMETER(use_flash_attention); +#endif + if (fused_runner != nullptr) { return qkv_bytes; } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (use_memory_efficient_attention) { size_t fmha_buffer_bytes = 0; if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) { fmha_buffer_bytes = batch_size * sequence_length * num_heads * v_head_size * sizeof(float); } - return qkv_bytes + fmha_buffer_bytes; } #else @@ -455,7 +466,7 @@ Status FusedScaledDotProductAttention( return Status::OK(); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION template Status FusedScaledDotProductAttentionCutlass( const cudaDeviceProp& device_prop, @@ -635,7 +646,7 @@ Status QkvToContext( return FusedScaledDotProductAttention(device_prop, stream, parameters, data); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (data.use_memory_efficient_attention) { return FusedScaledDotProductAttentionCutlass(device_prop, stream, parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h index 9476bbed26e8d..629ca59c73f16 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.h @@ -25,6 +25,7 @@ size_t GetAttentionWorkspaceSize( size_t v_head_size, size_t sequence_length, void* fused_runner, + bool use_flash_attention, bool use_memory_efficient_attention, bool no_qkv_workspace); diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc index 8ffae86ae53cf..1b026e64778e3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.cc @@ -9,6 +9,7 @@ #include "contrib_ops/cuda/bert/packed_multihead_attention_impl.h" #include "contrib_ops/cuda/bert/bert_padding.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace ::onnxruntime::common; @@ -42,6 +43,17 @@ PackedMultiHeadAttention::PackedMultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION + disable_flash_attention_ = sizeof(T) != 2 || onnxruntime::ParseEnvironmentVariableWithDefault( + attention::kDisableFlashAttention, false); + min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault( + attention::kMinSeqLenForFlashAttentionPackedQKV, + attention::kDefaultMinSeqLenForFlashAttentionPackedQKV); +#else + disable_flash_attention_ = true; + min_seq_len_for_flash_attention_packed_qkv_ = 0; +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION disable_memory_efficient_attention_ = onnxruntime::ParseEnvironmentVariableWithDefault( attention::kDisableMemoryEfficientAttention, false); #else @@ -94,8 +106,9 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, int64_t v_hidden_size = hidden_size; if (query_dims.size() == 4) { if (key != nullptr || value != nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key' and 'value' is expected to be empty when 'query' has 4 dimensions in packing mode"); + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' is expected to be empty when 'query' has 4 dimensions in packing mode"); } } else { // query_dims.size() == 2 if (key == nullptr) { @@ -143,11 +156,12 @@ Status PackedMultiHeadAttention::CheckInputs(const TensorShape& query_shape, const auto& cu_seq_len_dims = cu_seq_len_shape.GetDims(); if (cu_seq_len_dims.size() != 1 || cu_seq_len_dims[0] != batch_size + 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'cumulative_sequence_length' should have 1 dimension with size equal to batch_size + 1"); + return ORT_MAKE_STATUS( + ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cumulative_sequence_length' should have 1 dimension with size equal to batch_size + 1"); } - // TODO(tianleiwu): move relative position bias shape checker to a helper function. It is shared by multiple operators. + // TODO(tianleiwu): move relative position bias shape checker to a helper function. It is shared by multiple ops. const int num_heads = this->GetNumHeads(); bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { @@ -227,19 +241,39 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co Tensor* output = context->Output(0, output_shape); auto& device_prop = this->GetDeviceProp(); - MHARunner* fused_runner = this->GetFusedRunner(device_prop, parameters); + + bool use_flash_attention = false; +#if USE_FLASH_ATTENTION + if (!disable_flash_attention_) { + use_flash_attention = !parameters.has_relative_position_bias && + parameters.head_size == parameters.v_head_size && + onnxruntime::flash::is_supported(device_prop, + parameters.head_size, + parameters.num_heads, + parameters.num_heads); + + // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512. + if (use_flash_attention && key == nullptr && value == nullptr && + parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) { + use_flash_attention = false; + } + } +#endif + + MHARunner* fused_runner = use_flash_attention ? nullptr : this->GetFusedRunner(device_prop, parameters); bool use_memory_efficient_attention = false; -#if USE_FLASH_ATTENTION - if (nullptr == fused_runner && !disable_memory_efficient_attention_) { +#if USE_MEMORY_EFFICIENT_ATTENTION + if (!use_flash_attention && nullptr == fused_runner && !disable_memory_efficient_attention_) { int sm = device_prop.major * 10 + device_prop.minor; bool is_good_for_rpb = !parameters.has_relative_position_bias || parameters.sequence_length % (4 * sizeof(T)) == 0; - use_memory_efficient_attention = is_good_for_rpb && - (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) && - (parameters.head_size & 7) == 0 && - (parameters.v_head_size & 7) == 0 && - has_memory_efficient_attention(sm, sizeof(T) == 2); + use_memory_efficient_attention = + is_good_for_rpb && + (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && + (parameters.head_size & 7) == 0 && + (parameters.v_head_size & 7) == 0 && + has_memory_efficient_attention(sm, sizeof(T) == 2); } #endif @@ -250,7 +284,9 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co constexpr size_t element_size = sizeof(T); // When the source and target format is same (like TN3H => TN3H, or TNH => TNH) and no bias, need not transpose qkv. const bool no_qkv_workspace = (fused_runner != nullptr && key == nullptr && bias == nullptr) || - (use_memory_efficient_attention && value != nullptr && bias == nullptr); + ((use_memory_efficient_attention || use_flash_attention) && + value != nullptr && + bias == nullptr); size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, parameters.batch_size, parameters.num_heads, @@ -258,6 +294,7 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co parameters.v_head_size, parameters.sequence_length, fused_runner, + use_flash_attention, use_memory_efficient_attention, no_qkv_workspace); auto work_space = this->GetScratchBuffer(workSpaceSize, context->GetComputeStream()); @@ -268,12 +305,15 @@ Status PackedMultiHeadAttention::ComputeInternal(OpKernelContext* context) co data.key = (key == nullptr) ? nullptr : reinterpret_cast(key->Data()); data.value = (value == nullptr) ? nullptr : reinterpret_cast(value->Data()); data.bias = (bias == nullptr) ? nullptr : reinterpret_cast(bias->Data()); - data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data()); + data.relative_position_bias = (nullptr == relative_position_bias) + ? nullptr + : reinterpret_cast(relative_position_bias->Data()); data.workspace = reinterpret_cast(work_space.get()); data.token_offset = token_offset->Data(); data.cumulative_sequence_length = cumulative_sequence_length->Data(); data.output = reinterpret_cast(output->MutableData()); data.fused_runner = reinterpret_cast(fused_runner); + data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.no_qkv_workspace = no_qkv_workspace; data.source_qkv_format = (key == nullptr) ? AttentionQkvFormat::QKV_TN3H : AttentionQkvFormat::Q_K_V_TNH; diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h index b59463a7769fa..e30c603dc30aa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention.h @@ -31,6 +31,8 @@ class PackedMultiHeadAttention final : public TrtFusedAttention, public CudaK float scale_; // the scale for softmax in memory efficient attention or unfused attention. bool disable_memory_efficient_attention_; + bool disable_flash_attention_; + int min_seq_len_for_flash_attention_packed_qkv_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index d27cf975cb2c8..e09fd9e6b36e5 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -17,6 +17,7 @@ #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h" #include "contrib_ops/cuda/bert/rotary_embedding_util.h" +#include "contrib_ops/cuda/bert/flash_attention/flash_api.h" using namespace onnxruntime::cuda; using namespace onnxruntime::contrib::attention_softmax_cuda; @@ -32,7 +33,6 @@ static constexpr int32_t kMAX_THREADS_PER_BLOCK = 256; #define ADD_BIAS(value, bias_value) (biases == nullptr) ? value : (value + bias_value) #define GET_BIAS(bias_value) (biases == nullptr) ? T{} : bias_value - // Grid: (S, B) // Block: 256 // For unfused PackedMultiHeadAttention @@ -208,7 +208,6 @@ __global__ void TransposeQKV_TNH_TN3H( } } - // Grid: (S, B) // Block: 256 // For unfused PackedMultiHeadAttention @@ -329,7 +328,6 @@ __global__ void TransposeQKV_TN3H_3TNH( } } - // Grid: (T) // Block: 256 // For TRT fused attention. @@ -378,7 +376,6 @@ __global__ void AddBias_TN3H_TN3H( } } - template void InvokeTranspose( const T* query, const T* key, const T* value, const T* bias, T* output, @@ -587,6 +584,77 @@ Status FusedAttentionTrt( #if USE_FLASH_ATTENTION template +Status FlashAttention( + const cudaDeviceProp& device_prop, + cudaStream_t stream, + PackedAttentionParameters& parameters, + PackedMultiHeadAttentionData& data) { + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int qk_head_size = parameters.head_size; + const int v_head_size = parameters.v_head_size; + + // Q, K and V pointers + const int model_dimension_qk = num_heads * qk_head_size; + const int model_dimension_v = num_heads * v_head_size; + const size_t elements_qk = static_cast(parameters.token_count) * static_cast(model_dimension_qk); + const size_t elements_v = static_cast(parameters.token_count) * static_cast(model_dimension_v); + + // When separated Q, K, V is used, we can directly use them in Cutlass FMHA. Otherwise, transpose BSN3H to 3BSNH + if (!data.no_qkv_workspace) { + LaunchTranspose(data.query, data.key, data.value, data.bias, data.workspace, + batch_size, sequence_length, + num_heads, qk_head_size, v_head_size, + data.source_qkv_format, AttentionQkvFormat::Q_K_V_TNH, + data.token_offset, parameters.token_count, stream); + } + + float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast(qk_head_size)) + : parameters.scale; + int32_t* cu_seqlens_q = const_cast(data.cumulative_sequence_length); + int32_t* cu_seqlens_k = const_cast(data.cumulative_sequence_length); + const void* query = data.no_qkv_workspace ? data.query : data.workspace; + const void* key = data.no_qkv_workspace ? data.key : (data.workspace + elements_qk); + const void* value = data.no_qkv_workspace ? data.value : (data.workspace + elements_qk + elements_qk); + void* softmax_lse_buffer = data.no_qkv_workspace + ? data.workspace + : (data.workspace + elements_qk + elements_qk + elements_v); + + ORT_RETURN_IF_ERROR( + onnxruntime::flash::mha_varlen_fwd( + device_prop, + stream, + const_cast(query), + const_cast(key), + const_cast(value), + data.output, + cu_seqlens_q, + cu_seqlens_k, + softmax_lse_buffer, + batch_size, + num_heads, + num_heads, // num_heads_k + qk_head_size, + sequence_length, + sequence_length, + scale, + false // is causal + )); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", reinterpret_cast(key), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", reinterpret_cast(value), parameters.token_count, num_heads, v_head_size); + DUMP_TENSOR_D("cumulative_sequence_length", data.cumulative_sequence_length, 1, batch_size + 1); + DUMP_TENSOR("PackedMHA flash output", data.output, parameters.token_count, num_heads, v_head_size); + + return Status::OK(); +} +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION +template Status FusedAttentionCutlass( const cudaDeviceProp& device_prop, cudaStream_t stream, @@ -641,10 +709,10 @@ Status FusedAttentionCutlass( run_memory_efficient_attention(p); DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("PackedMHA cutlass q(BSNH)", reinterpret_cast(p.query), parameters.token_count, num_heads * qk_head_size); - DUMP_TENSOR_D("PackedMHA cutlass k(BSNH)", reinterpret_cast(p.key), parameters.token_count, num_heads * qk_head_size); - DUMP_TENSOR_D("PackedMHA cutlass v(BSNH)", reinterpret_cast(p.value), parameters.token_count, num_heads * v_head_size); - DUMP_TENSOR_D("PackedMHA cutlass cumulative_sequence_length", data.cumulative_sequence_length, 1, batch_size + 1); + DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(p.query), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("k(BSNH)", reinterpret_cast(p.key), parameters.token_count, num_heads, qk_head_size); + DUMP_TENSOR_D("v(BSNH)", reinterpret_cast(p.value), parameters.token_count, num_heads, v_head_size); + DUMP_TENSOR_D("cumulative_sequence_length", data.cumulative_sequence_length, 1, batch_size + 1); DUMP_TENSOR("PackedMHA cutlass output", data.output, parameters.token_count, num_heads, v_head_size); return Status::OK(); @@ -707,10 +775,10 @@ Status UnfusedAttention( // Q, K and V are ready now DUMP_TENSOR_INIT(); - DUMP_TENSOR_D("PackedMHA unfused q (BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("PackedMHA unfused k (BNSH)", k, batch_size, num_heads, sequence_length, qk_head_size); - DUMP_TENSOR_D("PackedMHA unfused v (BNSH)", v, batch_size, num_heads, sequence_length, v_head_size); - DUMP_TENSOR_D("PackedMHA unfused QK", scaled_qk, batch_size * num_heads, sequence_length, sequence_length); + DUMP_TENSOR_D("q (BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("k (BNSH)", k, batch_size, num_heads, sequence_length, qk_head_size); + DUMP_TENSOR_D("v (BNSH)", v, batch_size, num_heads, sequence_length, v_head_size); + DUMP_TENSOR_D("QK", scaled_qk, batch_size, num_heads, sequence_length, sequence_length); const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length); @@ -727,7 +795,7 @@ Status UnfusedAttention( num_heads, attention_score, stream)); - DUMP_TENSOR_D("PackedMHA unfused Softmax", attention_score, batch_size * num_heads, sequence_length, sequence_length); + DUMP_TENSOR_D("Softmax", attention_score, batch_size, num_heads, sequence_length, sequence_length); // compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v T* temp_output = qkv; @@ -762,6 +830,12 @@ Status QkvToContext( } #if USE_FLASH_ATTENTION + if (data.use_flash_attention) { + return FlashAttention(device_prop, stream, parameters, data); + } +#endif + +#if USE_MEMORY_EFFICIENT_ATTENTION if (data.use_memory_efficient_attention) { return FusedAttentionCutlass(device_prop, stream, parameters, data); } diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h index c7b72808787d7..eeca72f16e64e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.h @@ -29,6 +29,7 @@ struct PackedMultiHeadAttentionData { void* fused_runner; + bool use_flash_attention; bool use_memory_efficient_attention; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h index d61501f429329..ce42e33ba1bfd 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/fused_multihead_attention_v2.h @@ -855,6 +855,139 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2 { false, false}, + {DATA_TYPE_FP16, + 32, + 32, + kSM_80, + cubin_fmha_v2_fp16_32_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_32_32_sm80_cu_cubin_len, + "fmha_v2_fp16_32_32_sm80_kernel", + 8192, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 64, + 32, + kSM_80, + cubin_fmha_v2_fp16_64_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_64_32_sm80_cu_cubin_len, + "fmha_v2_fp16_64_32_sm80_kernel", + 16384, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 96, + 32, + kSM_80, + cubin_fmha_v2_fp16_96_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_96_32_sm80_cu_cubin_len, + "fmha_v2_fp16_96_32_sm80_kernel", + 24576, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 128, + 32, + kSM_80, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin_len, + "fmha_v2_fp16_128_32_sm80_kernel", + 32768, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 128, + 32, + kSM_80, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_128_32_sm80_cu_cubin_len, + "fmha_v2_fp16_128_32_sm80_kernel_nl", + 20480, + 128, + 32, + false, + false}, + {DATA_TYPE_FP16, + 192, + 32, + kSM_80, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin_len, + "fmha_v2_fp16_192_32_sm80_kernel", + 16384, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 192, + 32, + kSM_80, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_192_32_sm80_cu_cubin_len, + "fmha_v2_fp16_192_32_sm80_kernel_nl", + 16384, + 128, + 32, + false, + false}, + {DATA_TYPE_FP16, + 256, + 32, + kSM_80, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin_len, + "fmha_v2_fp16_256_32_sm80_kernel", + 20480, + 128, + 0, + false, + false}, + {DATA_TYPE_FP16, + 256, + 32, + kSM_80, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_256_32_sm80_cu_cubin_len, + "fmha_v2_fp16_256_32_sm80_kernel_nl", + 20480, + 128, + 32, + false, + false}, + {DATA_TYPE_FP16, + 384, + 32, + kSM_80, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin_len, + "fmha_v2_fp16_384_32_sm80_kernel", + 32768, + 256, + 0, + false, + false}, + {DATA_TYPE_FP16, + 384, + 32, + kSM_80, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin, + cubin_fmha_v2_fp16_384_32_sm80_cu_cubin_len, + "fmha_v2_fp16_384_32_sm80_kernel_nl", + 32768, + 256, + 32, + false, + false}, + // GA10x: sm86 uses sm80 kernels {DATA_TYPE_FP16, 32, diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc index f899a73ee0c81..b0556512de0b7 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -174,8 +174,9 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present = context->Output(1, present_shape); void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up - bool use_fused_cross_attention = false; - bool use_memory_efficient_attention = false; + constexpr bool use_fused_cross_attention = false; + constexpr bool use_memory_efficient_attention = false; + constexpr bool use_flash_attention = false; size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, parameters.num_heads, @@ -185,6 +186,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { parameters.kv_sequence_length, parameters.total_sequence_length, fused_runner, + use_flash_attention, use_fused_cross_attention, use_memory_efficient_attention); @@ -211,6 +213,7 @@ Status QAttention::ComputeInternal(OpKernelContext* context) const { data.present_value = nullptr; data.fused_runner = fused_runner; data.fused_cross_attention_kernel = nullptr; + data.use_flash_attention = use_flash_attention; data.use_memory_efficient_attention = use_memory_efficient_attention; data.cumulated_sequence_length_q_cache = nullptr; data.cumulated_sequence_length_kv_cache = nullptr; diff --git a/onnxruntime/core/framework/allocation_planner.cc b/onnxruntime/core/framework/allocation_planner.cc index 300db24a986f4..0bf27fdf5e5dc 100644 --- a/onnxruntime/core/framework/allocation_planner.cc +++ b/onnxruntime/core/framework/allocation_planner.cc @@ -1715,31 +1715,39 @@ class PlannerImpl { void PartitionIntoStreams(const logging::Logger& /*logger*/, const ExecutionProviders& /*execution_providers*/, const PathString& /*partition_config_file*/) { - stream_nodes_.push_back({}); - node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); - for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) { - stream_nodes_[0].push_back(node_index); - node_stream_map_[node_index] = 0; + if (graph_viewer_.NumberOfNodes() > 0) { + stream_nodes_.push_back({}); + node_stream_map_.resize(SafeInt(graph_viewer_.MaxNodeIndex()) + 1); + for (auto node_index : graph_viewer_.GetNodesInTopologicalOrder()) { + stream_nodes_[0].push_back(node_index); + node_stream_map_[node_index] = 0; + } + num_logic_streams_ = 1; } - num_logic_streams_ = 1; } Status BuildExecutionPlan(const ExecutionProviders& execution_providers) { // 1. create logic stream instance auto& execution_plan = plan_.execution_plan; - ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty()); - execution_plan.reserve(1); - auto first_node_index = stream_nodes_[0][0]; - auto* node = graph_viewer_.GetNode(first_node_index); - onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); - const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); - ORT_ENFORCE(ep); - auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault); - execution_plan.emplace_back(std::make_unique(node_device_mem_location)); - // 2. add steps to the execution plan - for (auto node_index : stream_nodes_[0]) { - execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); + + if (graph_viewer_.NumberOfNodes() > 0) { + ORT_ENFORCE(num_logic_streams_ == 1 && !stream_nodes_[0].empty()); + execution_plan.reserve(1); + auto first_node_index = stream_nodes_[0][0]; + auto* node = graph_viewer_.GetNode(first_node_index); + onnxruntime::ProviderType exec_provider_name = node->GetExecutionProviderType(); + const IExecutionProvider* ep = execution_providers.Get(exec_provider_name); + ORT_ENFORCE(ep); + auto node_device_mem_location = ep->GetOrtDeviceByMemType(OrtMemType::OrtMemTypeDefault); + execution_plan.emplace_back(std::make_unique(node_device_mem_location)); + // 2. add steps to the execution plan + for (auto node_index : stream_nodes_[0]) { + execution_plan[0]->steps_.emplace_back(std::make_unique(node_index)); + } + } else { + // graph with no nodes. e.g. subgraph of If might return the input as-is or a constant value from an initializer } + return Status::OK(); } diff --git a/onnxruntime/core/framework/session_state_utils.cc b/onnxruntime/core/framework/session_state_utils.cc index fc2f14263f7a7..df3a7afebc176 100644 --- a/onnxruntime/core/framework/session_state_utils.cc +++ b/onnxruntime/core/framework/session_state_utils.cc @@ -254,10 +254,11 @@ common::Status SaveInitializedTensors( auto initialized_tensors_to_allocate = id_to_initialized_tensor; for (int ort_value_index : initializer_allocation_order) { const auto entry = initialized_tensors_to_allocate.find(ort_value_index); + ORT_ENFORCE(entry != initialized_tensors_to_allocate.end(), + "OrtValue index: ", ort_value_index, " from initializer_allocation_order not found among initialized tensors"); if (!(utils::HasExternalData(*entry->second) && exec_plan.GetLocation(ort_value_index).Type() == OrtDevice::CPU)) { // can not trace string tensor - ORT_ENFORCE(entry != initialized_tensors_to_allocate.end() && - entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING); + ORT_ENFORCE(entry->second->data_type() != ONNX_NAMESPACE::TensorProto_DataType_STRING, "Can not trace string tensor"); ORT_RETURN_IF_ERROR(planner.Trace(entry->first, entry->second)); } initialized_tensors_to_allocate.erase(entry); diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 5a42f5d34b931..08ed811d9ac38 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1492,7 +1492,7 @@ Status UnpackInitializerData(const onnx::TensorProto& initializer, if (initializer.data_location() == TensorProto_DataLocation_EXTERNAL) { ORT_RETURN_IF_ERROR(ReadExternalDataForTensor( initializer, - model_path.IsEmpty() ? nullptr : model_path.ParentPath().ToPathString().c_str(), + (model_path.IsEmpty() || model_path.ParentPath().IsEmpty()) ? nullptr : model_path.ParentPath().ToPathString().c_str(), unpacked_tensor)); return Status::OK(); } diff --git a/onnxruntime/core/framework/utils.h b/onnxruntime/core/framework/utils.h index 56f41154b719c..ea6a629f87cb8 100644 --- a/onnxruntime/core/framework/utils.h +++ b/onnxruntime/core/framework/utils.h @@ -223,7 +223,7 @@ constexpr ONNXTensorElementDataType GetONNXTensorElementDataType int32_t ONNXTensorElementDataTypeToProtoTensorType(ONNXTensorElementDataType); -#ifdef ENABLE_TRAINING_CORE +#ifdef ENABLE_TRAINING common::Status VerifyInputTensorsAllocatedContiguously(OpKernelContext* context); #endif diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 4b7900194488d..aa0727e3750b0 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -344,10 +344,15 @@ std::unique_ptr CreateSchema(const std::string& functi std::unordered_map map_copy(model_local_functions.begin(), model_local_functions.end()); std::unordered_map empty_map; - ONNX_NAMESPACE::shape_inference::SymbolTableImpl symbolTable; + + // https://github.com/microsoft/onnxruntime/issues/17061 + // We are passing a nullptr for the symbol table, because symbol table must be global + // for all the shape inferencing to work correctly. Otherwise, unrelated shapes get + // the same symbolic shapes and are marked for memory re-use. This is a Temp fix. + constexpr ONNX_NAMESPACE::shape_inference::SymbolTableImpl* symbolTable = nullptr; ONNX_NAMESPACE::shape_inference::InferShapeForFunctionNode(*onnx_func_proto, func_domain_to_version, schema_registry, ctx, options, map_copy, - &symbolTable, &empty_map); + symbolTable, &empty_map); }); op_schema->Finalize(); diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc index a3ac4312053aa..dd38ee9b07ee6 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.cc @@ -462,6 +462,27 @@ bool LayerNormalizationGatherActor::PreCheck(const Graph& /* graph */, return true; } +bool LayerNormalizationGatherActor::PostProcess(Graph& /*graph*/, Node& current_node, + const SliceInfo& info_without_node, + const logging::Logger& /*logger*/, + const std::unordered_map& /*propagate_input_indices*/, + const std::unordered_map>& + /*all_input_cmp_rets*/, + const std::unordered_map& /*new_gather_infos*/) { + // Update LayerNormalization's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -479,6 +500,28 @@ bool SoftmaxGatherActor::PreCheck(const Graph& graph, const Node& current_node, propagate_input_indices, all_input_cmp_rets, shape_update_func); } +bool SoftmaxGatherActor::PostProcess(Graph& graph, Node& current_node, const SliceInfo& info_without_node, + const logging::Logger& logger, + const std::unordered_map& propagate_input_indices, + const std::unordered_map>& all_input_cmp_rets, + const std::unordered_map& new_gather_infos) { + SimplePointwiseGatherActor::PostProcess(graph, current_node, info_without_node, logger, + propagate_input_indices, all_input_cmp_rets, new_gather_infos); + + // Update Softmax's axis attribute if it is scalar slice. + if (info_without_node.is_scalar_slice) { + auto axis = static_cast(current_node.GetAttributes().at("axis").i()); + auto original_ln_input_rank = info_without_node.input_rank; + axis = axis < 0 ? axis + original_ln_input_rank : axis; + auto new_axis = axis - 1; + + auto& attributes = current_node.GetMutableAttributes(); + attributes["axis"] = ONNX_NAMESPACE::MakeAttribute("axis", static_cast(new_axis)); + } + + return true; +} + bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, const SliceInfo& info, const logging::Logger& logger, std::unordered_map& propagate_input_indices, @@ -566,6 +609,11 @@ bool ReshapeGatherActor::PreCheck(const Graph& graph, const Node& current_node, return true; } + LOG_DEBUG_INFO(logger, "Skip handle the Reshape, new_shape_const_values[info.non_negative_axis]:" + + std::to_string(new_shape_const_values[info.non_negative_axis]) + + ", info.output_dim_on_axis.has_dim_value(): " + + std::to_string(info.output_dim_on_axis.has_dim_value()) + "."); + return false; } @@ -604,11 +652,12 @@ bool ReshapeGatherActor::PostProcess( return true; } - // If it selected shape is a dim value, we can update the shape tensor directory. + // If the selected shape is a dim value, we can update the shape tensor directory. if (info_without_node.output_dim_on_axis.has_dim_value()) { new_shape_const_values[slice_axis] = info_without_node.output_dim_on_axis.dim_value(); auto new_shape_arg = - CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, new_shape_const_values, + CreateInitializerFromVector(graph, {static_cast(new_shape_const_values.size())}, + new_shape_const_values, graph.GenerateNodeArgName(current_node.MutableInputDefs()[1]->Name())); graph_utils::ReplaceNodeInput(current_node, 1, *new_shape_arg); return true; diff --git a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h index f6715e4bb1f32..0c21be1397636 100644 --- a/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h +++ b/onnxruntime/core/optimizer/compute_optimizer/upstream_gather_actors.h @@ -189,7 +189,7 @@ class LayerNormalizationGatherActor : public UpStreamGatherOperatorActorBase { const logging::Logger& /* logger */, const std::unordered_map& /* propagate_input_indices */, const std::unordered_map>& /* all_input_cmp_rets */, - const std::unordered_map& /* new_gather_infos */) override { return true; } + const std::unordered_map& /* new_gather_infos */) override; }; class SoftmaxGatherActor : public SimplePointwiseGatherActor { @@ -202,6 +202,12 @@ class SoftmaxGatherActor : public SimplePointwiseGatherActor { std::unordered_map& propagate_input_indices, std::unordered_map>& all_input_cmp_rets, std::function& shape_update_func) override; + + bool PostProcess(Graph& /* graph */, Node& /* current_node */, const SliceInfo& /* info_without_node */, + const logging::Logger& /* logger */, + const std::unordered_map& /* propagate_input_indices */, + const std::unordered_map>& /* all_input_cmp_rets */, + const std::unordered_map& /* new_gather_infos */) override; }; class ReshapeGatherActor : public UpStreamGatherOperatorActorBase { diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h index aba9a798cf786..b9f3050e59c5b 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_aggregator.h @@ -64,6 +64,18 @@ enum MissingTrack : uint8_t { kFalse = 0 }; +template +struct TreeNodeElement; + +template +union PtrOrWeight { + TreeNodeElement* ptr; + struct WeightData { + int32_t weight; + int32_t n_weights; + } weight_data; +}; + template struct TreeNodeElement { int feature_id; @@ -71,24 +83,19 @@ struct TreeNodeElement { // Stores the node threshold or the weights if the tree has one target. T value_or_unique_weight; - // onnx specification says hitrates is used to store information about the node, + // The onnx specification says hitrates is used to store information about the node, // but this information is not used for inference. // T hitrates; - // True node, false node are obtained by computing `this + truenode_inc_or_first_weight`, - // `this + falsenode_inc_or_n_weights` if the node is not a leaf. - // In case of a leaf, these attributes are used to indicate the position of the weight - // in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, - // the weight is also stored in `value_or_unique_weight`. - // This implementation assumes a tree has less than 2^31 nodes, - // and the total number of leave in the set of trees is below 2^31. - // A node cannot point to itself. - int32_t truenode_inc_or_first_weight; - // In case of a leaf, the following attribute indicates the number of weights - // in array `TreeEnsembleCommon::weights_`. If not a leaf, it indicates - // `this + falsenode_inc_or_n_weights` is the false node. - // A node cannot point to itself. - int32_t falsenode_inc_or_n_weights; + // PtrOrWeight acts as a tagged union, with the "tag" being whether the node is a leaf or not (see `is_not_leaf`). + + // If it is not a leaf, it is a pointer to the true child node when traversing the decision tree. The false branch is + // always 1 position away from the TreeNodeElement in practice in `TreeEnsembleCommon::nodes_` so it is not stored. + + // If it is a leaf, it contains `weight` and `n_weights` attributes which are used to indicate the position of the + // weight in array `TreeEnsembleCommon::weights_`. If the number of targets or classes is one, the weight is also + // stored in `value_or_unique_weight`. + PtrOrWeight truenode_or_weight; uint8_t flags; inline NODE_MODE mode() const { return NODE_MODE(flags & 0xF); } @@ -189,8 +196,8 @@ class TreeAggregatorSum : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { ORT_ENFORCE(it->i < (int64_t)predictions.size()); predictions[onnxruntime::narrow(it->i)].score += it->value; predictions[onnxruntime::narrow(it->i)].has_score = 1; @@ -292,8 +299,8 @@ class TreeAggregatorMin : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value < predictions[onnxruntime::narrow(it->i)].score) ? it->value @@ -349,8 +356,8 @@ class TreeAggregatorMax : public TreeAggregator>& predictions, const TreeNodeElement& root, gsl::span> weights) const { - auto it = weights.begin() + root.truenode_inc_or_first_weight; - for (int32_t i = 0; i < root.falsenode_inc_or_n_weights; ++i, ++it) { + auto it = weights.begin() + root.truenode_or_weight.weight_data.weight; + for (int32_t i = 0; i < root.truenode_or_weight.weight_data.n_weights; ++i, ++it) { predictions[onnxruntime::narrow(it->i)].score = (!predictions[onnxruntime::narrow(it->i)].has_score || it->value > predictions[onnxruntime::narrow(it->i)].score) ? it->value diff --git a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h index 161bb2b0820eb..8f847fe66aa73 100644 --- a/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h +++ b/onnxruntime/core/providers/cpu/ml/tree_ensemble_common.h @@ -85,6 +85,13 @@ class TreeEnsembleCommon : public TreeEnsembleCommonAttributes { template void ComputeAgg(concurrency::ThreadPool* ttp, const Tensor* X, Tensor* Y, Tensor* label, const AGG& agg) const; + + private: + size_t AddNodes(const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, + const std::vector& nodes_values_as_tensor, const std::vector& node_values, + const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, + int64_t tree_id, const InlinedVector& node_tree_ids); }; template @@ -186,7 +193,7 @@ Status TreeEnsembleCommon::Init( max_tree_depth_ = 1000; ORT_ENFORCE(nodes_modes.size() < std::numeric_limits::max()); - // additional members + // Additional members size_t limit; uint32_t i; InlinedVector cmodes; @@ -195,18 +202,14 @@ Status TreeEnsembleCommon::Init( int fpos = -1; for (i = 0, limit = nodes_modes.size(); i < limit; ++i) { cmodes.push_back(MakeTreeNodeMode(nodes_modes[i])); - if (cmodes[i] == NODE_MODE::LEAF) - continue; + if (cmodes[i] == NODE_MODE::LEAF) continue; if (fpos == -1) { fpos = static_cast(i); continue; } - if (cmodes[i] != cmodes[fpos]) - same_mode_ = false; + if (cmodes[i] != cmodes[fpos]) same_mode_ = false; } - // filling nodes - n_nodes_ = nodes_treeids.size(); limit = static_cast(n_nodes_); InlinedVector node_tree_ids; @@ -214,156 +217,185 @@ Status TreeEnsembleCommon::Init( nodes_.clear(); nodes_.reserve(limit); roots_.clear(); - std::unordered_map idi; - idi.reserve(limit); + std::unordered_map node_tree_ids_map; + node_tree_ids_map.reserve(limit); + + InlinedVector truenode_ids, falsenode_ids; + truenode_ids.reserve(limit); + falsenode_ids.reserve(limit); max_feature_id_ = 0; + // Build node_tree_ids and node_tree_ids_map and truenode_ids and falsenode_ids for (i = 0; i < limit; ++i) { - TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), - static_cast(nodes_nodeids[i])}; - TreeNodeElement node; - node.feature_id = static_cast(nodes_featureids[i]); - if (node.feature_id > max_feature_id_) { - max_feature_id_ = node.feature_id; - } - node.value_or_unique_weight = nodes_values_as_tensor.empty() - ? static_cast(nodes_values[i]) - : nodes_values_as_tensor[i]; - - /* hitrates is not used for inference, they are ignored. - if (nodes_hitrates_as_tensor.empty()) { - node.hitrates = static_cast(i < nodes_hitrates.size() ? nodes_hitrates[i] : -1); - } else { - node.hitrates = i < nodes_hitrates_as_tensor.size() ? nodes_hitrates_as_tensor[i] : -1; - } */ - - node.flags = static_cast(cmodes[i]); - node.truenode_inc_or_first_weight = 0; // nodes_truenodeids[i] if not a leaf - node.falsenode_inc_or_n_weights = 0; // nodes_falsenodeids[i] if not a leaf - - if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { - node.flags |= static_cast(MissingTrack::kTrue); - } - auto p = idi.insert(std::pair(node_tree_id, i)); + TreeNodeElementId node_tree_id{static_cast(nodes_treeids[i]), static_cast(nodes_nodeids[i])}; + auto p = node_tree_ids_map.insert(std::pair(node_tree_id, i)); if (!p.second) { ORT_THROW("Node ", node_tree_id.node_id, " in tree ", node_tree_id.tree_id, " is already there."); } - nodes_.emplace_back(node); node_tree_ids.emplace_back(node_tree_id); } - InlinedVector truenode_ids, falsenode_ids; - truenode_ids.reserve(limit); - falsenode_ids.reserve(limit); TreeNodeElementId coor; - i = 0; - for (auto it = nodes_.begin(); it != nodes_.end(); ++it, ++i) { - if (!it->is_not_leaf()) { + for (i = 0; i < limit; ++i) { + if (cmodes[i] == NODE_MODE::LEAF) { truenode_ids.push_back(0); falsenode_ids.push_back(0); - continue; - } - - TreeNodeElementId& node_tree_id = node_tree_ids[i]; - coor.tree_id = node_tree_id.tree_id; - coor.node_id = static_cast(nodes_truenodeids[i]); - ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + } else { + TreeNodeElementId& node_tree_id = node_tree_ids[i]; + coor.tree_id = node_tree_id.tree_id; + coor.node_id = static_cast(nodes_truenodeids[i]); + ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + + auto found = node_tree_ids_map.find(coor); + if (found == node_tree_ids_map.end()) { + ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode)."); + } + if (found->second == truenode_ids.size()) { + ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (truenode)."); + } + truenode_ids.emplace_back(found->second); - auto found = idi.find(coor); - if (found == idi.end()) { - ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (truenode)."); - } - if (found->second == truenode_ids.size()) { - ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (truenode)."); + coor.node_id = static_cast(nodes_falsenodeids[i]); + ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); + found = node_tree_ids_map.find(coor); + if (found == node_tree_ids_map.end()) { + ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode)."); + } + if (found->second == falsenode_ids.size()) { + ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (falsenode)."); + } + falsenode_ids.emplace_back(found->second); + // We could also check that truenode_ids[truenode_ids.size() - 1] != falsenode_ids[falsenode_ids.size() - 1]). + // It is valid but no training algorithm would produce a tree where left and right nodes are the same. } - truenode_ids.emplace_back(found->second); + } - coor.node_id = static_cast(nodes_falsenodeids[i]); - ORT_ENFORCE((coor.node_id >= 0 && coor.node_id < n_nodes_)); - found = idi.find(coor); - if (found == idi.end()) { - ORT_THROW("Unable to find node ", coor.tree_id, "-", coor.node_id, " (falsenode)."); - } - if (found->second == falsenode_ids.size()) { - ORT_THROW("A node cannot point to itself: ", coor.tree_id, "-", node_tree_id.node_id, " (falsenode)."); + // Let's construct nodes_ such that the false branch is always the next element in nodes_. + // updated_mapping will translates the old position of each node to the new node position in nodes_. + std::vector updated_mapping(nodes_treeids.size(), 0); + int64_t previous_tree_id = -1; + for (i = 0; i < n_nodes_; ++i) { + if (previous_tree_id == -1 || (previous_tree_id != node_tree_ids[i].tree_id)) { + // New tree. + int64_t tree_id = node_tree_ids[i].tree_id; + size_t root_position = + AddNodes(i, cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, nodes_values, + nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + roots_.push_back(&nodes_[root_position]); + previous_tree_id = tree_id; } - falsenode_ids.emplace_back(found->second); - // We could also check that truenode_ids[truenode_ids.size() - 1] != falsenode_ids[falsenode_ids.size() - 1]). - // It is valid but no training algorithm would produce a tree where left and right nodes are the same. } - // sort targets + n_trees_ = roots_.size(); + if (((int64_t)nodes_.size()) != n_nodes_) { + ORT_THROW("Number of nodes in nodes_ (", nodes_.size(), ") is different from n_nodes (", n_nodes_, ")."); + } + + // Sort targets InlinedVector> indices; indices.reserve(target_class_nodeids.size()); for (i = 0, limit = target_class_nodeids.size(); i < limit; i++) { - indices.emplace_back(std::pair( - TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, - i)); + indices.emplace_back( + std::pair(TreeNodeElementId{target_class_treeids[i], target_class_nodeids[i]}, i)); } + std::sort(indices.begin(), indices.end()); - // Initialize the leaves. TreeNodeElementId ind; SparseValue w; size_t indi; for (indi = 0, limit = target_class_nodeids.size(); indi < limit; ++indi) { ind = indices[indi].first; i = indices[indi].second; - auto found = idi.find(ind); - if (found == idi.end()) { + auto found = node_tree_ids_map.find(ind); + if (found == node_tree_ids_map.end()) { ORT_THROW("Unable to find node ", ind.tree_id, "-", ind.node_id, " (weights)."); } - TreeNodeElement& leaf = nodes_[found->second]; + TreeNodeElement& leaf = nodes_[updated_mapping[found->second]]; if (leaf.is_not_leaf()) { // An exception should be raised in that case. But this case may happen in // models converted with an old version of onnxmltools. These weights are ignored. // ORT_THROW("Node ", ind.tree_id, "-", ind.node_id, " is not a leaf."); continue; } - w.i = target_class_ids[i]; - w.value = target_class_weights_as_tensor.empty() - ? static_cast(target_class_weights[i]) - : target_class_weights_as_tensor[i]; - if (leaf.falsenode_inc_or_n_weights == 0) { - leaf.truenode_inc_or_first_weight = static_cast(weights_.size()); + w.value = target_class_weights_as_tensor.empty() ? static_cast(target_class_weights[i]) + : target_class_weights_as_tensor[i]; + if (leaf.truenode_or_weight.weight_data.n_weights == 0) { + leaf.truenode_or_weight.weight_data.weight = static_cast(weights_.size()); leaf.value_or_unique_weight = w.value; } - ++leaf.falsenode_inc_or_n_weights; + ++leaf.truenode_or_weight.weight_data.n_weights; weights_.push_back(w); } - // Initialize all the nodes but the leaves. - int64_t previous = -1; - for (i = 0, limit = static_cast(n_nodes_); i < limit; ++i) { - if ((previous == -1) || (previous != node_tree_ids[i].tree_id)) - roots_.push_back(&(nodes_[idi[node_tree_ids[i]]])); - previous = node_tree_ids[i].tree_id; - if (!nodes_[i].is_not_leaf()) { - if (nodes_[i].falsenode_inc_or_n_weights == 0) { - ORT_THROW("Target is missing for leaf ", ind.tree_id, "-", ind.node_id, "."); - } - continue; - } - ORT_ENFORCE(truenode_ids[i] != i); // That would mean the left node is itself, leading to an infinite loop. - nodes_[i].truenode_inc_or_first_weight = static_cast(truenode_ids[i] - i); - ORT_ENFORCE(falsenode_ids[i] != i); // That would mean the right node is itself, leading to an infinite loop. - nodes_[i].falsenode_inc_or_n_weights = static_cast(falsenode_ids[i] - i); - } - - n_trees_ = roots_.size(); has_missing_tracks_ = false; - for (auto itm = nodes_missing_value_tracks_true.begin(); - itm != nodes_missing_value_tracks_true.end(); ++itm) { + for (auto itm = nodes_missing_value_tracks_true.begin(); itm != nodes_missing_value_tracks_true.end(); ++itm) { if (*itm) { has_missing_tracks_ = true; break; } } + return Status::OK(); } +template +size_t TreeEnsembleCommon::AddNodes( + const size_t i, const InlinedVector& cmodes, const InlinedVector& truenode_ids, + const InlinedVector& falsenode_ids, const std::vector& nodes_featureids, + const std::vector& nodes_values_as_tensor, const std::vector& node_values, + const std::vector& nodes_missing_value_tracks_true, std::vector& updated_mapping, int64_t tree_id, + const InlinedVector& node_tree_ids) { + // Validate this index maps to the same tree_id as the one we should be building. + if (node_tree_ids[i].tree_id != tree_id) { + ORT_THROW("Tree id mismatch. Expected ", tree_id, " but got ", node_tree_ids[i].tree_id, " at position ", i); + } + + if (updated_mapping[i] != 0) { + // In theory we should not accept any cycles, however in practice LGBM conversion implements set membership via a + // series of "Equals" nodes, with the true branches directed at the same child node (a cycle). + // We may instead seek to formalize set membership in the future. + return updated_mapping[i]; + } + + size_t node_pos = nodes_.size(); + updated_mapping[i] = node_pos; + + TreeNodeElement node; + node.flags = static_cast(cmodes[i]); + node.feature_id = static_cast(nodes_featureids[i]); + if (node.feature_id > max_feature_id_) { + max_feature_id_ = node.feature_id; + } + node.value_or_unique_weight = + nodes_values_as_tensor.empty() ? static_cast(node_values[i]) : nodes_values_as_tensor[i]; + if (i < static_cast(nodes_missing_value_tracks_true.size()) && nodes_missing_value_tracks_true[i] == 1) { + node.flags |= static_cast(MissingTrack::kTrue); + } + nodes_.push_back(std::move(node)); + if (nodes_[node_pos].is_not_leaf()) { + size_t false_branch = + AddNodes(falsenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + if (false_branch != node_pos + 1) { + ORT_THROW("False node must always be the next node, but it isn't at index ", node_pos, " with flags ", + static_cast(nodes_[node_pos].flags)); + } + size_t true_branch = + AddNodes(truenode_ids[i], cmodes, truenode_ids, falsenode_ids, nodes_featureids, nodes_values_as_tensor, + node_values, nodes_missing_value_tracks_true, updated_mapping, tree_id, node_tree_ids); + // We don't need to store the false branch pointer since we know it is always in the immediate next entry in nodes_. + // nodes_[node_pos].falsenode_inc_or_n_weights.ptr = &nodes_[false_branch]; + nodes_[node_pos].truenode_or_weight.ptr = &nodes_[true_branch]; + } else { + nodes_[node_pos].truenode_or_weight.weight_data.weight = 0; + nodes_[node_pos].truenode_or_weight.weight_data.n_weights = 0; + } + return node_pos; +} + template Status TreeEnsembleCommon::compute(OpKernelContext* ctx, const Tensor* X, @@ -637,22 +669,19 @@ void TreeEnsembleCommon::ComputeAgg(concur } } // namespace detail -#define TREE_FIND_VALUE(CMP) \ - if (has_missing_tracks_) { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ - root += (val CMP root->value_or_unique_weight || \ - (root->is_missing_track_true() && _isnan_(val))) \ - ? root->truenode_inc_or_first_weight \ - : root->falsenode_inc_or_n_weights; \ - } \ - } else { \ - while (root->is_not_leaf()) { \ - val = x_data[root->feature_id]; \ - root += val CMP root->value_or_unique_weight \ - ? root->truenode_inc_or_first_weight \ - : root->falsenode_inc_or_n_weights; \ - } \ +#define TREE_FIND_VALUE(CMP) \ + if (has_missing_tracks_) { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ + root = (val CMP root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) \ + ? root->truenode_or_weight.ptr \ + : root + 1; \ + } \ + } else { \ + while (root->is_not_leaf()) { \ + val = x_data[root->feature_id]; \ + root = val CMP root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; \ + } \ } inline bool _isnan_(float x) { return std::isnan(x); } @@ -671,15 +700,14 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( if (has_missing_tracks_) { while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root += (val <= root->value_or_unique_weight || - (root->is_missing_track_true() && _isnan_(val))) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = (val <= root->value_or_unique_weight || (root->is_missing_track_true() && _isnan_(val))) + ? root->truenode_or_weight.ptr + : root + 1; } } else { while (root->is_not_leaf()) { val = x_data[root->feature_id]; - root += val <= root->value_or_unique_weight ? root->truenode_inc_or_first_weight : root->falsenode_inc_or_n_weights; + root = val <= root->value_or_unique_weight ? root->truenode_or_weight.ptr : root + 1; } } break; @@ -703,42 +731,36 @@ TreeEnsembleCommon::ProcessTreeNodeLeave( } } else { // Different rules to compare to node thresholds. ThresholdType threshold; - while (root->is_not_leaf()) { + while (1) { val = x_data[root->feature_id]; threshold = root->value_or_unique_weight; switch (root->mode()) { case NODE_MODE::BRANCH_LEQ: - root += val <= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val <= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_LT: - root += val < threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val < threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_GTE: - root += val >= threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val >= threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_GT: - root += val > threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val > threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_EQ: - root += val == threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val == threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::BRANCH_NEQ: - root += val != threshold || (root->is_missing_track_true() && _isnan_(val)) - ? root->truenode_inc_or_first_weight - : root->falsenode_inc_or_n_weights; + root = val != threshold || (root->is_missing_track_true() && _isnan_(val)) ? root->truenode_or_weight.ptr + : root + 1; break; case NODE_MODE::LEAF: - break; + return root; } } } diff --git a/onnxruntime/core/providers/cpu/nn/batch_norm.h b/onnxruntime/core/providers/cpu/nn/batch_norm.h index a5c68eebb25ad..be9bc3368ea41 100644 --- a/onnxruntime/core/providers/cpu/nn/batch_norm.h +++ b/onnxruntime/core/providers/cpu/nn/batch_norm.h @@ -29,6 +29,10 @@ namespace onnxruntime { +#if !defined(ORT_MINIMAL_BUILD) || defined(ENABLE_TRAINING_OPS) +#define BATCHNORM_INCLUDE_TRAINING_SUPPORT +#endif + template class BatchNorm : public OpKernel { public: @@ -47,7 +51,7 @@ class BatchNorm : public OpKernel { } if (is_train_) { -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) momentum_ = op_kernel_info.GetAttrOrDefault("momentum", 0.9f); ORT_ENFORCE(is_spatial_, "Training mode only supports spatial BN"); #else @@ -84,7 +88,7 @@ class BatchNorm : public OpKernel { // calculate sample_size (including all channels) size_t sample_size_incl_all_channels = sample_size * C; -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) AllocatorPtr alloc; ORT_RETURN_IF_ERROR(p_op_kernel_context->GetTempSpaceAllocator(&alloc)); @@ -111,7 +115,7 @@ class BatchNorm : public OpKernel { ConstEigenVectorArrayMap scale_arr(scale->Data(), is_spatial_ ? C : sample_size_incl_all_channels); ConstEigenVectorArrayMap bias_arr(B->Data(), is_spatial_ ? C : sample_size_incl_all_channels); -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) // Note that we only support spatial BN for training if (is_train_) { EigenVectorArrayMap saved_mean_arr(saved_mean->MutableData(), C); @@ -162,7 +166,7 @@ class BatchNorm : public OpKernel { ConstEigenVectorArrayMap var_arr(var->Data(), is_spatial_ ? C : sample_size_incl_all_channels); inv_std = (var_arr + epsilon_).sqrt().inverse(); } else { -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) EigenVectorArrayMap saved_inv_std_arr(saved_inv_std->MutableData(), C); saved_inv_std_arr = (saved_inv_std_arr + epsilon_).inverse().sqrt(); inv_std = saved_inv_std_arr; @@ -171,7 +175,7 @@ class BatchNorm : public OpKernel { // If we're training, do batch normalization based on computation from this batch ConstEigenVectorArrayMap mean_arr( -#ifdef ENABLE_TRAINING_OPS +#if defined(BATCHNORM_INCLUDE_TRAINING_SUPPORT) !is_train_ ? mean->Data() : saved_mean->Data(), #else mean->Data(), diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index e9fc8d857b831..21a256eee6f14 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -77,7 +77,8 @@ class QLinearConv : public OpKernel { W_zero_point_value = W_zero_point_data[0]; for (int64_t i = 1; i < W_zero_point_size; i++) { ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value, - "QLinearConv : zero point of per-channel filter must be same"); + "QLinearConv : zero point of per-channel filter must be same. " + "This happens by design if the quantization is symmetric."); } } diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index d9c6126d4bf36..3fe1980141ca5 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -231,7 +231,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); @@ -464,7 +465,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/operators/gemm.cc b/onnxruntime/core/providers/js/operators/gemm.cc index f579d62bdfb5f..04700d0f54705 100644 --- a/onnxruntime/core/providers/js/operators/gemm.cc +++ b/onnxruntime/core/providers/js/operators/gemm.cc @@ -12,7 +12,15 @@ namespace js { ONNX_OPERATOR_TYPED_KERNEL_EX( \ Gemm, \ kOnnxDomain, \ - 11, \ + 13, \ + T, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + Gemm); \ + ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + Gemm, \ + kOnnxDomain, \ + 11, 12, \ T, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 691af48711a56..cfacc1aa6a363 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -25,8 +25,9 @@ class Split : public JsKernel, public SplitBase { if (num_outputs_ < 0) { num_outputs_ = split_sizes.size(); } - } else if (split_sizes_.size() == 0) { - // Compute split_sizes from input shape and num_outputs + } else if (split_sizes_.size() == 0 && info.GetInputCount() < 2) { + // Compute split_sizes from input shape and num_outputs. + // TODO: Shape might not be known at this point, better to handle this in javascript auto total_split_size = info.node().InputDefs()[0]->Shape()->dim(gsl::narrow_cast(axis_)).dim_value(); int64_t split_size_sum = 0; if (num_outputs_ < 0) { @@ -44,6 +45,7 @@ class Split : public JsKernel, public SplitBase { ORT_ENFORCE(split_size_sum == total_split_size, "Sum of split sizes (", split_size_sum, ") does not match input size (", total_split_size, ")"); } + // else: let javascript handle all other cases, ie. split_sizes come as input[1] JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc index 5b5ff0f2873fd..7797e0a47caaf 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/base_op_builder.cc @@ -112,6 +112,9 @@ bool BaseOpBuilder::HasSupportedInputOutputs(const InitializedTensorSet& initial }; for (const auto& input : node_unit.Inputs()) { + if (!input.node_arg.Exists()) { + continue; + } if (!has_supported_shape(input.node_arg, node_unit.Name(), node_unit.OpType())) return false; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc index 618779f6d2166..8d0347673ba56 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/reduction_op_builder.cc @@ -51,10 +51,11 @@ void ReductionOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, cons Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const NodeUnit& node_unit) const { const auto& op_type(node_unit.OpType()); const auto& inputs = node_unit.Inputs(); + const auto& input = node_unit.Inputs()[0].node_arg.Name(); const auto& output = node_unit.Outputs()[0].node_arg.Name(); auto& shaper(model_builder.GetShaper()); - const auto input_shape = shaper[inputs[0].node_arg.Name()]; + const auto input_shape = shaper[input]; const auto& operand_indices(model_builder.GetOperandIndices()); const auto& operand_types(model_builder.GetOperandTypes()); @@ -99,10 +100,10 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co } // Add ReduceMean operation - InlinedVector input_indices; - input_indices.push_back(operand_indices.at(inputs[0].node_arg.Name())); // data - if (!axes.empty()) { + InlinedVector input_indices; + input_indices.push_back(operand_indices.at(input)); // data + const auto axes_name = model_builder.GetUniqueName(node_unit.Name() + inputs[0].node_arg.Name() + "_axes"); Shape axes_dimen = {static_cast(axes.size())}; const OperandType axes_operand_type(Type::TENSOR_INT32, axes_dimen); @@ -110,17 +111,17 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co input_indices.push_back(operand_indices.at(axes_name)); // axes - int32_t input_size = static_cast(input_shape.size()); + int32_t input_rank = static_cast(input_shape.size()); // Make output dimensions InlinedVector output_dimen; if (keepdims) { - output_dimen.reserve(input_size); + output_dimen.reserve(input_rank); } else { - output_dimen.reserve(input_size - axes.size()); + output_dimen.reserve(input_rank - axes.size()); } - for (int32_t i = 0; i < input_size; i++) { + for (int32_t i = 0; i < input_rank; i++) { if (std::find(axes.begin(), axes.end(), i) == axes.end()) { output_dimen.push_back(input_shape[i]); } else { @@ -143,10 +144,14 @@ Status ReductionOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, co ORT_RETURN_IF_ERROR(model_builder.AddOperation(op_code, input_indices, {output}, {output_operand_type})); } else { - // If `axes` is still empty at this point, meaning that it's ReduceMean-18 and attribute `noop_with_empty_axes` specifies as 1, - // treat as an Identity op here. - const OperandType output_operand_type(operand_types.at(inputs[0].node_arg.Name()).type, input_shape); - model_builder.RegisterOperand(output, operand_indices.at(inputs[0].node_arg.Name()), output_operand_type); + // Note: If `axes` is still empty at this point, meaning it's ReduceMean-18 and attribute `noop_with_empty_axes` + // specifies as 1. We treat this case as an Identity op in NNAPI EP. + // However, we hit an issue while adding no-ops operation in NNAPI because it doesn't allow adding an operand both as + // an input and output. + // Currently, we return not supported in NNAPI EP when `noop_with_empty_axes` is true. + + // const OperandType output_operand_type(operand_types.at(input).type, input_shape); + // model_builder.RegisterOperand(output, operand_indices.at(input), output_operand_type); } return Status::OK(); @@ -169,6 +174,8 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ const auto& inputs = node_unit.Inputs(); const auto& op(node_unit.OpType()); + NodeAttrHelper helper(node_unit); + Shape input_shape; if (!GetShape(inputs[0].node_arg, input_shape)) return false; @@ -180,6 +187,7 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ } if (op == "ReduceMean") { + const bool noop_with_empty_axes = helper.Get("noop_with_empty_axes", 0) != 0; if (inputs.size() > 1 && inputs[1].node_arg.Exists()) { const auto& axes_name = inputs[1].node_arg.Name(); if (!Contains(initializers, axes_name)) { @@ -187,6 +195,15 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return false; } } + // Note: For the case - ReduceMean 18+ with noop_with_empty_axes attribute set as 1, + // currently we hit an issue in NNAPI where it does not allow adding an operand as both an input and output. + // This issue may arise from handling no-ops like Identity and ReduceX with noop_with_empty_axes set. + // TODO: Support the case when a more complete solution is available. + if (node_unit.SinceVersion() >= 18 && noop_with_empty_axes) { + LOGS_DEFAULT(VERBOSE) + << "ReduceMean 18+ with noop_with_empty_axes attribute set as 1 is not supported for now."; + return false; + } } return true; diff --git a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc index 01e348caf16cd..cdaa1c8fac76c 100644 --- a/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/nnapi/nnapi_builtin/builders/impl/resize_op_builder.cc @@ -153,10 +153,10 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers if (!GetShape(node_unit.Inputs()[0].node_arg, input_shape)) return false; - const auto input_size = input_shape.size(); - if (input_size != 4) { + const auto input_rank = input_shape.size(); + if (input_rank != 4) { LOGS_DEFAULT(VERBOSE) << "Resize only support 4d shape, input is " - << input_size << "d shape"; + << input_rank << "d shape"; return false; } @@ -206,6 +206,26 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } } + + // The new feature - antialiasing introduced since opset 18 doesn't have a NNAPI mapping support yet. + // And a few other new attributes are currently not handled by NNAPI EP, can add support in the future if needed. + if (node_unit.SinceVersion() >= 18) { + const auto antialias = helper.Get("antialias", 0); + const auto axes = helper.Get("axes", std::vector{}); + const auto keep_aspect_ratio_policy = helper.Get("keep_aspect_ratio_policy", "stretch"); + if (antialias != 0) { + LOGS_DEFAULT(VERBOSE) << "Resize 18+ antialias feature is not currently supported by NNAPI."; + return false; + } + if (!axes.empty()) { + LOGS_DEFAULT(VERBOSE) << "Resize 18+ axes attribute is not currently supported by NNAPI EP."; + return false; + } + if (keep_aspect_ratio_policy != "stretch") { + LOGS_DEFAULT(VERBOSE) << "Resize 18+ keep_aspect_ratio_policy attribute is not currently supported by NNAPI EP."; + return false; + } + } } { // scales and sizes (if present) must be initializers @@ -216,20 +236,22 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers } // scales - if (inputs.size() == 3 && !Contains(initializers, inputs[2].node_arg.Name())) { + bool using_scales = (inputs.size() > 2 && inputs[2].node_arg.Exists()); + if (using_scales && !Contains(initializers, inputs[2].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "Input scales of Resize must be known"; return false; } // sizes - if (inputs.size() > 3 && !Contains(initializers, inputs[3].node_arg.Name())) { + bool using_sizes = inputs.size() > 3 && inputs[3].node_arg.Exists(); + if (using_sizes && !Contains(initializers, inputs[3].node_arg.Name())) { LOGS_DEFAULT(VERBOSE) << "Input sizes of Resize must be known"; return false; } bool input_is_nchw = false; // haven't a good solution to check layout when scale is 1.0F // We want to check if the scales or sizes are not trying to resize on N/C channels here - if (inputs.size() == 3) { // we are using scales + if (using_scales) { // we are using scales const auto& scales_tensor = *initializers.at(inputs[2].node_arg.Name()); Initializer const unpacked_tensor(scales_tensor); auto scales_data = unpacked_tensor.DataAsSpan(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 3f0cfdac8a2a0..88a576f3ffa73 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1034,10 +1034,6 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv int8_calibration_cache_available_ = !int8_calibration_cache_name_.empty(); } - if (cuda_graph_enable_) { - GetPerThreadContext().InitCUDAGraph(); - } - /* * Parse explicit min/max/opt profile shapes from provider options. * @@ -1142,50 +1138,35 @@ bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool TensorrtExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); -} - -Status TensorrtExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); -} - -void TensorrtExecutionProvider::PerThreadContext::InitCUDAGraph() { - cuda_graph_ = std::make_unique(); -} - -void TensorrtExecutionProvider::PerThreadContext::SetGraphStream(cudaStream_t stream) { - cuda_graph_->SetStream(stream); -} - -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { +bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::PerThreadContext::CaptureBegin() { - cuda_graph_->Reset(); - cuda_graph_->CaptureBegin(); +void TensorrtExecutionProvider::CaptureBegin() { + cuda_graph_.Reset(); + cuda_graph_.CaptureBegin(); } -void TensorrtExecutionProvider::PerThreadContext::CaptureEnd() { - cuda_graph_->CaptureEnd(); +void TensorrtExecutionProvider::CaptureEnd() { + cuda_graph_.CaptureEnd(); is_graph_captured_ = true; } -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptured() const { +bool TensorrtExecutionProvider::IsGraphCaptured() const { return is_graph_captured_; } -Status TensorrtExecutionProvider::PerThreadContext::ReplayGraph() { +Status TensorrtExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); // Please note that CUDAGraph::Replay() is not thread safe. - // The cuda graph object is maintained by a per thread basis, + // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. - return cuda_graph_->Replay(); + return cuda_graph_.Replay(); } -void TensorrtExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - // The cuda graph object is maintained by a per thread basis, +void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), // therefore following increment is guaranteed to be thread safe. ++regular_run_count_before_graph_capture_; } @@ -1216,18 +1197,6 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } - - // The reason of !IsGraphCaptureEnabled(): - // If cuda graph is enabled, the per thread context will not be released - // because the per thread cuda graph needs to be maintained and replayed for - // the next run. - // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to release. - if (!IsGraphCaptureEnabled() && - PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { - ReleasePerThreadContext(); - } return Status::OK(); } @@ -2381,6 +2350,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &builders_[context->node_name], + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2430,6 +2392,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(state); + + // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine, + // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading. + // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + std::lock_guard lock(*(trt_state->tensorrt_mu_ptr)); const std::unordered_map& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; @@ -2437,6 +2404,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorinput_shape_ranges; auto trt_builder = trt_state->builder->get(); auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); @@ -2471,260 +2439,238 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector lock(*(trt_state->tensorrt_mu_ptr)); - - // Load serialized engine - if (trt_state->engine_cache_enable && trt_engine == nullptr) { - std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); - std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); - if (engine_file && profile_file) { - // Deserialize profile - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - - // Prepare buffer - engine_file.seekg(0, std::ios::end); - size_t engine_size = engine_file.tellg(); - engine_file.seekg(0, std::ios::beg); - std::unique_ptr engine_buf{new char[engine_size]}; - engine_file.read((char*)engine_buf.get(), engine_size); - - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr( - trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { - shape_ranges = DeserializeProfileV2(profile_file); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Decrypt engine - size_t engine_size = 0; - if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not get engine buffer size"); - } - std::unique_ptr engine_buf{new char[engine_size]}; - if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine decryption function decrypt"); - } - // Deserialize engine - // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - trt_state->engine->reset(); - *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); - } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; - trt_engine = trt_state->engine->get(); - context_update = true; - } - } + // Load serialized engine + if (trt_state->engine_cache_enable && trt_engine == nullptr) { + std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); + std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); + if (engine_file && profile_file) { + // Deserialize profile + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; - // Check and update shape ranges for dynamic shape inputs. - for (int i = 0, end = num_inputs; i < end; ++i) { - auto input = trt_state->network->get()->getInput(i); - const std::string& input_name = input->getName(); - input_names.insert(input_name); + // Prepare buffer + engine_file.seekg(0, std::ios::end); + size_t engine_size = engine_file.tellg(); + engine_file.seekg(0, std::ios::beg); + std::unique_ptr engine_buf{new char[engine_size]}; + engine_file.read((char*)engine_buf.get(), engine_size); - // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. - // TRT EP will help determine the min/max/opt profile values based on current input tensor value. - if (shape_ranges.find(input_name) != shape_ranges.end()) { - auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); - if (status != Status::OK()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); - } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading + trt_state->engine->reset(); + *(trt_state->engine) = std::unique_ptr( + trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } - } - - // Regenerate engine - if (engine_update) { - // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. - if (GetPerThreadContext().IsTensorRTContextInMap(fused_node_name)) { - GetPerThreadContext().ResetTensorRTContext(fused_node_name); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } else if (trt_state->engine_decryption_enable && !engine_file && profile_file) { + shape_ranges = DeserializeProfileV2(profile_file); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path; + // Decrypt engine + size_t engine_size = 0; + if (!trt_state->engine_decryption(engine_cache_path.c_str(), nullptr, &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not get engine buffer size"); } - + std::unique_ptr engine_buf{new char[engine_size]}; + if (!trt_state->engine_decryption(engine_cache_path.c_str(), &engine_buf[0], &engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine decryption function decrypt"); + } + // Deserialize engine + // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc + // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading trt_state->engine->reset(); - auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); - trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); - for (auto trt_profile : trt_profiles) { - trt_config->addOptimizationProfile(trt_profile); + *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not deserialize engine from encrypted cache: " + engine_cache_path); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; + trt_engine = trt_state->engine->get(); + context_update = true; + } + } - // Set INT8 Per Tensor Dynamic range - if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { - trt_config->setInt8Calibrator(nullptr); - if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); - } - } + // Check and update shape ranges for dynamic shape inputs. + for (int i = 0, end = num_inputs; i < end; ++i) { + auto input = trt_state->network->get()->getInput(i); + const std::string& input_name = input->getName(); + input_names.insert(input_name); - // Set precision - if (trt_state->fp16_enable && trt_state->int8_enable) { - trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved. + // TRT EP will help determine the min/max/opt profile values based on current input tensor value. + if (shape_ranges.find(input_name) != shape_ranges.end()) { + auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update); + if (status != Status::OK()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles."); } + } + } - // Set DLA (DLA can only run with FP16 or INT8) - if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; - trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); - trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(trt_state->dla_core); - } + // Regenerate engine + if (engine_update) { + // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior. + trt_state->context->reset(); + trt_state->engine->reset(); + auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); + trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); + for (auto trt_profile : trt_profiles) { + trt_config->addOptimizationProfile(trt_profile); + } - // enable sparse weights - if (trt_state->sparsity_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + // Set INT8 Per Tensor Dynamic range + if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + trt_config->setInt8Calibrator(nullptr); + if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); } + } - // enable builder heuristics - if (trt_state->build_heuristics_enable) { - trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; - } + // Set precision + if (trt_state->fp16_enable && trt_state->int8_enable) { + trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); + } else if (trt_state->fp16_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); + } else if (trt_state->int8_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); + } + + // Set DLA (DLA can only run with FP16 or INT8) + if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); + trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); + trt_config->setDLACore(trt_state->dla_core); + } + + // enable sparse weights + if (trt_state->sparsity_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed"; + } + + // enable builder heuristics + if (trt_state->build_heuristics_enable) { + trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled"; + } #if NV_TENSORRT_MINOR > 5 && NV_TENSORRT_MAJOR >= 8 - // switch optimizaion level - if (trt_state->builder_optimization_level != 3) { - trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; - } + // switch optimizaion level + if (trt_state->builder_optimization_level != 3) { + trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_; + } - // limit auxiliary streams - if (trt_state->auxiliary_streams >= 0) { - trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; - } + // limit auxiliary streams + if (trt_state->auxiliary_streams >= 0) { + trt_config->setMaxAuxStreams(trt_state->auxiliary_streams); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams; + } #else - if (trt_state->builder_optimization_level != 3) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; - } - if (trt_state->auxiliary_streams >= 0) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; - } + if (trt_state->builder_optimization_level != 3) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!"; + } + if (trt_state->auxiliary_streams >= 0) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!"; + } #endif - // limit used tactic sources - if (trt_state->filter_tactic_sources) { - nvinfer1::TacticSources tactics = trt_config->getTacticSources(); - tactics |= trt_state->tactic_sources; - trt_config->setTacticSources(tactics); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + // limit used tactic sources + if (trt_state->filter_tactic_sources) { + nvinfer1::TacticSources tactics = trt_config->getTacticSources(); + tactics |= trt_state->tactic_sources; + trt_config->setTacticSources(tactics); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics; + } + + // Load timing cache from file. Create a fresh cache if the file doesn't exist + std::unique_ptr timing_cache = nullptr; + if (trt_state->timing_cache_enable) { + std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); + timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); + if (timing_cache == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not create timing cache: " + timing_cache_path); } - - // Load timing cache from file. Create a fresh cache if the file doesn't exist - std::unique_ptr timing_cache = nullptr; - if (trt_state->timing_cache_enable) { - std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path); - timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size())); - if (timing_cache == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not create timing cache: " + timing_cache_path); - } - trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; - } + trt_config->setTimingCache(*timing_cache, force_timing_cache_match_); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path; } + } - // Build engine - { - auto lock = GetApiLock(); - std::chrono::steady_clock::time_point engine_build_start; - if (detailed_build_log_) { - engine_build_start = std::chrono::steady_clock::now(); - } - *(trt_state->engine) = std::unique_ptr( - trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); - if (detailed_build_log_) { - auto engine_build_stop = std::chrono::steady_clock::now(); - LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; - } + // Build engine + { + auto lock = GetApiLock(); + std::chrono::steady_clock::time_point engine_build_start; + if (detailed_build_log_) { + engine_build_start = std::chrono::steady_clock::now(); } - if (*(trt_state->engine) == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + *(trt_state->engine) = std::unique_ptr( + trt_builder->buildEngineWithConfig(*trt_state->network->get(), *trt_config)); + if (detailed_build_log_) { + auto engine_build_stop = std::chrono::steady_clock::now(); + LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } - trt_engine = trt_state->engine->get(); - if (trt_state->engine_cache_enable) { - // Serialize engine profile - SerializeProfileV2(profile_cache_path, shape_ranges); - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; + } + if (!(*(trt_state->engine))) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); + } + trt_engine = trt_state->engine->get(); + if (trt_state->engine_cache_enable) { + // Serialize engine profile + SerializeProfileV2(profile_cache_path, shape_ranges); + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; - // Serialize engine - std::unique_ptr serializedModel(trt_engine->serialize()); - size_t engine_size = serializedModel->size(); - if (trt_state->engine_decryption_enable) { - // Encrypt engine - if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not call engine encryption function encrypt"); - } - } else { - std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); - file.write(reinterpret_cast(serializedModel->data()), engine_size); + // Serialize engine + std::unique_ptr serializedModel(trt_engine->serialize()); + size_t engine_size = serializedModel->size(); + if (trt_state->engine_decryption_enable) { + // Encrypt engine + if (!trt_state->engine_encryption(engine_cache_path.c_str(), reinterpret_cast(serializedModel->data()), engine_size)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not call engine encryption function encrypt"); } - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } else { + std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out); + file.write(reinterpret_cast(serializedModel->data()), engine_size); } + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path; + } - // serialize and save timing cache - if (trt_state->timing_cache_enable) { - auto timing_cache = trt_config->getTimingCache(); - std::unique_ptr timingCacheHostData{timing_cache->serialize()}; - if (timingCacheHostData == nullptr) { - return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, - "TensorRT EP could not serialize timing cache: " + timing_cache_path); - } - saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); - if (detailed_build_log_) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; - } + // serialize and save timing cache + if (trt_state->timing_cache_enable) { + auto timing_cache = trt_config->getTimingCache(); + std::unique_ptr timingCacheHostData{timing_cache->serialize()}; + if (timingCacheHostData == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, + "TensorRT EP could not serialize timing cache: " + timing_cache_path); + } + saveTimingCacheFile(timing_cache_path, timingCacheHostData.get()); + if (detailed_build_log_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path; } - context_update = true; } + context_update = true; } - // Build execution context if either of the following conditions is true: - // (1) The engine is built or updated by this thread. - // (2) The first inference run for this thread where there is no IExecutionContext object yet. - // (3) The engine is updated by another thread. (We compare the profile shapes maintained by the PerThreadContext to the profile shapes maintained by TRT EP) - // - // Note: Creating an execution context from an engine is thread safe per TRT doc - // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading - if (context_update || - !GetPerThreadContext().IsTensorRTContextInMap(fused_node_name) || - GetPerThreadContext().CompareProfileShapes(fused_node_name, shape_ranges)) { - std::unique_ptr new_context; + if (context_update) { if (trt_state->context_memory_sharing_enable) { - new_context.reset(trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); } else { - new_context.reset(trt_state->engine->get()->createExecutionContext()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); } - auto context_status = GetPerThreadContext().UpdateTensorRTContext(fused_node_name, std::move(new_context)); - if (!context_status) { + if (!(*(trt_state->context))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } - GetPerThreadContext().UpdateProfileShapes(fused_node_name, shape_ranges); + trt_context = trt_state->context->get(); } - // Get the reference to the IExecutionContext object that is maintained on a per thread basis. - nvinfer1::IExecutionContext& trt_context = GetPerThreadContext().GetTensorRTContext(fused_node_name); - // Get input and output binding names int total_bindings = trt_engine->getNbBindings(); std::vector buffers(total_bindings); @@ -2760,12 +2706,12 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorisShapeBinding(binding_index)) { - trt_context.setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); } else { for (int j = 0, end = nb_dims; j < end; ++j) { dimensions.d[j] = static_cast(tensor_shapes[j]); } - const bool status = trt_context.setBindingDimensions(binding_index, dimensions); + const bool status = trt_context->setBindingDimensions(binding_index, dimensions); if (!status) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP cannot set the dynamic dimensions of a binding")); @@ -2904,7 +2850,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsecond; } - nvinfer1::Dims dimensions = trt_context.getBindingDimensions(static_cast(binding_index)); + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); int nb_dims = dimensions.nbDims; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { @@ -3038,20 +2984,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; } - trt_context.setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().SetGraphStream(stream); - GetPerThreadContext().CaptureBegin(); + cuda_graph_.SetStream(stream); + CaptureBegin(); } // Run TRT inference - if (!trt_context.enqueueV2(&buffers[0], stream, nullptr)) { + if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } @@ -3082,14 +3028,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector* parser = nullptr; std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; @@ -246,6 +247,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. std::unordered_map> parsers_; std::unordered_map> engines_; + std::unordered_map> contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -256,6 +258,21 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture + cudnnHandle_t external_cudnn_handle_ = nullptr; + cublasHandle_t external_cublas_handle_ = nullptr; + + CUDAGraph cuda_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + // [Note] We don't use PerThreadContext for now since it has issue with multithreading + // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. // For example, TensorRT execution context and CUDA graph are the ones to be put here. class PerThreadContext final { @@ -306,15 +323,15 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // Cuda graph with multi threads will be supported in the future, so cuda_graph_ is put under PerThreadContext. - // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph pointer is enough (no need to maintain one CUDAGraph pointer per TRT subgraph) - std::unique_ptr cuda_graph_; + // ORT TRT only supports CUDA graph when whole model is supported by TRT, so simply maintaining a CUDAGraph instance is enough (no need to maintain one CUDAGraph instance per TRT subgraph) + CUDAGraph cuda_graph_; bool is_graph_captured_ = false; - int regular_run_count_before_graph_capture_ = -1; + int regular_run_count_before_graph_capture_ = 0; // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: // (1) memory pattern is enabled. (2) arena allocation for stream. // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs // to allocate enough memory in Arena before graph capturing. - const int min_num_runs_before_cuda_graph_capture_ = 0; // required min regular runs before graph capture for the necessary memory allocations. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 4df11c2224e27..b5f45b15a5992 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -138,6 +138,7 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri model_proto = model.ToProto(); } else { model_proto = model.ToGraphProtoWithExternalInitializers(filename_dat, + ToPathString(filename), initializer_size_threshold); } auto& metadata = model.MetaData(); diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 15fe5acfe0fd2..70d2d0fe5d511 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2711,9 +2711,8 @@ static constexpr OrtApi ort_api_1_to_16 = { &OrtApis::GetTensorRTProviderOptionsByName, &OrtApis::UpdateCUDAProviderOptionsWithValue, &OrtApis::GetCUDAProviderOptionsByName, - // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) - &OrtApis::KernelContext_GetResource, + // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2742,10 +2741,10 @@ static_assert(offsetof(OrtApi, ReleaseKernelInfo) / sizeof(void*) == 218, "Size static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 224, "Size of version 13 API cannot change"); static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); -static_assert(offsetof(OrtApi, GetCUDAProviderOptionsByName) / sizeof(void*) == 264, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: -static_assert(std::string_view(ORT_VERSION) == "1.16.0", +static_assert(std::string_view(ORT_VERSION) == "1.16.2", "ORT_Version change detected, please follow below steps to ensure OrtApi is updated properly"); // 1. Update the hardcoded version string in above static_assert to silence it // 2. If there were any APIs added to ort_api_1_to_16 above: diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index b73fcbbff5456..c86469918cca5 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -446,14 +446,6 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi providers, provider_options = check_and_normalize_provider_args( providers, provider_options, available_providers ) - if not providers and len(available_providers) > 1: - self.disable_fallback() - raise ValueError( - f"This ORT build has {available_providers} enabled. " - "Since ORT 1.9, you are required to explicitly set " - "the providers parameter when instantiating InferenceSession. For example, " - f"onnxruntime.InferenceSession(..., providers={available_providers}, ...)" - ) session_options = self._sess_options if self._sess_options else C.get_default_session_options() if self._model_path: diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 5ac20739c486e..82d119894a5d8 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -53,6 +53,7 @@ namespace onnxruntime { #endif // _MSC_VER #include +#include #if defined(_MSC_VER) #pragma warning(disable : 4267 4996 4503 4003) @@ -85,7 +86,7 @@ struct AsyncResource { std::vector feed_names; std::vector feed_names_raw; - std::vector fetches_raw; + std::vector fetches_raw; // will be released during destruction std::vector fetch_names; std::vector fetch_names_raw; @@ -106,6 +107,15 @@ struct AsyncResource { fetch_names.reserve(sz); fetch_names_raw.reserve(sz); } + + ~AsyncResource() { + std::for_each(fetches_raw.begin(), fetches_raw.end(), [](const OrtValue* fetch) { + if (fetch) { + std::unique_ptr fetch_recycler(fetch); + } + }); + fetches_raw.clear(); + } }; void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) { diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index bdf00f21100bf..26e74a6dfbac9 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -22,7 +22,7 @@ class TensorData: - _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"]) + _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"]) def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -55,7 +55,7 @@ def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]] self.data[k] = TensorData(lowest=v[0], highest=v[1]) continue if len(v) == 4: - self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3]) + self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3]) continue raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.") if not isinstance(v, TensorData): diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 924d4c72b6390..bb968d660c30c 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -112,8 +112,8 @@ def __init__( False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] ) - self.activation_qType = activation_qType.tensor_type - self.weight_qType = weight_qType.tensor_type + self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType) + self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType) """ Dictionary specifying the min and max values for tensors. It has following format: { diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index d23459b478e6a..23f9eaf4b0e0b 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -157,7 +157,7 @@ def quantize(self): nodes, ) = self.quantizer.quantize_activation(node, [0]) quant_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quantized_input_names.append(quant_weight_tuple[0]) zero_point_names.append(quant_weight_tuple[1]) diff --git a/onnxruntime/python/tools/quantization/operators/lstm.py b/onnxruntime/python/tools/quantization/operators/lstm.py index 7e91f9b76ca36..90a52cb528b32 100644 --- a/onnxruntime/python/tools/quantization/operators/lstm.py +++ b/onnxruntime/python/tools/quantization/operators/lstm.py @@ -47,10 +47,10 @@ def quantize(self): R.dims[0] = R_num_dir * R_4_hidden_size quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[2], onnx_proto.TensorProto.INT8, 0 + node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806 diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index f87a9d8228bac..a03f6431fbd08 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -266,7 +266,13 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( weight_name, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, + # Quantization type is forced to be TensorProto.INT8. + # when the expected value would be (see below) + # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. + # QLinearConv expects to have a unique value for all channels. + # This code does not enforce that but it is necessarily the case when the + # quantization is symmetric (as for INT8). + onnx_proto.TensorProto.INT8, axis, keep_float_weight=self.add_qdq_pair_to_weight, ) diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index eff3dc0bcdc35..b7d4726610387 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -99,7 +99,10 @@ def quant_pre_process( sess_option = onnxruntime.SessionOptions() sess_option.optimized_model_filepath = opt_model_path sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - _ = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + sess = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + # Close the session to avoid the cleanup error on Windows for temp folders + # https://github.com/microsoft/onnxruntime/issues/17627 + del sess except Exception: logger.error( "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'." diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 4b771c5bee3b1..1ec5edf686c63 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -38,17 +38,17 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): # In some models there is input_ids->gather->add->LayerNorm and one of input of the # add node is initializer with fixed shape which should not be fused into SkipLayerNorm - if add is None: + if add is None or add.op_type != "Add": + return + + # The number of inputs of add should be 2 + if len(add.input) != 2: return for add_input in add.input: if self.model.get_initializer(add_input) is not None: return - # The number of input node of add should be 2 - if len(self.model.get_parents(add)) != 2: - return - # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization if add in self.nodes_to_remove: return @@ -57,6 +57,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): simplified = node.op_type == "SimplifiedLayerNormalization" if self.shape_infer_helper is not None: + # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size) if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): logger.debug( "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same", @@ -73,15 +74,14 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None: return - residual_add_has_multiple_consumers = False - add_children = self.model.get_children(add, input_name_to_nodes) - # This means that the residual Add before the LayerNormalization produces an output - # that is consumed by some other nodes other than the LayerNormalization itself + # that is consumed by some other nodes or graph output other than the LayerNormalization itself # We can still go ahead with the SkipLayerNormalization fusion but we need to # preserve the output of Add and that needs to be produced by SkipLayerNormalization. - if len(add_children) != 1: - residual_add_has_multiple_consumers = True + add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None + residual_add_has_multiple_consumers = ( + add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1 + ) outputs_to_keep = node.output @@ -94,11 +94,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if residual_add_has_multiple_consumers: outputs.extend(["", "", add.output[0]]) - if ( - add is not None - and add.op_type == "Add" - and self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node) - ): + if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend([add, node]) inputs = ( @@ -136,32 +132,33 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return return_indice = [] - nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], None, return_indice) - if nodes is None: + nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice) + if nodes is not None: + (add, _matmul) = nodes + else: # In case of fp16, we could have a Cast between the MatMul and the bias Add + return_indice = [] nodes = self.model.match_parent_path( - node, ["Add", "Cast", "MatMul"], [None, None, None], None, return_indice + node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice ) - if nodes is None: + if nodes is not None: + (add, _cast, _matmul) = nodes + else: return assert len(return_indice) == 2 or len(return_indice) == 3 add_input_index = return_indice[0] if add_input_index >= 2: return - - (add, matmul) = nodes + sln_input = add.input[return_indice[1]] + bias_input = add.input[1 - return_indice[1]] + skip_input = node.input[1 - add_input_index] # bias should be one dimension - bias_index = -1 - bias_weight = None - for i, input in enumerate(add.input): - initializer = self.model.get_initializer(input) - if initializer is None: - continue - bias_index = i - bias_weight = NumpyHelper.to_array(initializer) - break + initializer = self.model.get_initializer(bias_input) + if initializer is None: + return + bias_weight = NumpyHelper.to_array(initializer) if bias_weight is None: logger.debug("Bias weight not found") return @@ -176,11 +173,11 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend(subgraph_nodes) inputs = [ - node.input[1 - add_input_index], - matmul.output[0], + sln_input, + skip_input, node.input[2], node.input[3], - add.input[bias_index], + bias_input, ] new_node = helper.make_node( "SkipLayerNormalization", diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 0715395268ee8..71c1a21d8f768 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,5 +1,6 @@ import logging -from typing import Dict, List +from collections import OrderedDict +from typing import Any, Dict, List import numpy import torch @@ -205,3 +206,112 @@ def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shape else: ort_outputs.append(copy_tensor) return ort_outputs + + +class CudaSession: + """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider""" + + def __init__(self, ort_session: InferenceSession, device: torch.device, enable_cuda_graph=False): + self.ort_session = ort_session + self.input_names = [input.name for input in self.ort_session.get_inputs()] + self.output_names = [output.name for output in self.ort_session.get_outputs()] + self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session) + self.io_binding = self.ort_session.io_binding() + self.enable_cuda_graph = enable_cuda_graph + + self.input_tensors = OrderedDict() + self.output_tensors = OrderedDict() + self.device = device + + def __del__(self): + del self.input_tensors + del self.output_tensors + del self.io_binding + del self.ort_session + + def allocate_buffers(self, shape_dict: Dict[str, tuple]): + """Allocate tensors for I/O Binding""" + if self.enable_cuda_graph: + for name, shape in shape_dict.items(): + if name in self.input_names: + # Reuse allocated buffer when the shape is same + if name in self.input_tensors: + if tuple(self.input_tensors[name].shape) == tuple(shape): + continue + raise RuntimeError("Expect static input shape for cuda graph") + + numpy_dtype = self.io_name_to_numpy_type[name] + tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( + device=self.device + ) + self.input_tensors[name] = tensor + + self.io_binding.bind_input( + name, + tensor.device.type, + tensor.device.index, + numpy_dtype, + list(tensor.size()), + tensor.data_ptr(), + ) + + for name, shape in shape_dict.items(): + if name in self.output_names: + # Reuse allocated buffer when the shape is same + if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape): + continue + + numpy_dtype = self.io_name_to_numpy_type[name] + tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( + device=self.device + ) + self.output_tensors[name] = tensor + + self.io_binding.bind_output( + name, + tensor.device.type, + tensor.device.index, + numpy_dtype, + list(tensor.size()), + tensor.data_ptr(), + ) + + def infer(self, feed_dict: Dict[str, torch.Tensor]): + """Bind input tensors and run inference""" + for name, tensor in feed_dict.items(): + assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() + if name in self.input_names: + if self.enable_cuda_graph: + assert self.input_tensors[name].nelement() == tensor.nelement() + assert tensor.device.type == "cuda" + # Please install cuda-python package with a version corresponding to CUDA in your machine. + from cuda import cudart + + # Update input tensor inplace since cuda graph requires input and output has fixed memory address. + cudart.cudaMemcpy( + self.input_tensors[name].data_ptr(), + tensor.data_ptr(), + tensor.element_size() * tensor.nelement(), + cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, + ) + else: + self.io_binding.bind_input( + name, + tensor.device.type, + tensor.device.index, + TypeHelper.torch_type_to_numpy_type(tensor.dtype), + [1] if len(tensor.shape) == 0 else list(tensor.shape), + tensor.data_ptr(), + ) + + self.ort_session.run_with_iobinding(self.io_binding) + + return self.output_tensors + + @staticmethod + def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]: + return { + "device_id": device_id, + "arena_extend_strategy": "kSameAsRequested", + "enable_cuda_graph": enable_cuda_graph, + } diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 0e66a22e59b9a..b652e0723f5aa 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -2112,9 +2112,11 @@ static void RunModelWithRandomInput( constexpr int hidden_size = 768; constexpr int num_heads = 12; + const float min_value = is_float16 ? -0.001f : -1.0f; + const float max_value = is_float16 ? 0.001f : 1.0f; std::vector batch_input_dims{1, sequence_length, hidden_size}; - std::vector batch_input_data = random.Uniform(batch_input_dims, -1.0f, 1.0f); + std::vector batch_input_data = random.Uniform(batch_input_dims, min_value, max_value); std::vector input_dims{batch_size, sequence_length, hidden_size}; std::vector input_data; @@ -2123,12 +2125,12 @@ static void RunModelWithRandomInput( } std::vector weight_dims{hidden_size, 3 * hidden_size}; - std::vector weight_data = random.Uniform(weight_dims, -1.0f, 1.0f); + std::vector weight_data = random.Uniform(weight_dims, min_value, max_value); std::vector bias_dims{3 * hidden_size}; - std::vector bias_data = random.Uniform(bias_dims, -1.0f, 1.0f); + std::vector bias_data = random.Uniform(bias_dims, min_value, max_value); - float gpu_threshold = is_float16 ? static_cast(sequence_length) / 32.0f : 0.005f; + float gpu_threshold = is_float16 ? 0.5f : 0.005f; constexpr float cpu_threshold = 0.002f; bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()); @@ -2146,7 +2148,10 @@ static void RunModelWithRandomInput( test.AddInput("weight", weight_dims, weight_data); test.AddInput("bias", bias_dims, bias_data); } - test.AddInput("mask_index", mask_index_dims, mask_index_data); + if (mask_index_data.size() > 0) { + test.AddInput("mask_index", mask_index_dims, mask_index_data); + } + test.AddReferenceOutputs(onnx_model, gpu_threshold); std::vector> execution_providers; if (enable_cuda) { @@ -2216,6 +2221,25 @@ TEST(AttentionTest, Attention_Mask1D_Fp32_B2_S64) { false); } +// This case can be used to test flash attention using Ampere GPU +TEST(AttentionTest, Attention_NoMask_Fp16) { + constexpr int batch_size = 2; + std::vector sequence_lengths{1, 7, 8}; + for (const auto& sequence_length : sequence_lengths) { + std::vector mask_index_dims{}; + std::vector mask_index_data{}; + std::string onnx_model = "testdata/attention_no_mask_fp16.onnx"; + + RunModelWithRandomInput( + batch_size, + sequence_length, + mask_index_dims, + mask_index_data, + onnx_model, + true); + } +} + // This test is disabled since it is flaky. TEST(AttentionTest, DISABLED_Attention_Mask1D_Fp16_B2_FusedNoPadding) { constexpr int batch_size = 2; diff --git a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc index c2230501b081d..49b338d8329ee 100644 --- a/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/multihead_attention_op_test.cc @@ -300,6 +300,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_Default) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, @@ -315,6 +316,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -330,6 +332,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}, @@ -342,10 +345,11 @@ static void RunMultiHeadAttentionKernel( return; } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (kernel_type == AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -362,6 +366,7 @@ static void RunMultiHeadAttentionKernel( if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -388,9 +393,9 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); } -#if USE_FLASH_ATTENTION - if (data.sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 || - data.kv_sequence_length >= contrib::attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) { +#if USE_MEMORY_EFFICIENT_ATTENTION + if (data.sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32 || + data.kv_sequence_length >= contrib::attention::kMinSeqLenForMemoryEfficientAttentionFp32) { kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( @@ -434,7 +439,7 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data, bool disable_cpu data.v_hidden_size, kernel_type, use_float16, data.is_static_kv, disable_cpu, disable_cuda); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention; if (!SkipAttentionKernel(data, kernel_type)) { RunMultiHeadAttentionKernel( diff --git a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc index dd9224df8f380..09baf8def05f6 100644 --- a/onnxruntime/test/contrib_ops/packed_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_attention_op_test.cc @@ -433,7 +433,8 @@ static void RunModelWithRandomInput( std::vector token_offset_dims{batch_size, sequence_length}; std::vector cum_seq_len_dims{batch_size + 1}; - float gpu_threshold = is_float16 ? 0.1f : 0.005f; + float gpu_threshold = is_float16 ? 0.15f : 0.005f; + gpu_threshold *= sequence_length > 1024 ? 4.0f : 1.0f; // threshold should increase with sequence length bool enable_cuda = HasCudaEnvironment(is_float16 ? 530 : 0); if (enable_cuda) { OpTester test("PackedAttention", 1, onnxruntime::kMSDomain); diff --git a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc index fc2b58680c84f..22253955566f2 100644 --- a/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/packed_multihead_attention_op_test.cc @@ -160,6 +160,7 @@ static void RunPackedMultiHeadAttentionTest( if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "0"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -168,10 +169,11 @@ static void RunPackedMultiHeadAttentionTest( InvokePackedMultiHeadAttentionTest(true, false); } -#if USE_FLASH_ATTENTION +#if USE_MEMORY_EFFICIENT_ATTENTION if (kernel_type == AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -182,9 +184,20 @@ static void RunPackedMultiHeadAttentionTest( } #endif +#if USE_FLASH_ATTENTION + if (kernel_type == AttentionKernelType::AttentionKernel_FlashAttention) { + ScopedEnvironmentVariables scoped_env_vars{ + EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "0"}, + {onnxruntime::contrib::attention::kMinSeqLenForFlashAttentionPackedQKV, "0"}}}; + InvokePackedMultiHeadAttentionTest(true, true); + } +#endif + if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) { ScopedEnvironmentVariables scoped_env_vars{ EnvVarMap{ + {onnxruntime::contrib::attention::kDisableFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedSelfAttention, "1"}, {onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}, @@ -389,6 +402,32 @@ TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_cutlass) { AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention); } +#if USE_FLASH_ATTENTION +TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_FlashAttention) { + if (HasCudaEnvironment(800)) { + PackedAttentionTestData data; + GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); + std::vector empty_data = {}; + + RunPackedMultiHeadAttentionTest( + data.qkv_data, + empty_data, + empty_data, + empty_data, + data.token_offset, + data.cumulative_sequence_length, + data.fp16_output_data, + data.batch_size, + data.sequence_length, + data.hidden_size, + data.v_hidden_size, + data.num_heads, + data.token_count, + AttentionKernelType::AttentionKernel_FlashAttention); + } +} +#endif + TEST(PackedMultiHeadAttentionTest, PackedQKV_Padding_NoBias_unfused) { PackedAttentionTestData data; GetPackedMultiHeadAttentionData_Batch2_HeadSize32_NoRelPosBias(data); diff --git a/onnxruntime/test/framework/float_8_test.cc b/onnxruntime/test/framework/float_8_test.cc new file mode 100644 index 0000000000000..520d44f33809b --- /dev/null +++ b/onnxruntime/test/framework/float_8_test.cc @@ -0,0 +1,94 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include + +#include "core/framework/float8.h" +#include "test/framework/test_utils.h" +#include "test/util/include/test/capturing_sink.h" +#include "test/util/include/test/test_environment.h" + +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +TEST(Float8_Tests, CastE4M3FN) { + std::vector> cases{ + std::pair(0.00439453125, 0.00390625), + std::pair(0.005859375, 0.005859375), + std::pair(0.005759375, 0.005859375), + std::pair(0.0046875, 0.00390625), + std::pair(0.001953125, 0.001953125), + std::pair(0.0029296875, 0.00390625), + std::pair(0.002053125, 0.001953125), + std::pair(0.00234375, 0.001953125), + std::pair(0.0087890625, 0.0078125), + std::pair(0.001171875, 0.001953125), + std::pair(1.8131605, 1.875)}; + for (auto it : cases) { + auto f8 = onnxruntime::Float8E4M3FN(it.first); + auto f8_32 = f8.ToFloat(); + EXPECT_EQ(it.second, f8_32); + } +} + +union float_bits { + uint32_t bits; + float val; +}; + +TEST(Float8_Tests, NanE4M3FN) { + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val).val, static_cast(0x7E)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val).val, static_cast(0xFE)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800000}).val, false).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800000}).val, false).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7F800001}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFF800001}).val).val, static_cast(0xFF)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FN((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); +} + +TEST(Float8_Tests, NanE4M3FNUZ) { + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7F800001}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFF800001}).val).val, static_cast(0x80)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E4M3FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); +} + +TEST(Float8_Tests, NanE5M2) { + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val).val, static_cast(0x7B)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val).val, static_cast(0xFB)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800000}).val, false).val, static_cast(0x7C)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800000}).val, false).val, static_cast(0xFC)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7F800001}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFF800001}).val).val, static_cast(0xFF)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0x7FC00000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2((float_bits{0xFFC00000}).val).val, static_cast(0xFF)); +} + +TEST(Float8_Tests, NanE5M2FNUZ) { + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val).val, static_cast(0x7F)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val).val, static_cast(0xFF)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800000}).val, false).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7F800001}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFF800001}).val).val, static_cast(0x80)); + // 0x7FC00000 is the value used by numpy. + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0x7FC00000}).val).val, static_cast(0x80)); + EXPECT_EQ(onnxruntime::Float8E5M2FNUZ((float_bits{0xFFC00000}).val).val, static_cast(0x80)); +} + +} // namespace test +} // namespace onnxruntime + +#endif // DISABLE_FLOAT8_TYPES diff --git a/onnxruntime/test/framework/ort_model_only_test.cc b/onnxruntime/test/framework/ort_model_only_test.cc index f8da4e895913a..e2cb82e47f32b 100644 --- a/onnxruntime/test/framework/ort_model_only_test.cc +++ b/onnxruntime/test/framework/ort_model_only_test.cc @@ -4,6 +4,7 @@ #include "core/flatbuffers/schema/ort.fbs.h" #include "core/framework/data_types.h" #include "core/framework/tensorprotoutils.h" +#include "core/framework/TensorSeq.h" #include "core/graph/model.h" #include "core/graph/onnx_protobuf.h" #include "core/session/onnxruntime_cxx_api.h" @@ -556,6 +557,41 @@ TEST(OrtModelOnlyTests, LoadOrtFormatModelFromBufferNoCopyInitializersUseBuffer) RunOrtModel(test_info); } +// regression test for 2 issues covered by PR #17000 (internally reported issue). +// 1) allocation planner broke in minimal build when subgraph had no nodes. +// 2) usage of a sequence data type caused an exception due to IsSparseTensor() throwing +// instead of allowing the calling code to have #ifdef'd code to handle when IsSparseTensor +// returned true and sparse tensors were disabled. +TEST(OrtModelOnlyTests, GithubIssue17000) { + // need to run the model to + auto model_uri = ORT_TSTR("testdata/ort_github_issue_17000.ort"); + + auto allocator = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; + + OrtValue item0, item1; + CreateMLValue(allocator, {1}, {1.f}, &item0); + CreateMLValue(allocator, {2}, {2.f, 3.f}, &item1); + + auto elem_type = DataTypeImpl::GetType(); + auto tensor_seq = std::make_unique(elem_type); + tensor_seq->SetElements({item0, item1}); + + auto mltype = DataTypeImpl::GetType(); + OrtValue value(tensor_seq.release(), mltype, mltype->GetDeleteFunc()); + + OrtModelTestInfo test_info; + test_info.model_filename = model_uri; + test_info.inputs.insert(std::make_pair("seq_in", value)); + test_info.output_names = {"still_has_elements"}; + test_info.output_verifier = [](const std::vector& fetches) { + const auto& output = fetches[0].Get(); + ASSERT_EQ(output.Shape().Size(), 1); + ASSERT_EQ(output.Data()[0], true); // removed one item from seq so should still have elements + }; + + RunOrtModel(test_info); +} + #if !defined(DISABLE_ML_OPS) // test that we can deserialize and run a previously saved ORT format model // for a model with sequence and map outputs diff --git a/onnxruntime/test/optimizer/compute_optimizer_test.cc b/onnxruntime/test/optimizer/compute_optimizer_test.cc index 01016774288e4..a03d0da2538d4 100644 --- a/onnxruntime/test/optimizer/compute_optimizer_test.cc +++ b/onnxruntime/test/optimizer/compute_optimizer_test.cc @@ -638,7 +638,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_ScalarSlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -737,7 +738,8 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -826,6 +828,345 @@ TEST(ComputeOptimizerTests, GatherMatMul_SlicingOnSecondLastDim) { } } +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 32, 256] (float) + | + LayerNormalization[axis=-1 (as example)] + | + [2, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherLayerNormalization) { + std::vector> test_config_pairs{ + // { + // is_scalar_slice, + // ln_axis_before_propagation, + // expected_ln_axis_after_propagation, + // expected to propagate + // } + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, -3, -3, false}, + {true, -2, -2, false}, + {true, -1, 1, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, -3, -3, false}, + {false, -2, -2, false}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t ln_axis_before = std::get<1>(p); + int64_t ln_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, ln_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["LayerNormalization"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + values, require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == ln_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, ln_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 32, 256}}); + auto* input2_arg = builder.MakeInput({{256}}); + auto* input3_arg = builder.MakeInput({{256}}); + auto* ln_out = builder.MakeIntermediate(); + builder.AddNode("LayerNormalization", {input1_arg, input2_arg, input3_arg}, {ln_out}) + .AddAttribute("axis", ln_axis_before); + + std::vector slice_inputs; + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {ln_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + +/* +Test graph includes multiple equivalent subgraphs as below. + graph input [2, 4, 32, 256] (float) + | + Softmax[axis=3 (as example)] + | + [2, 4, 32, 256] + | + | 0 (scalar) + | / + Gather[axis=1] + | + Identity + | + graph output [2, 32, 256] (float) + +Add an Identity node because currently, we don't allow Gather generates graph output. +*/ +TEST(ComputeOptimizerTests, GatherSoftmax) { + std::vector> test_config_pairs{ + // {is_scalar_slice, softmax_axis_before_propagation, + // expected_softmax_axis_after_propagation, expected to propagate} + {true, 0, 0, false}, + {true, 1, 1, false}, + {true, 2, 1, true}, + {true, 3, 2, true}, + {true, -4, -4, false}, + {true, -3, -3, false}, + {true, -2, 1, true}, + {true, -1, 2, true}, + {false, 0, 0, false}, + {false, 1, 1, false}, + {false, 2, 2, true}, + {false, 3, 3, true}, + {false, -4, -4, false}, + {false, -3, -3, false}, + {false, -2, -2, true}, + {false, -1, -1, true}, + }; + + constexpr static int64_t gather_axis = 1; + constexpr static int64_t slice_data_value = 0; + + for (auto p : test_config_pairs) { + bool is_scalar_slice = std::get<0>(p); + int64_t softmax_axis_before = std::get<1>(p); + int64_t softmax_axis_after = std::get<2>(p); + bool expected_to_propagate = std::get<3>(p); + + const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); + + InlinedVector indices; + auto pre_graph_checker = [&indices](Graph& graph) -> Status { + auto op_count_pre = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_count_pre.size() == 3U); + TEST_RETURN_IF_NOT(op_count_pre["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_pre["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Gather") { + TEST_RETURN_IF_NOT(indices.empty()); + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(node.InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, + indices, require_constant)); + } + } + return Status::OK(); + }; + + auto post_graph_checker = [is_scalar_slice, softmax_axis_after, + &indices, expected_to_propagate](Graph& graph) { + auto op_count_post = CountOpsInGraph(graph); + + TEST_RETURN_IF_NOT(op_count_post.size() == 3U); + TEST_RETURN_IF_NOT(op_count_post["Softmax"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Gather"] == 1); + TEST_RETURN_IF_NOT(op_count_post["Identity"] == 1); + + for (Node& node : graph.Nodes()) { + if (node.OpType() == "Softmax") { + const auto& input_defs = node.InputDefs(); + + auto producer_node = graph.GetProducerNode(input_defs[0]->Name()); + if (expected_to_propagate) { + TEST_RETURN_IF_NOT(producer_node != nullptr); + TEST_RETURN_IF_NOT(producer_node->OpType() == "Gather"); + + InlinedVector values; + constexpr bool require_constant = true; + NodeArg* initializer_node_arg = graph.GetNodeArg(producer_node->InputDefs()[1]->Name()); + TEST_RETURN_IF_NOT(optimizer_utils::AppendTensorFromInitializer(graph, *initializer_node_arg, values, + require_constant)); + for (size_t i = 0; i < values.size(); i++) { + TEST_RETURN_IF_NOT(values[i] == indices[i]); + } + + const ONNX_NAMESPACE::TensorShapeProto* slice_out_shape = producer_node->OutputDefs()[0]->Shape(); + TEST_RETURN_IF_NOT(slice_out_shape != nullptr); + + auto& attrs = node.GetAttributes(); + TEST_RETURN_IF_NOT(attrs.find("axis") != attrs.end()); + + auto& axis_attr = attrs.at("axis"); + auto axis_value = (int)axis_attr.i(); + TEST_RETURN_IF_NOT(axis_value == softmax_axis_after); + + if (is_scalar_slice) { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 3); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 256); + } else { + TEST_RETURN_IF_NOT(slice_out_shape->dim_size() == 4); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(0)) && + slice_out_shape->dim(0).dim_value() == 2); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(1)) && + slice_out_shape->dim(1).dim_value() == 1); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(2)) && + slice_out_shape->dim(2).dim_value() == 32); + TEST_RETURN_IF_NOT(utils::HasDimValue(slice_out_shape->dim(3)) && + slice_out_shape->dim(3).dim_value() == 256); + } + + } else { + TEST_RETURN_IF_NOT(producer_node == nullptr); + } + } + } + + return Status::OK(); + }; + + auto build_test_case = [is_scalar_slice, softmax_axis_before](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput({{2, 4, 32, 256}}); + auto* softmax_out = builder.MakeIntermediate(); + builder.AddNode("Softmax", {input1_arg}, {softmax_out}) + .AddAttribute("axis", softmax_axis_before); + + std::vector slice_inputs; + + NodeArg* indices_initializer = nullptr; + + if (is_scalar_slice) { + indices_initializer = builder.MakeScalarInitializer(slice_data_value); + } else { + indices_initializer = builder.MakeInitializer({1}, {slice_data_value}); + } + + slice_inputs = {softmax_out, indices_initializer}; + + auto* gather_out = builder.MakeIntermediate(); + builder.AddNode("Gather", slice_inputs, + {gather_out}) + .AddAttribute("axis", gather_axis); + + auto* identity_out = builder.MakeOutput(); + builder.AddNode("Identity", {gather_out}, {identity_out}); + }; + + std::unique_ptr transformer = std::make_unique(); + ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 14, *logger, std::move(transformer), + TransformerLevel::Level1, + 1, pre_graph_checker, post_graph_checker)); + } +} + TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { const logging::Logger* logger = &logging::LoggingManager::DefaultLogger(); auto model_uri = MODEL_FOLDER "computation_reduction/gather/gather_reshape_scalar_batch_dim.onnx"; @@ -835,7 +1176,8 @@ TEST(ComputeOptimizerTests, GatherReshape_ScalarSlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); @@ -928,7 +1270,8 @@ TEST(ComputeOptimizerTests, GatherReshape_SlicingOnBatchDim) { std::map op_to_count = CountOpsInGraph(graph); onnxruntime::GraphTransformerManager graph_transformation_mgr{1}; - ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), + TransformerLevel::Level1)); ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger)); GraphViewer graph_viewer(graph); diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index 6fd1f6081cf05..85ccb8f175f62 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -202,19 +202,19 @@ struct TensorCheck { // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified. // If the isinf check is first the isnan check and branch gets omitted if (std::isnan(cur_expected[i])) { - ASSERT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(cur_expected[i])) { // Test infinity for equality - ASSERT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; + EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; } else { if (!has_abs_err && !has_rel_err) { // the default for existing tests - ASSERT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; + EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; } else { if (has_abs_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i; + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i; } if (has_rel_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) << "i:" << i; } } @@ -256,20 +256,20 @@ void InternalNumericalCheck(const Tensor& expected, // NOTE: Check isnan first to work around MSVC linker bug when /LTCG:incremental is specified. // If the isinf check is first the isnan check and branch gets omitted if (std::isnan(cur_expected[i])) { - ASSERT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(cur_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(cur_expected[i])) { // Test infinity for equality - ASSERT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; + EXPECT_EQ(cur_expected[i], cur_actual[i]) << "Expected infinity. i:" << i; } else { if (!has_abs_err && !has_rel_err) { // the default for existing tests - ASSERT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; + EXPECT_NEAR(cur_expected[i], cur_actual[i], threshold) << "i:" << i; } else { if (has_abs_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.absolute_error)) << "i:" << i; } if (has_rel_err) { - ASSERT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) + EXPECT_NEAR(cur_expected[i], cur_actual[i], *(params.relative_error) * std::abs(cur_expected[i])) << "i:" << i; } } diff --git a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc index 8126990df57cc..ee18cf2cea6cb 100644 --- a/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/batch_norm_op_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/framework/tensor.h" +#include "core/providers/cpu/nn/batch_norm.h" // for BATCHNORM_INCLUDE_TRAINING_SUPPORT #include "core/session/inference_session.h" #include "test/common/dnnl_op_test_utils.h" #include "test/providers/provider_test_utils.h" @@ -846,7 +847,7 @@ TEST(BatchNormTest, BatchNorm2d_bfloat16) { #endif // USE_DNNL // TODO fix flaky test for CUDA -#ifdef ENABLE_TRAINING_OPS +#ifdef BATCHNORM_INCLUDE_TRAINING_SUPPORT TEST(BatchNormTest, ForwardTrainingTestWithSavedOutputsOpset9) { // TODO: Unskip when fixed #41968513 if (DefaultDmlExecutionProvider().get() != nullptr) { @@ -936,7 +937,7 @@ TEST(BatchNormTest, ForwardTrainingTestOpset15) { {kCudaExecutionProvider, kRocmExecutionProvider, kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kDnnlExecutionProvider}); } -#endif // ENABLE_TRAINING_OPS +#endif // BATCHNORM_INCLUDE_TRAINING_SUPPORT } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_float8.py b/onnxruntime/test/python/onnxruntime_test_float8.py index 3f3180230f853..76ca5d9538374 100644 --- a/onnxruntime/test/python/onnxruntime_test_float8.py +++ b/onnxruntime/test/python/onnxruntime_test_float8.py @@ -8,9 +8,11 @@ import unittest import numpy as np +import packaging.version as pv import parameterized from numpy.testing import assert_allclose from onnx import TensorProto +from onnx import __version__ as onnx_version from onnx.checker import check_model from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor, make_tensor_value_info from onnx.reference import ReferenceEvaluator @@ -37,7 +39,7 @@ class TestInferenceSession(unittest.TestCase): `_. """ - dtypes = frozenset({"FLOAT": np.float32, "FLOAT16": np.float16}) + dtypes = {"FLOAT": np.float32, "FLOAT16": np.float16} # noqa: RUF012 x = np.array( [0.4068359375, 352, 416, 336, 304, 272, -248, -100, 1e-4, 1e-2, 416, 432, 1e5, np.inf, -np.inf, np.nan], dtype=np.float32, @@ -76,7 +78,7 @@ class TestInferenceSession(unittest.TestCase): 240.0, 240.0, -240.0, - -104.0, + -96.0, 0.0, 0.009765625, 240.0, @@ -113,7 +115,7 @@ class TestInferenceSession(unittest.TestCase): [ 0.4375, 384.0, - 448.0, + 384.0, 320.0, 320.0, 256.0, @@ -121,7 +123,7 @@ class TestInferenceSession(unittest.TestCase): -96.0, 0.0001068115234375, 0.009765625, - 448.0, + 384.0, 448.0, 57344.0, 57344.0, @@ -167,7 +169,7 @@ class TestInferenceSession(unittest.TestCase): np.nan, np.nan, np.nan, - -104.0, + -96.0, 0.0, 0.009765625, np.nan, @@ -204,7 +206,7 @@ class TestInferenceSession(unittest.TestCase): [ 0.4375, 384.0, - 448.0, + 384.0, 320.0, 320.0, 256.0, @@ -212,7 +214,7 @@ class TestInferenceSession(unittest.TestCase): -96.0, 0.0001068115234375, 0.009765625, - 448.0, + 384.0, 448.0, np.nan, np.nan, @@ -245,6 +247,7 @@ def model_cast_cast_f16_float(self, to, saturate, rev=False): check_model(onnx_model) return onnx_model + @unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0") @parameterized.parameterized.expand( [ ("FLOAT8E4M3FN", "FLOAT", 1), @@ -429,6 +432,7 @@ def model_qdq(self, to, float_name, saturate, castq=False, castdq=False, like=Fa check_model(onnx_model) return onnx_model + @unittest.skipIf(pv.Version(onnx_version) < pv.Version("1.15.0"), reason="needs onnx>=1.15.0") @parameterized.parameterized.expand( [ ("FLOAT8E4M3FN", "FLOAT", 1), @@ -689,6 +693,18 @@ def test_model_qdq_cuda_ortvalue(self, name: str, float_name: str, saturate: int self.assertEqual(expect.shape, y.shape) self.assertEqual(expect.dtype, y.dtype) + @unittest.skipIf("CUDAExecutionProvider" not in available_providers, reason="Not running on CUDA.") + def test_compare_cpu_cuda_e4m3fn(self): + folder = os.path.join(os.path.dirname(__file__), "..", "testdata", "float8") + model = os.path.join(folder, "te.cast_fp8_1_fp32.onnx") + data = np.load(os.path.join(folder, "te.cast_fp8_1_fp32_input.npy")) + + sess_cpu = onnxruntime.InferenceSession(model, providers=["CPUExecutionProvider"]) + sess_cuda = onnxruntime.InferenceSession(model, providers=["CUDAExecutionProvider"]) + cpu_res = sess_cpu.run(None, {"input": data})[0] + cuda_res = sess_cuda.run(None, {"input": data})[0] + self.assertEqual(cuda_res.tolist(), cpu_res.tolist()) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e554d418667a1..86577eaf2d7df 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -80,11 +80,7 @@ def test_model_serialization(self): so.log_severity_level = 1 so.logid = "TestModelSerialization" so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx" - onnxrt.InferenceSession( - get_name("mul_1.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) os.remove(so.optimized_model_filepath) except Fail as onnxruntime_error: @@ -179,6 +175,62 @@ def test_model_serialization_with_original_external_initializers_to_directory(se else: raise onnxruntime_error + def test_model_serialization_with_original_external_initializers_to_current_directory(self): + optimized_model_filepath = "model_opt_with_ext_data_1.onnx" + external_initializers_file = "model_opt_with_ext_data_1.bin" + optimized_model_filepath_2 = "model_opt_with_ext_data_2.onnx" + external_initializers_file_2 = "model_opt_with_ext_data_2.bin" + + so = onnxrt.SessionOptions() + so.log_severity_level = 1 + so.logid = "TestModelSerializationWithOriginalExternalInitializersToCurrentDirectory" + so.optimized_model_filepath = optimized_model_filepath + + so.add_session_config_entry( + "session.optimized_model_external_initializers_file_name", external_initializers_file + ) + + # TODO(anyone): Set this to 100 will cause test error since some tensor below the threshold + # still refers to the original external data file. We shall fix this issue so that the + # optimized model only refers to one external data file. + so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "10") + session1 = onnxrt.InferenceSession( + get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] + ) + del session1 + self.assertTrue(os.path.isfile(optimized_model_filepath)) + self.assertTrue(os.path.isfile(external_initializers_file)) + + so2 = onnxrt.SessionOptions() + so2.log_severity_level = 1 + so2.logid = "TestModelSerializationWithExternalInitializersInCurrentDirectory" + so2.optimized_model_filepath = optimized_model_filepath_2 + so2.add_session_config_entry( + "session.optimized_model_external_initializers_file_name", external_initializers_file_2 + ) + so2.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "10") + + # verify that we can load the optimized model with external data in current directory and save + # optimized model with external data to current directory. + session2 = onnxrt.InferenceSession( + optimized_model_filepath, sess_options=so2, providers=["CPUExecutionProvider"] + ) + del session2 + self.assertTrue(os.path.isfile(optimized_model_filepath_2)) + self.assertTrue(os.path.isfile(external_initializers_file_2)) + + # Remove model 1 to make sure optimized model 2 can be loaded independently from model 1 + os.remove(optimized_model_filepath) + os.remove(external_initializers_file) + + session3 = onnxrt.InferenceSession( + optimized_model_filepath_2, sess_options=onnxrt.SessionOptions(), providers=["CPUExecutionProvider"] + ) + del session3 + + os.remove(optimized_model_filepath_2) + os.remove(external_initializers_file_2) + def test_get_providers(self): self.assertTrue("CPUExecutionProvider" in onnxrt.get_available_providers()) # get_all_providers() returns the default EP order from highest to lowest. diff --git a/onnxruntime/test/python/quantization/resnet_code.py b/onnxruntime/test/python/quantization/resnet_code.py new file mode 100644 index 0000000000000..74e3652673c99 --- /dev/null +++ b/onnxruntime/test/python/quantization/resnet_code.py @@ -0,0 +1,13763 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import numpy +from onnx import numpy_helper +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info, set_model_props + + +def create_model(): + initializers = [] + nodes = [] + inputs = [] + outputs = [] + functions = [] + + # opsets + opsets = {"": 13} + + # initializers + + list_value = [ + -0.013732648454606533, + -0.005861935671418905, + 0.06889285147190094, + -0.1172710582613945, + 0.08841240406036377, + -0.03748627379536629, + 0.016256270930171013, + -0.1059316024184227, + 0.08246039599180222, + 0.14295539259910583, + -0.32958757877349854, + 0.1631188541650772, + 0.05412565544247627, + -0.10758306831121445, + 0.12607362866401672, + -0.4987836182117462, + 0.7441706657409668, + -0.24774713814258575, + -0.30415549874305725, + 0.4033295810222626, + -0.13447114825248718, + 0.04623159021139145, + 0.2380414456129074, + -1.226112723350525, + 2.150630235671997, + -1.702580213546753, + 0.5305419564247131, + -0.06836353242397308, + -0.20055373013019562, + 0.7035881280899048, + -0.8389442563056946, + -0.1904432326555252, + 1.2609282732009888, + -1.0670661926269531, + 0.4142579436302185, + 0.04739700257778168, + -0.3265092074871063, + 1.1873037815093994, + -1.6817731857299805, + 0.9709527492523193, + -0.09095840901136398, + -0.12556785345077515, + 0.0835147574543953, + -0.24109329283237457, + 0.032948240637779236, + 0.46304041147232056, + -0.6594106554985046, + 0.349990576505661, + -0.04113377630710602, + 0.016451245173811913, + 0.008994563482701778, + -0.028321878984570503, + -0.05336569994688034, + 0.16036668419837952, + -0.12088149785995483, + 0.031160499900579453, + -0.0618649423122406, + 0.07205374538898468, + 0.15965768694877625, + -0.3389044404029846, + 0.21603335440158844, + 0.04029613360762596, + -0.0813034325838089, + 0.1019665077328682, + -0.4873599112033844, + 0.7873126268386841, + -0.2951086163520813, + -0.43754327297210693, + 0.5905176401138306, + -0.21821773052215576, + 0.06022067740559578, + 0.26326146721839905, + -1.6453089714050293, + 2.606400728225708, + -1.8939754962921143, + 0.5196341276168823, + 0.0055860355496406555, + -0.2335057258605957, + 0.9807199239730835, + -1.2137882709503174, + -0.2699125409126282, + 1.7379733324050903, + -1.4401814937591553, + 0.435971736907959, + -0.04829222336411476, + -0.24543480575084686, + 1.3292583227157593, + -2.0375823974609375, + 1.2458536624908447, + -0.08251484483480453, + -0.14181238412857056, + 0.10612589120864868, + -0.21671657264232635, + 0.1129523366689682, + 0.3666985034942627, + -0.7546612024307251, + 0.42979565262794495, + -0.0976259633898735, + -0.0008812264422886074, + 0.02994859404861927, + -0.07027778774499893, + 0.01393035613000393, + 0.07363647222518921, + -0.10249849408864975, + 0.06602989137172699, + -0.012129798531532288, + 0.10730132460594177, + -0.04546127840876579, + -0.16065146028995514, + 0.14788293838500977, + -0.05488971993327141, + 0.03601694852113724, + 0.07513345777988434, + -0.23953600227832794, + 0.48062530159950256, + -0.42057543992996216, + -0.02402813360095024, + 0.17920851707458496, + -0.10703158378601074, + -0.028666120022535324, + 0.2815375030040741, + -0.860264241695404, + 1.4422725439071655, + -1.2058128118515015, + 0.5272247791290283, + -0.06504356116056442, + -0.20021803677082062, + 0.44968947768211365, + -0.3856053650379181, + -0.1589551419019699, + 0.7579770684242249, + -0.8349987268447876, + 0.3225692808628082, + 0.08153475821018219, + -0.43163740634918213, + 0.8742384910583496, + -0.9722443222999573, + 0.579015851020813, + -0.06688100844621658, + -0.12384293973445892, + 0.08289378881454468, + -0.10082041472196579, + -0.11204896867275238, + 0.3934254050254822, + -0.4511864185333252, + 0.32745760679244995, + -0.06534548103809357, + -0.028830429539084435, + 0.021844232454895973, + 0.01775779016315937, + -0.004250001162290573, + 0.013087524101138115, + -0.001250433037057519, + -0.040545206516981125, + -0.014049320481717587, + -0.024194253608584404, + -0.023865194991230965, + -0.0038033330347388983, + 0.00920871365815401, + -0.006582418456673622, + 0.0032474950421601534, + -0.0369916632771492, + -0.16640843451023102, + -0.28968843817710876, + -0.3531132638454437, + -0.26307201385498047, + -0.13392697274684906, + -0.03747623786330223, + 0.08083077520132065, + 0.2026241272687912, + 0.25018608570098877, + 0.2529378831386566, + 0.2307336926460266, + 0.13928599655628204, + 0.08631229400634766, + 0.13893137872219086, + 0.4867081344127655, + 0.7170669436454773, + 0.8331555724143982, + 0.6734364032745361, + 0.3549460768699646, + 0.16798041760921478, + -0.14487245678901672, + -0.47733625769615173, + -0.7670150995254517, + -0.875726580619812, + -0.6291986703872681, + -0.2910463213920593, + -0.09991979598999023, + -0.009158087894320488, + 0.018850643187761307, + 0.02646111696958542, + -0.009077857248485088, + 0.029430989176034927, + -0.03707962855696678, + -0.05111744999885559, + -0.02076525054872036, + 0.011828843504190445, + 0.017857171595096588, + 0.02548048458993435, + -0.009077494964003563, + 0.0022066361270844936, + -0.02064262516796589, + -0.008582246489822865, + -0.022748643532395363, + -0.03038850985467434, + 0.0006585497176274657, + -0.0016039719339460135, + -0.01612498238682747, + 0.013966801576316357, + -0.05851661041378975, + -0.21422894299030304, + -0.33863192796707153, + -0.3720807433128357, + -0.3030800521373749, + -0.1737397164106369, + -0.05903157964348793, + 0.15018144249916077, + 0.27454254031181335, + 0.31182464957237244, + 0.30118387937545776, + 0.24605700373649597, + 0.14123573899269104, + 0.14992672204971313, + 0.20660799741744995, + 0.5046274662017822, + 0.7706091403961182, + 0.8978630900382996, + 0.7368614673614502, + 0.3929724097251892, + 0.23079657554626465, + -0.21169082820415497, + -0.5920398235321045, + -0.893406867980957, + -0.9499238729476929, + -0.730407178401947, + -0.3615736961364746, + -0.15422092378139496, + -0.024615347385406494, + 0.005115498788654804, + 0.024657316505908966, + 0.028517475351691246, + 0.027910854667425156, + -0.009482389315962791, + -0.042242538183927536, + -0.017875321209430695, + 0.00430292496457696, + 0.015949612483382225, + 0.003636278910562396, + -0.018156034871935844, + -0.0009349065367132425, + -0.0010362856555730104, + -0.013051170855760574, + -0.009141271002590656, + -8.714485738892108e-05, + 0.02399279735982418, + 0.01753612607717514, + -0.013710699044167995, + -0.014245252124965191, + -0.0028008236549794674, + -0.08206935226917267, + -0.1098734438419342, + -0.10250325500965118, + -0.08874496072530746, + -0.031079040840268135, + 0.004536658991128206, + 0.03923843801021576, + 0.08478657901287079, + 0.07715648412704468, + 0.018803801387548447, + 0.013921198435127735, + 0.015864359214901924, + 0.04947463795542717, + 0.039856068789958954, + 0.1712094396352768, + 0.362756609916687, + 0.4192918539047241, + 0.2668488621711731, + 0.11430513113737106, + 0.06648365408182144, + -0.058979276567697525, + -0.24177154898643494, + -0.3709423542022705, + -0.3979431986808777, + -0.29706764221191406, + -0.11569518595933914, + -0.01848490908741951, + -0.015523962676525116, + 0.05081642046570778, + 0.09057094901800156, + 0.08520761132240295, + 0.04497350752353668, + -0.019453801214694977, + -0.06109466031193733, + 0.011463015340268612, + -0.008522219955921173, + -0.005283404141664505, + -0.017313135787844658, + -0.0015744483098387718, + -0.011845857836306095, + -0.016727561131119728, + -0.006708915811032057, + 0.0008860539528541267, + -0.010050912387669086, + -0.028460539877414703, + -0.0165643822401762, + -0.016545938327908516, + -0.00567589420825243, + -0.0032017906196415424, + -0.0130555285140872, + -0.026848897337913513, + -0.02615198865532875, + 0.002669057110324502, + -0.027966763824224472, + -0.03851256147027016, + -0.014509409666061401, + -0.029059220105409622, + -0.007284109480679035, + 0.04045313969254494, + 0.10005538910627365, + 0.014574537053704262, + -0.044292762875556946, + -0.01750861294567585, + -0.02231375314295292, + -0.004432118032127619, + 0.10051869601011276, + 0.1443023532629013, + 0.0508832149207592, + -0.04350621998310089, + -0.0025447055231779814, + -0.014583000913262367, + -0.02153291553258896, + 0.018860718235373497, + 0.03618147224187851, + 0.007304056081920862, + -0.029104959219694138, + 0.00576505484059453, + -0.016025763005018234, + -0.025094063952565193, + -0.05296780914068222, + -0.037012189626693726, + -0.04414081946015358, + -0.053135257214307785, + -0.028890708461403847, + -0.010220452211797237, + -0.027575822547078133, + -0.01087758969515562, + -0.027209162712097168, + -0.030827227979898453, + -0.007646164856851101, + -0.016133273020386696, + 0.000639698002487421, + -0.0034172122832387686, + 0.03914793208241463, + 0.030786357820034027, + 0.005965455900877714, + 0.020923329517245293, + -0.03435938432812691, + -0.0026781477499753237, + 0.04278327897191048, + 0.20045910775661469, + 0.21770593523979187, + 0.09422573447227478, + 0.03198440372943878, + -0.021056609228253365, + 0.028007682412862778, + 0.19196027517318726, + 0.4791645109653473, + 0.5333831906318665, + 0.3014310598373413, + 0.103666290640831, + -0.03651479259133339, + 0.027079502120614052, + 0.19239209592342377, + 0.5168290138244629, + 0.5564895868301392, + 0.2977963089942932, + 0.07770062237977982, + -0.042239490896463394, + -0.017265107482671738, + 0.08760321140289307, + 0.2775075435638428, + 0.312491774559021, + 0.12284757196903229, + 0.019664151594042778, + -0.026643047109246254, + 0.0009152573184110224, + 0.016156431287527084, + 0.09042830765247345, + 0.08991760015487671, + 0.013326293788850307, + 0.02613811008632183, + 0.021025240421295166, + 0.0198842640966177, + 0.03375901281833649, + 0.028616728261113167, + 0.026605166494846344, + 0.04126269370317459, + 0.029309948906302452, + 0.01408455427736044, + -0.003831037785857916, + 0.01922326348721981, + -0.018229445442557335, + -0.013015883974730968, + 0.017597628757357597, + -0.007964612916111946, + 0.045263469219207764, + 0.0184696726500988, + -0.001163159729912877, + -0.1809321641921997, + -0.22486254572868347, + -0.08606110513210297, + 0.001087217591702938, + 0.037091098725795746, + -0.013625397346913815, + -0.178089901804924, + -0.5483279824256897, + -0.612791895866394, + -0.32531827688217163, + -0.06506585329771042, + 0.05076128616929054, + -0.007585812360048294, + -0.20981833338737488, + -0.6155760884284973, + -0.7119701504707336, + -0.354442298412323, + -0.04236743599176407, + 0.045713260769844055, + 0.03192479908466339, + -0.07216271013021469, + -0.310979425907135, + -0.3656359910964966, + -0.13522450625896454, + 0.008291869424283504, + 0.03362602740526199, + -0.0009240762447007, + 0.01604474149644375, + -0.055634208023548126, + -0.06180194392800331, + 0.0222025066614151, + 0.027704820036888123, + -0.034385330975055695, + -0.07050742954015732, + -0.06287489086389542, + 0.03521641716361046, + -0.00020920530369039625, + 0.05458284169435501, + 0.058752644807100296, + -0.08097169548273087, + -0.01668735221028328, + 0.18557283282279968, + 0.26208117604255676, + 0.1253771185874939, + 0.07758381962776184, + -0.022084739059209824, + 0.016727397218346596, + 0.23247942328453064, + 0.35444316267967224, + 0.21802566945552826, + -0.04409221559762955, + -0.08573070168495178, + -0.0994141548871994, + 0.07754423469305038, + 0.14311672747135162, + 0.04036660119891167, + -0.29222917556762695, + -0.38828015327453613, + -0.26185816526412964, + -0.12845511734485626, + 0.04763585329055786, + -0.017382778227329254, + -0.16010743379592896, + -0.2395028918981552, + -0.2049665004014969, + -0.041346337646245956, + 0.091490738093853, + -0.005191737785935402, + -0.07687077671289444, + -0.08105621486902237, + -0.05329642817378044, + -0.03404862806200981, + 0.11478845030069351, + 0.13328343629837036, + -0.037197597324848175, + -0.01787363924086094, + -0.016605347394943237, + 0.007853846065700054, + 0.029950136318802834, + 0.10808859020471573, + 0.02873288467526436, + -0.1766187697649002, + -0.17560969293117523, + -0.03922238200902939, + 0.14447443187236786, + 0.1534212827682495, + 0.11272227019071579, + 0.008810695260763168, + -0.1485181748867035, + 0.07839693129062653, + 0.43013128638267517, + 0.4898712635040283, + 0.26522761583328247, + 0.10202436149120331, + -0.07163076847791672, + 0.09933187812566757, + 0.47377726435661316, + 0.6340300440788269, + 0.36741772294044495, + -0.04812543839216232, + -0.17370514571666718, + -0.17513291537761688, + 0.22105705738067627, + 0.3226463794708252, + 0.09850790351629257, + -0.4044247269630432, + -0.6237908601760864, + -0.4679968059062958, + -0.1954391747713089, + 0.09878316521644592, + -0.004430827684700489, + -0.31550562381744385, + -0.5235733985900879, + -0.4510284662246704, + -0.13843706250190735, + 0.10064390301704407, + -0.006748788990080357, + -0.12714813649654388, + -0.2107744812965393, + -0.18755048513412476, + -0.05646044388413429, + 0.12781813740730286, + 0.18928050994873047, + -0.04337320104241371, + -0.04973407834768295, + -0.04690375551581383, + 0.0245530866086483, + 0.10698680579662323, + 0.1646823137998581, + 0.081840381026268, + -0.01471243891865015, + -0.03138890117406845, + -0.04195617139339447, + 0.012708203867077827, + 0.033312954008579254, + 0.02409377694129944, + -0.0036440726835280657, + -0.06239784508943558, + 0.0037516560405492783, + 0.11261500418186188, + 0.13069754838943481, + 0.05901307612657547, + 0.048614490777254105, + -0.027712708339095116, + 0.027247682213783264, + 0.19195327162742615, + 0.2688453793525696, + 0.1509387195110321, + 0.020540937781333923, + -0.004100556951016188, + -0.012650247663259506, + 0.039176344871520996, + 0.09037251025438309, + -0.004689970053732395, + -0.23859903216362, + -0.2364242821931839, + -0.15189304947853088, + -0.0761493444442749, + -0.0028172829188406467, + -0.04328106716275215, + -0.16187387704849243, + -0.21743592619895935, + -0.1282283067703247, + -0.024501819163560867, + 0.04029383510351181, + -0.027387680485844612, + -0.05414740741252899, + -0.08344019204378128, + -0.06591048091650009, + 0.012637111358344555, + 0.06905930489301682, + 0.08426016569137573, + -0.0030199100729078054, + 0.034059297293424606, + 0.01111840270459652, + 0.013492933474481106, + 0.0674189031124115, + 0.08242739737033844, + 0.006129032466560602, + -0.07763395458459854, + -0.03002289868891239, + -0.055725954473018646, + 0.008795201778411865, + 0.02994825504720211, + -0.06114519387483597, + -0.0560108907520771, + -0.008179228752851486, + -0.07149285078048706, + -0.02700420655310154, + -0.01306728646159172, + 0.06276566535234451, + 0.007125973701477051, + -0.03540417551994324, + -0.039717916399240494, + 0.009147526696324348, + -0.06517947465181351, + 0.0720859095454216, + -0.05035398155450821, + 0.06659520417451859, + -0.01841895841062069, + 0.004233633633702993, + -0.020911216735839844, + -0.004646372981369495, + 1.6690073013305664, + 0.4517613649368286, + -0.07667035609483719, + 0.005556757096201181, + -0.02638973295688629, + 0.044588603079319, + -0.020916732028126717, + 0.2571280598640442, + -0.009559552185237408, + -0.043380800634622574, + 0.03196016326546669, + -0.03783237189054489, + -0.03076902963221073, + 0.03180111199617386, + 0.06352709978818893, + 0.020281998440623283, + -0.00741154421120882, + -0.0009214285528287292, + -0.0476187989115715, + -0.07208544760942459, + -0.05323023349046707, + -0.011103631928563118, + 0.02877136506140232, + -0.05324484035372734, + -0.10076326876878738, + 0.026193000376224518, + 0.03536469116806984, + 0.045722659677267075, + -0.03756006807088852, + 0.022998394444584846, + 0.0019359687576070428, + 0.01654801517724991, + 0.047304198145866394, + -0.08431598544120789, + -0.0645647644996643, + -0.17326746881008148, + -0.10692577064037323, + -0.08416426181793213, + -0.04107839986681938, + -0.0012680464424192905, + -0.02600814774632454, + -0.014215772971510887, + 0.2114446610212326, + -0.040954578667879105, + -0.05050172284245491, + 0.004194092936813831, + -0.0025900816544890404, + -0.1359374076128006, + 0.03946976363658905, + 2.3023669719696045, + 0.7484877109527588, + -0.1994970589876175, + -0.06490366160869598, + 0.007983183488249779, + -0.017937449738383293, + -0.12516839802265167, + 0.3313288688659668, + 0.11946671456098557, + -0.16942338645458221, + -0.007721045054495335, + 0.02824605070054531, + -0.05310647189617157, + -0.1122083067893982, + -0.17094524204730988, + -0.08465421944856644, + -0.09679102897644043, + -0.03848385065793991, + 0.040121182799339294, + -0.06661732494831085, + 0.0005764663219451904, + -0.05729356408119202, + -0.04778655245900154, + -0.034835152328014374, + -0.07634143531322479, + -0.05054831504821777, + 0.00597620103508234, + 0.04499154910445213, + -0.03308190405368805, + -0.04915233701467514, + -0.05842791870236397, + 0.003590918146073818, + 0.055837079882621765, + -0.02547842636704445, + -0.018847621977329254, + -0.2073899656534195, + -0.14987564086914062, + -0.03971748799085617, + 0.05886378139257431, + 0.020922083407640457, + -0.039155181497335434, + -0.028855402022600174, + 0.08688661456108093, + -0.1402827501296997, + -0.05810496211051941, + 0.037841811776161194, + -0.04082907736301422, + -0.1191127747297287, + -0.10852136462926865, + 1.6274418830871582, + 0.3678200840950012, + -0.2865799367427826, + -0.05291350558400154, + 0.023858532309532166, + -0.046683818101882935, + -0.2307816743850708, + -0.001670230645686388, + -0.17716962099075317, + -0.16724731028079987, + 0.040194038301706314, + -0.023075448349118233, + -0.01538322027772665, + -0.07914327085018158, + -0.19621343910694122, + -0.11628971993923187, + -0.05851752683520317, + 0.06313594430685043, + 0.017808571457862854, + 0.02447943389415741, + 0.048611078411340714, + -0.009247995913028717, + 0.00789090245962143, + 0.06673033535480499, + 0.0661577433347702, + 0.019111329689621925, + 0.038164373487234116, + 0.029342610388994217, + -0.03547409921884537, + -0.11017149686813354, + -0.11077891290187836, + 0.001108204829506576, + -0.0330691784620285, + -0.05039837956428528, + 0.017638904973864555, + 0.277705579996109, + 0.5606598258018494, + 0.5469182133674622, + 0.13591277599334717, + 0.012421006336808205, + 0.046348799020051956, + -0.02721901424229145, + -0.5645118355751038, + -1.072814702987671, + -0.9852984547615051, + -0.3608386516571045, + -0.010197073221206665, + -0.09785731136798859, + -0.02597353421151638, + 0.4627133309841156, + 1.1483618021011353, + 0.9505703449249268, + 0.17471027374267578, + -0.016467586159706116, + 0.026623696088790894, + 0.04765752702951431, + -0.4000166058540344, + -0.8956774473190308, + -0.6268588304519653, + -0.09439487755298615, + 0.02861764468252659, + -0.004155704285949469, + 0.08989865332841873, + 0.27384331822395325, + 0.6518518328666687, + 0.4184596836566925, + 0.13106893002986908, + 0.0050344159826636314, + 0.007061495911329985, + -0.016157688573002815, + -0.1364346295595169, + -0.27324289083480835, + -0.14245718717575073, + -0.04623992741107941, + -0.015541884116828442, + 0.030779436230659485, + 0.03756715729832649, + 0.01957445964217186, + -0.04964561015367508, + -0.0211405660957098, + 0.044496409595012665, + -0.026335055008530617, + -0.11620140820741653, + -0.11803250014781952, + 0.18242181837558746, + 0.5057784914970398, + 0.5045838952064514, + 0.03748183697462082, + 0.05692485347390175, + 0.1608155369758606, + 0.02245517633855343, + -0.7651812434196472, + -1.5504053831100464, + -1.3563542366027832, + -0.4314505457878113, + -0.028384560719132423, + -0.12238024920225143, + 0.106974296271801, + 1.11427903175354, + 2.173083543777466, + 1.747692346572876, + 0.5455064177513123, + 0.03363418206572533, + 0.11388687789440155, + -0.05905687436461449, + -0.8059568405151367, + -1.6196117401123047, + -1.1898213624954224, + -0.2654758095741272, + -0.004251840524375439, + -0.0916782096028328, + -0.024067873135209084, + 0.22692462801933289, + 0.6695711612701416, + 0.3673460781574249, + -0.017016466706991196, + -0.029604146257042885, + 0.020365707576274872, + 0.03215239942073822, + 0.0070981839671730995, + -0.14026938378810883, + -0.02425236999988556, + 0.059152450412511826, + -0.006319367326796055, + 0.003989882301539183, + 0.048541076481342316, + 0.003988460637629032, + -0.03105335496366024, + -0.08329232037067413, + 0.03226872906088829, + 0.02119620516896248, + -0.0953872874379158, + -0.15174035727977753, + 0.07963212579488754, + 0.29094186425209045, + 0.2690921127796173, + -0.020104877650737762, + 0.024988379329442978, + 0.15326620638370514, + 0.1256464123725891, + -0.40941280126571655, + -0.946648120880127, + -0.8358487486839294, + -0.14284957945346832, + -0.07980851829051971, + -0.1435413807630539, + 0.038134895265102386, + 0.8021518588066101, + 1.552701473236084, + 1.2496209144592285, + 0.38152581453323364, + 0.07136060297489166, + 0.14329172670841217, + -0.06546801328659058, + -0.5923707485198975, + -1.253793478012085, + -0.9458200335502625, + -0.156633198261261, + -0.04217473417520523, + -0.11199303716421127, + -0.07520301640033722, + 0.15331010520458221, + 0.4794600307941437, + 0.2449675053358078, + -0.10396319627761841, + 0.0034801275469362736, + 0.04475663974881172, + 0.024035215377807617, + 0.056806568056344986, + -0.07363307476043701, + -0.001563104335218668, + 0.05157755687832832, + 0.043718185275793076, + 0.02102719619870186, + 0.11859089881181717, + 0.08675580471754074, + -0.13180124759674072, + -0.15522590279579163, + 0.03273458778858185, + -0.0019622649997472763, + 0.1011638194322586, + -0.10800585150718689, + -0.6884365677833557, + -0.5495791435241699, + 0.0780424103140831, + 0.33674973249435425, + -0.21274283528327942, + -0.4183696210384369, + -0.8053947687149048, + 0.03347628563642502, + 1.3938312530517578, + 0.9454176425933838, + -0.012210174463689327, + 0.04924672842025757, + 0.16284359991550446, + 1.1340152025222778, + 2.0020322799682617, + 0.2796843647956848, + -0.968036413192749, + -0.5768532752990723, + 0.17757350206375122, + 0.37485063076019287, + 0.11534234136343002, + -1.2916942834854126, + -1.692176103591919, + -0.30523377656936646, + 0.14307916164398193, + 0.03928302228450775, + -0.19196964800357819, + -0.4533900022506714, + -0.3294944167137146, + 0.5480389595031738, + 0.4497548043727875, + 0.2170887440443039, + -0.05817069113254547, + -0.06957870721817017, + 0.03169052675366402, + 0.23751793801784515, + 0.0823391005396843, + -0.04811413958668709, + -0.051265716552734375, + -0.0395645909011364, + -0.03849785774946213, + 0.04607917368412018, + 0.09946659207344055, + -0.029992828145623207, + -0.05369366332888603, + -0.005230880342423916, + 0.012808755040168762, + 0.1821947544813156, + 0.05478882044553757, + -0.47736144065856934, + -0.44480830430984497, + -0.036321353167295456, + 0.13646431267261505, + -0.04045571759343147, + -0.21837295591831207, + -0.6888197660446167, + -0.08431777358055115, + 0.96018385887146, + 0.6788493990898132, + 0.011028020642697811, + 0.05917810648679733, + 0.02488739602267742, + 0.6898419857025146, + 1.4259209632873535, + 0.13193827867507935, + -0.8078985810279846, + -0.31056249141693115, + 0.018122224137187004, + 0.137860506772995, + 0.051947757601737976, + -0.9757952094078064, + -1.1060559749603271, + 0.06675099581480026, + 0.2091575562953949, + -0.029623042792081833, + -0.0705878809094429, + -0.18514159321784973, + -0.07947035878896713, + 0.5719470381736755, + 0.2286168485879898, + -0.03433626517653465, + 0.0036030709743499756, + 0.006251791957765818, + 0.04144154116511345, + 0.08598234504461288, + -0.050599172711372375, + -0.10440917313098907, + -0.02927244082093239, + -0.04102599248290062, + -0.07101748138666153, + -0.03579306975007057, + 0.03586365282535553, + 0.06752362847328186, + 0.048901572823524475, + -0.020898710936307907, + -0.009411930106580257, + 0.10169848799705505, + 0.1812015175819397, + -0.014482695609331131, + -0.12548771500587463, + -0.060731250792741776, + -0.034499138593673706, + 0.0829617902636528, + 0.04616715386509895, + -0.20867496728897095, + -0.1990129053592682, + 0.1773940473794937, + 0.13156233727931976, + -0.03437860682606697, + 0.04012921825051308, + -0.11132699251174927, + -0.023460939526557922, + 0.2713286876678467, + -0.06662362813949585, + -0.2709292471408844, + -0.0030232456047087908, + -0.10379529744386673, + -0.07136038690805435, + 0.03757762163877487, + -0.20515622198581696, + -0.1231834888458252, + 0.26915228366851807, + 0.0998353362083435, + -0.031466737389564514, + 0.04657471179962158, + 0.07664929330348969, + 0.10308870673179626, + 0.23429608345031738, + -0.06942534446716309, + -0.09051290899515152, + 0.03243685141205788, + 0.04053235426545143, + -0.021392958238720894, + -0.05330868810415268, + -0.11525140702724457, + -0.03889385238289833, + 0.01636480540037155, + -0.009352890774607658, + 0.13151532411575317, + -0.14738643169403076, + -0.18289834260940552, + 0.15955400466918945, + -0.001023759599775076, + 0.028809679672122, + 0.012261062860488892, + 0.29654747247695923, + -0.285063236951828, + -0.40187928080558777, + 0.3713407516479492, + 0.009383893571794033, + -0.023022817447781563, + -0.003799814498052001, + 0.48470190167427063, + -0.43402406573295593, + -0.5858806371688843, + 0.5751441717147827, + 0.05045031011104584, + -0.05559438094496727, + -0.02045449987053871, + 0.5281224250793457, + -0.5058223605155945, + -0.5950849056243896, + 0.6492323279380798, + 0.013408469036221504, + -0.05940670147538185, + -0.0044364179484546185, + 0.3112560212612152, + -0.34908774495124817, + -0.42427319288253784, + 0.43349501490592957, + 0.03724945709109306, + -0.05263671651482582, + -0.010485195554792881, + 0.1261255145072937, + -0.1349790245294571, + -0.2524855136871338, + 0.24608080089092255, + 0.036001257598400116, + -0.028843939304351807, + 0.0056989979930222034, + 0.04458172619342804, + -0.06122935935854912, + -0.166972354054451, + 0.14557687938213348, + 0.018050044775009155, + 0.032598987221717834, + -0.0055792503990232944, + 0.24355076253414154, + -0.21433626115322113, + -0.29646870493888855, + 0.1958809792995453, + 0.015435033477842808, + 0.05235098674893379, + 0.010786890983581543, + 0.47903597354888916, + -0.4127257168292999, + -0.6203306317329407, + 0.47024452686309814, + 0.0823090448975563, + -0.04538045823574066, + -0.004072466865181923, + 0.7509317994117737, + -0.6508772969245911, + -0.8481631278991699, + 0.7875698208808899, + 0.0966777428984642, + -0.10461349785327911, + 0.0063789174892008305, + 0.7535857558250427, + -0.8082649111747742, + -0.8165622353553772, + 0.9064085483551025, + 0.04986630380153656, + -0.10200339555740356, + 0.0314355194568634, + 0.46324053406715393, + -0.5523763298988342, + -0.5632953643798828, + 0.6378755569458008, + 0.07833302766084671, + -0.07979781180620193, + 0.031164664775133133, + 0.1967470794916153, + -0.21681970357894897, + -0.29283079504966736, + 0.3367702066898346, + 0.034929461777210236, + -0.047199901193380356, + -0.0033645557705312967, + 0.05454660952091217, + -0.11264829337596893, + -0.190998375415802, + 0.17961400747299194, + 0.0009085010970011353, + -0.0001827089727157727, + 0.04841821268200874, + 0.019923821091651917, + -0.07004066556692123, + -0.10590090602636337, + 0.054114967584609985, + 0.04302384704351425, + 0.00462615629658103, + 0.022948985919356346, + 0.1673787385225296, + -0.1319379210472107, + -0.2711219787597656, + 0.2387620061635971, + 0.05667697265744209, + -0.018639734014868736, + -0.07672597467899323, + 0.3503187298774719, + -0.2981504797935486, + -0.38647517561912537, + 0.4072522521018982, + 0.010913677513599396, + -0.05246961489319801, + -0.04058554396033287, + 0.39216771721839905, + -0.3605193495750427, + -0.34857264161109924, + 0.46899959444999695, + -0.03358001261949539, + -0.05188553035259247, + -0.023204902186989784, + 0.17140533030033112, + -0.2120431810617447, + -0.2144550085067749, + 0.2837989032268524, + -0.0191226527094841, + -0.020922169089317322, + 0.004324179142713547, + 0.038136694580316544, + -0.042803723365068436, + -0.11487454175949097, + 0.11820490658283234, + 0.003412557765841484, + 0.0035020115319639444, + 0.03646541014313698, + -0.010104459710419178, + -0.010897459462285042, + -0.09292570501565933, + 0.06823977828025818, + 0.02677192911505699, + 0.020071662962436676, + 0.005776307079941034, + 0.02613351307809353, + 0.017107944935560226, + -0.0002623539185151458, + -0.039298396557569504, + -0.0314190648496151, + -0.019773684442043304, + -0.01924789510667324, + 0.04253160580992699, + 0.09694722294807434, + 0.1925637573003769, + 0.1901547759771347, + 0.09470294415950775, + -0.00296174269169569, + -0.03602522239089012, + 0.03572473302483559, + 0.08787581324577332, + 0.1773553043603897, + 0.20970025658607483, + 0.14899243414402008, + 0.05427362397313118, + -0.032429151237010956, + 0.023915717378258705, + 0.06557436287403107, + 0.13488733768463135, + 0.17550915479660034, + 0.17485061287879944, + 0.10260436683893204, + -0.005381361581385136, + -0.05573735386133194, + -0.09410752356052399, + -0.07940010726451874, + -0.03424998000264168, + 0.007975265383720398, + 0.028827181085944176, + 0.023788832128047943, + -0.02962818741798401, + -0.13474339246749878, + -0.22529757022857666, + -0.20413516461849213, + -0.14711618423461914, + -0.05960607901215553, + 0.04579121991991997, + 0.005325576290488243, + -0.11592217534780502, + -0.2260522097349167, + -0.2467145025730133, + -0.22054187953472137, + -0.13919179141521454, + 0.0016459478065371513, + 0.0515579916536808, + 0.060555730015039444, + 0.040788713842630386, + -0.017907800152897835, + -0.026459651067852974, + -0.02488812990486622, + 0.015644825994968414, + 0.10543125867843628, + 0.19312354922294617, + 0.28380078077316284, + 0.28878358006477356, + 0.16968156397342682, + 0.04848042502999306, + -0.00986899808049202, + 0.06337545067071915, + 0.16356752812862396, + 0.2444516271352768, + 0.29273414611816406, + 0.2314801961183548, + 0.12695762515068054, + -0.022283215075731277, + 0.018402203917503357, + 0.07152476161718369, + 0.14247483015060425, + 0.18759845197200775, + 0.20828258991241455, + 0.14114585518836975, + -0.047197990119457245, + -0.13794781267642975, + -0.17509934306144714, + -0.1696663200855255, + -0.1206701323390007, + -0.036128126084804535, + 0.007180679589509964, + 0.006984225939959288, + -0.09600912779569626, + -0.22975720465183258, + -0.33287662267684937, + -0.2942708134651184, + -0.20305578410625458, + -0.08411446958780289, + 0.042896877974271774, + -0.020053744316101074, + -0.16365791857242584, + -0.3145587742328644, + -0.3321540057659149, + -0.2667454183101654, + -0.1542910486459732, + -0.006954069249331951, + 0.020191870629787445, + 0.014010002836585045, + 0.0016916356980800629, + -0.04649524390697479, + -0.014931428246200085, + -0.017954425886273384, + -0.020003901794552803, + 0.03831968829035759, + 0.08447518199682236, + 0.14068123698234558, + 0.13400419056415558, + 0.08205568045377731, + -0.0004489773709792644, + -0.019211264327168465, + 0.023363608866930008, + 0.08738930523395538, + 0.12299696356058121, + 0.13070489466190338, + 0.09040816128253937, + 0.03286544978618622, + -0.006979941390454769, + -0.0010930931894108653, + 0.04313739389181137, + 0.10121051222085953, + 0.11390950530767441, + 0.11383924633264542, + 0.06694260239601135, + -0.00425445893779397, + -0.0666416585445404, + -0.09225274622440338, + -0.0977785512804985, + -0.07118111103773117, + -0.026749763637781143, + -0.019425569102168083, + 0.03321055322885513, + -0.0033978468272835016, + -0.08309262245893478, + -0.15557922422885895, + -0.14969374239444733, + -0.07188998907804489, + -0.018716221675276756, + 0.022834330797195435, + 0.004232254344969988, + -0.04141783341765404, + -0.125192329287529, + -0.14545302093029022, + -0.12225300818681717, + -0.05844716727733612, + 0.010607236064970493, + 0.024218380451202393, + -0.002702374942600727, + -0.030814893543720245, + 0.03507756441831589, + -0.0506589449942112, + 0.03415676951408386, + 0.0011444400297477841, + 0.0026324463542550802, + 0.028514407575130463, + -0.01849454641342163, + -0.030959082767367363, + -0.05565863475203514, + 0.05771413818001747, + 0.003916156478226185, + -0.004474544432014227, + 0.04403551295399666, + 0.1733711212873459, + -0.37650829553604126, + 0.22322984039783478, + 0.0032540319953113794, + -0.01139416079968214, + -0.039046600461006165, + 0.0021948080975562334, + 0.5777754783630371, + -1.1944804191589355, + 0.769478976726532, + -0.1349843591451645, + 0.0004430754925124347, + -0.0061850035563111305, + -0.08340868353843689, + 0.8327823877334595, + -1.649588942527771, + 1.126111388206482, + -0.2918313145637512, + 0.003614947199821472, + 0.0016799914883449674, + -0.03255167230963707, + 0.6123784184455872, + -1.1993682384490967, + 0.8305437564849854, + -0.13622376322746277, + 0.00905851274728775, + -0.006772476714104414, + 0.07578610628843307, + 0.05859832838177681, + -0.4543764293193817, + 0.26330503821372986, + 0.0259060300886631, + -0.0007997890934348106, + 0.01269856933504343, + 0.006897627376019955, + -0.02491801232099533, + -0.03139931708574295, + 0.0028456314466893673, + 0.0008253560517914593, + -0.01086023822426796, + -0.004186873324215412, + 0.06299160420894623, + -0.039931319653987885, + -0.09315146505832672, + 0.05495935305953026, + 0.027547571808099747, + -0.010900916531682014, + -0.025233760476112366, + 0.060600072145462036, + 0.21010243892669678, + -0.5445898771286011, + 0.35070353746414185, + -0.033771682530641556, + -0.0269146841019392, + -0.025363197550177574, + -0.021729450672864914, + 0.70921790599823, + -1.4368270635604858, + 0.9582043290138245, + -0.1708265244960785, + 0.010022420436143875, + -0.032301150262355804, + -0.08667651563882828, + 1.0338889360427856, + -1.913576364517212, + 1.262008547782898, + -0.23795078694820404, + -0.032233912497758865, + -0.01397701445966959, + -0.05402921140193939, + 0.7621430158615112, + -1.387437343597412, + 0.8621506094932556, + -0.14765247702598572, + -0.004747485741972923, + 0.0017516895895823836, + 0.08154146373271942, + 0.16601374745368958, + -0.5324177742004395, + 0.27442997694015503, + 0.03274058923125267, + -0.008812552317976952, + 0.005774920806288719, + 0.04165825620293617, + -0.011749272234737873, + -0.01953396573662758, + -0.009672109968960285, + 0.01170953270047903, + 0.003071938641369343, + -0.018979815766215324, + 0.062123894691467285, + -0.004921444226056337, + -0.03380037844181061, + 0.01310884952545166, + 0.007953890599310398, + -0.0012086924398317933, + -0.03317898139357567, + -0.0015596294542774558, + 0.08166785538196564, + -0.2291223704814911, + 0.11783571541309357, + -0.016078786924481392, + 0.018957575783133507, + 0.025793947279453278, + -0.09036394208669662, + 0.3833881616592407, + -0.5794023871421814, + 0.4610825777053833, + -0.14165280759334564, + -0.007412370759993792, + 0.05252876877784729, + -0.21435455977916718, + 0.6177686452865601, + -0.8516795635223389, + 0.667263925075531, + -0.22572898864746094, + -0.004465761594474316, + 0.02589319832623005, + -0.1893543303012848, + 0.43213585019111633, + -0.6462821364402771, + 0.434274822473526, + -0.15750259160995483, + -0.01198036689311266, + -2.4281514924950898e-05, + 0.039562296122312546, + 0.11126027256250381, + -0.23193514347076416, + 0.1412443071603775, + -0.011839920654892921, + 0.007880321703851223, + 0.02950354479253292, + 0.011689653620123863, + -0.07272310554981232, + -0.03319466486573219, + -0.003948990721255541, + 0.03549842908978462, + -0.02165558747947216, + -0.09912239760160446, + -0.08742356300354004, + 0.30591821670532227, + 0.23934677243232727, + 0.02658180706202984, + -0.022127188742160797, + -0.02769642136991024, + 0.16399237513542175, + 0.5140998959541321, + 0.007951628416776657, + -0.5589093565940857, + -0.24106110632419586, + -0.02753414213657379, + 0.06947467476129532, + 0.048558495938777924, + -0.5370690822601318, + -0.761831521987915, + 0.16272802650928497, + 0.29426246881484985, + 0.07943751662969589, + -0.022394873201847076, + -0.217612162232399, + -0.03093647211790085, + 0.5945476293563843, + 0.2873935103416443, + -0.16481661796569824, + -0.02931203693151474, + -0.029083512723445892, + 0.06754925847053528, + 0.20200076699256897, + -0.07271742075681686, + -0.1976277083158493, + -0.04189611226320267, + 0.06403793394565582, + -0.00022445111244451255, + -0.01032529678195715, + -0.03415631130337715, + 0.009091783314943314, + 0.04317992925643921, + 0.07196266949176788, + -0.025028688833117485, + -0.02722775563597679, + -0.017168480902910233, + -0.027666645124554634, + -0.06734028458595276, + 0.10843724757432938, + 0.08066407591104507, + -0.027849983423948288, + -0.0045820740051567554, + -0.03388727456331253, + 0.16772156953811646, + 0.651636004447937, + 0.34874194860458374, + -0.1454945057630539, + -0.18056720495224, + 0.11703842133283615, + 0.43017855286598206, + 0.7624525427818298, + -0.3420296907424927, + -1.272199273109436, + -0.5284644365310669, + -0.005667245015501976, + 0.08240436762571335, + -0.13299596309661865, + -1.3164156675338745, + -1.659982442855835, + 0.19898656010627747, + 0.6253566741943359, + 0.25137946009635925, + -0.18244975805282593, + -0.5360167622566223, + -0.06195700913667679, + 1.2547520399093628, + 1.0296341180801392, + 0.10651036351919174, + -0.023540280759334564, + -0.07594245672225952, + 0.1492130160331726, + 0.5033117532730103, + 0.09394379705190659, + -0.22459803521633148, + -0.22473134100437164, + -0.04738321527838707, + 0.04127531498670578, + 0.0682951882481575, + -0.02095615118741989, + -0.1233135387301445, + -0.10028401762247086, + -0.008111395873129368, + -0.000617706507910043, + 0.018859047442674637, + 0.028446361422538757, + -0.06159031391143799, + -0.1292838156223297, + 0.051308393478393555, + 0.11001072078943253, + -0.02056661807000637, + -0.012175443582236767, + -0.1313694268465042, + 0.0067574759013950825, + 0.4612729251384735, + 0.323080450296402, + -0.09392253309488297, + -0.1256203055381775, + 0.03537299111485481, + 0.2556088864803314, + 0.6467183232307434, + -0.16340143978595734, + -0.8799455165863037, + -0.3312987685203552, + 0.01464154850691557, + 0.07046713680028915, + 0.053634822368621826, + -0.8514915108680725, + -1.176972508430481, + 0.2056443840265274, + 0.4998764395713806, + 0.1268644779920578, + -0.10905193537473679, + -0.3750888705253601, + -0.06701061874628067, + 0.9052186608314514, + 0.6792045831680298, + -0.00323892361484468, + -0.0007412935374304652, + -0.03608793020248413, + 0.1009129211306572, + 0.36775916814804077, + 0.035214491188526154, + -0.2273784875869751, + -0.15815992653369904, + -0.004773923195898533, + 0.06374036520719528, + 0.04737555980682373, + -0.0563247986137867, + -0.09587392956018448, + -0.043853096663951874, + 0.032572731375694275, + -0.0036250585690140724, + 0.07889056205749512, + -0.03589344769716263, + -0.019771328195929527, + 0.04937156289815903, + 0.039052557200193405, + -0.013377528637647629, + -0.0841481015086174, + -0.03358105197548866, + -0.2128981053829193, + -0.14468812942504883, + 0.14675867557525635, + 0.2550889551639557, + 0.22369499504566193, + -0.0032973098568618298, + 0.006679064594209194, + -0.11752036958932877, + 0.025247232988476753, + 0.23064176738262177, + 0.25043538212776184, + 0.3474777638912201, + 0.2151806503534317, + 0.051294319331645966, + 0.16301114857196808, + 0.25422143936157227, + -0.1796918362379074, + -0.6128425598144531, + -0.42049655318260193, + 0.07740531116724014, + -0.007960617542266846, + 0.2504507601261139, + 0.2932300865650177, + -0.5157915949821472, + -1.2904177904129028, + -1.0362532138824463, + -0.22443994879722595, + 0.007411653641611338, + 0.16024430096149445, + 0.33939966559410095, + -0.2748318016529083, + -0.8487470149993896, + -0.5955387949943542, + 0.033155132085084915, + -0.09185351431369781, + -0.05639262869954109, + 0.17084303498268127, + 0.11292264610528946, + -0.046329669654369354, + 0.11495561897754669, + 0.31740760803222656, + -0.13903948664665222, + 0.05507560819387436, + 0.10180198401212692, + -0.1369788944721222, + -0.10618618875741959, + -0.001083499751985073, + 0.16340164840221405, + 0.07591762393712997, + 0.3417445123195648, + 0.27897438406944275, + -0.32192930579185486, + -0.5731648206710815, + -0.46150147914886475, + -0.03230089321732521, + 0.04096771031618118, + 0.22242987155914307, + 0.027000218629837036, + -0.4113498628139496, + -0.433158278465271, + -0.5252256393432617, + -0.3510502874851227, + -0.133863165974617, + -0.38554033637046814, + -0.45547229051589966, + 0.2475612610578537, + 1.154951572418213, + 0.8282179236412048, + -0.13197137415409088, + -0.03350961208343506, + -0.5282800197601318, + -0.5297923684120178, + 0.9037952423095703, + 2.516275405883789, + 2.086421489715576, + 0.3573826849460602, + -0.010694397613406181, + -0.31418153643608093, + -0.5325371026992798, + 0.48083701729774475, + 1.7732245922088623, + 1.2747145891189575, + -0.06401863694190979, + 0.14296381175518036, + 0.07267159968614578, + -0.28001847863197327, + -0.29204103350639343, + 0.12853951752185822, + -0.1998838633298874, + -0.6375644207000732, + 0.06310836225748062, + -0.020014479756355286, + -0.08150970935821533, + 0.08175478130578995, + 0.07667485624551773, + 0.0025236753281205893, + -0.08504530042409897, + -0.035742271691560745, + -0.1332666128873825, + -0.15150736272335052, + 0.18459312617778778, + 0.3363596200942993, + 0.2501969635486603, + 0.029292423278093338, + -0.060296736657619476, + -0.1142202764749527, + -0.05918247997760773, + 0.18826954066753387, + 0.2183520495891571, + 0.21247169375419617, + 0.14935970306396484, + 0.09923429787158966, + 0.21808095276355743, + 0.21930061280727386, + -0.060535889118909836, + -0.5729222297668457, + -0.4199080169200897, + 0.058897778391838074, + 0.050647757947444916, + 0.2784770131111145, + 0.2754706144332886, + -0.40136128664016724, + -1.3269731998443604, + -1.124815583229065, + -0.11878778040409088, + -0.005137663800269365, + 0.17839783430099487, + 0.2115524858236313, + -0.24165289103984833, + -0.9655010104179382, + -0.7425088286399841, + 0.0304054357111454, + -0.07012742757797241, + -0.015557953156530857, + 0.1128007024526596, + 0.18957749009132385, + -0.07996463775634766, + 0.09505810588598251, + 0.34419506788253784, + -0.3072076439857483, + 0.03868290036916733, + 0.11494885385036469, + 0.03748936951160431, + 0.0797261893749237, + -0.003397951368242502, + -0.07380004972219467, + -0.11507676541805267, + -0.10298885405063629, + 0.10698320716619492, + 0.06602972000837326, + 0.08226803690195084, + 0.0037747276946902275, + -0.162277951836586, + 0.01671667955815792, + 0.09137773513793945, + 0.18799471855163574, + 0.04144813120365143, + 0.1285877376794815, + 0.1820434182882309, + 0.04940629005432129, + 0.0991915687918663, + 0.10219171643257141, + -0.013141660951077938, + -0.051191627979278564, + 0.05468929558992386, + 0.087598517537117, + 0.15897324681282043, + 0.11863455921411514, + -0.00814050156623125, + -0.07701541483402252, + -0.14013728499412537, + -0.044140227138996124, + -0.05328791216015816, + 0.06760499626398087, + 0.12053386867046356, + 0.09780212491750717, + -0.053725965321063995, + -0.07915244251489639, + -0.0032519602682441473, + 0.019637396559119225, + 0.07848430424928665, + 0.019138827919960022, + 0.1460287868976593, + 0.1281038075685501, + 0.024417784065008163, + 0.059176862239837646, + 0.0658111497759819, + -0.016405148431658745, + -0.18877744674682617, + 0.16666102409362793, + 0.1610611230134964, + 0.08374520391225815, + 0.11570518463850021, + 0.11903064697980881, + 0.1294964700937271, + 0.06379758566617966, + 0.08417274057865143, + 0.12754113972187042, + 0.025328608229756355, + 0.05170705169439316, + 0.0835295170545578, + 0.07477264851331711, + 0.11244285851716995, + 0.11559426784515381, + 0.045258160680532455, + -0.14825093746185303, + -0.08153342455625534, + 0.06288623809814453, + 0.11952362209558487, + 0.11784297972917557, + 0.011141132563352585, + -0.21666541695594788, + -0.29976174235343933, + -0.2279169261455536, + -0.11828474700450897, + 0.12436322867870331, + 0.10465826094150543, + -0.09751085937023163, + -0.292611300945282, + -0.37374064326286316, + -0.31437963247299194, + -0.25637903809547424, + 0.06173908710479736, + 0.14131486415863037, + 0.008434675633907318, + -0.23816508054733276, + -0.30330890417099, + -0.22094152867794037, + -0.11608295142650604, + 0.13235151767730713, + 0.15353602170944214, + 0.15839524567127228, + 0.012247815728187561, + -0.08126968890428543, + -0.003756331978365779, + 0.10660683363676071, + 0.21976575255393982, + -0.04188326746225357, + 0.15462253987789154, + 0.06303395330905914, + 0.006879634689539671, + 0.008284888230264187, + 0.07084798067808151, + 0.1211942657828331, + 0.10190404951572418, + 0.02935362420976162, + -0.05645999684929848, + -0.16800500452518463, + -0.1850246787071228, + -0.09476880729198456, + -0.025327544659376144, + 0.054355036467313766, + -0.035813912749290466, + -0.18694879114627838, + -0.34871891140937805, + -0.3151862621307373, + -0.1943007856607437, + -0.09755205363035202, + 0.014881589449942112, + -0.14875493943691254, + -0.37112873792648315, + -0.37739917635917664, + -0.3241480886936188, + -0.2915399968624115, + -0.11268249899148941, + -0.019726404920220375, + -0.2510305941104889, + -0.38005372881889343, + -0.3622463345527649, + -0.2932804226875305, + -0.28574010729789734, + -0.1505027860403061, + -0.004947682376950979, + -0.18587322533130646, + -0.34759166836738586, + -0.28965193033218384, + -0.21052972972393036, + -0.18780536949634552, + -0.07400713860988617, + 0.11154936999082565, + -0.03556853160262108, + -0.1896934062242508, + -0.18135806918144226, + -0.10117948800325394, + -0.0393117293715477, + 0.06517928093671799, + -0.016659021377563477, + -0.011290309950709343, + -0.007930322550237179, + 0.008189777843654156, + 0.03678786754608154, + 0.021890517324209213, + 0.0034292477648705244, + 0.02200375869870186, + 0.0014921070542186499, + -0.0800287202000618, + -0.17657361924648285, + -0.18702608346939087, + -0.12880444526672363, + -0.022084584459662437, + 0.026420501992106438, + -0.023968446999788284, + -0.07948111742734909, + -0.16741475462913513, + -0.18733707070350647, + -0.16539834439754486, + -0.07347387820482254, + -0.009723886847496033, + -0.02016977220773697, + -0.061092622578144073, + -0.13145211338996887, + -0.15919029712677002, + -0.15043555200099945, + -0.10107766091823578, + 0.0016151965828612447, + 0.0627974420785904, + 0.08695066720247269, + 0.11727584898471832, + 0.11745581030845642, + 0.11329426616430283, + 0.0533670075237751, + -0.016355818137526512, + 0.008450252935290337, + 0.06448577344417572, + 0.1538505256175995, + 0.21232697367668152, + 0.14713847637176514, + 0.039088234305381775, + -0.015588105656206608, + 0.026483291760087013, + 0.060862988233566284, + 0.18265819549560547, + 0.23042462766170502, + 0.168768972158432, + 0.034099943935871124, + -0.018249109387397766, + -0.0321880541741848, + -0.03254542127251625, + -0.03061222843825817, + -0.0026304698549211025, + 0.017764942720532417, + 0.010707704350352287, + 0.009254949167370796, + -0.04533161595463753, + -0.1483704000711441, + -0.2637183666229248, + -0.2678598165512085, + -0.1737881749868393, + -0.049990858882665634, + 0.013515918515622616, + -0.054345693439245224, + -0.1467861533164978, + -0.24911582469940186, + -0.2831358015537262, + -0.22300836443901062, + -0.13739243149757385, + -0.017879672348499298, + -0.040345460176467896, + -0.09990613907575607, + -0.16936856508255005, + -0.2266550064086914, + -0.2020808756351471, + -0.1509508341550827, + 0.014163740910589695, + 0.07591170817613602, + 0.09185601025819778, + 0.10455341637134552, + 0.09514842182397842, + 0.09877350926399231, + 0.053898438811302185, + 0.005704578943550587, + 0.0591997392475605, + 0.13600079715251923, + 0.21777905523777008, + 0.2574957311153412, + 0.20117221772670746, + 0.11415109038352966, + -0.001181072206236422, + 0.09470006823539734, + 0.18978413939476013, + 0.3073742389678955, + 0.36875811219215393, + 0.3069853186607361, + 0.1708926260471344, + -0.0325310118496418, + -0.02656698040664196, + 0.016060845926404, + 0.02459372952580452, + 0.04165660962462425, + 0.033969976007938385, + 0.012855498120188713, + 0.030497560277581215, + 0.004896117839962244, + -0.030887477099895477, + -0.13454437255859375, + -0.1294785887002945, + -0.06398608535528183, + 0.016156472265720367, + 0.03577340394258499, + -0.0033482143189758062, + -0.07112833857536316, + -0.16465041041374207, + -0.1621057391166687, + -0.09478478878736496, + -0.03555302321910858, + -0.001592929707840085, + -0.01719600521028042, + -0.06598587334156036, + -0.1411861628293991, + -0.1496778130531311, + -0.11535074561834335, + -0.0905962884426117, + -0.013807609677314758, + 0.029542237520217896, + 0.039138730615377426, + 0.03988270089030266, + 0.02665030211210251, + 0.049553126096725464, + -0.0015685928519815207, + -0.018007200211286545, + 0.009533192962408066, + 0.06910547614097595, + 0.1034330427646637, + 0.15017645061016083, + 0.10221225768327713, + 0.020978443324565887, + -0.023747621104121208, + 0.02295384369790554, + 0.09313814342021942, + 0.1771395057439804, + 0.21169933676719666, + 0.17989481985569, + 0.05862005427479744, + -0.004540165886282921, + 0.021994179114699364, + -0.003493826137855649, + -0.000224211675231345, + 0.031808022409677505, + -0.05090906098484993, + 0.001970196608453989, + 0.01633802428841591, + 0.0049764602445065975, + 0.0006027702474966645, + -0.005952450912445784, + -0.009886081330478191, + -0.08520589768886566, + 0.030780712142586708, + 0.00037104589864611626, + 0.011886775493621826, + -0.023506291210651398, + 0.08029806613922119, + -0.005086984951049089, + -0.07738454639911652, + 0.06721897423267365, + -0.02397127076983452, + 0.006669329944998026, + -0.016343094408512115, + 0.06056324020028114, + 0.15656796097755432, + -0.49836501479148865, + 0.2475810945034027, + -0.009270203299820423, + -0.006855266634374857, + 0.0034896093420684338, + -0.027938276529312134, + 0.5722692012786865, + -1.1357109546661377, + 0.5644665956497192, + 0.015787361189723015, + -0.015141892246901989, + -0.0032788251992315054, + -0.04797150194644928, + 0.6196744441986084, + -1.1540743112564087, + 0.6065864562988281, + 0.0019708566833287477, + 0.006332532037049532, + 0.014192940667271614, + 0.03773411735892296, + 0.27323007583618164, + -0.594700813293457, + 0.2488076239824295, + -0.008853388018906116, + 0.005692378617823124, + 0.000576167949475348, + -0.027197014540433884, + 0.022015029564499855, + -0.02571249194443226, + 0.004507753532379866, + -0.002439734758809209, + -0.01994609646499157, + 0.03601142391562462, + 0.008136607706546783, + 0.01658148691058159, + -0.06548810750246048, + 0.022721221670508385, + -0.0038820707704871893, + -0.0007800398161634803, + 0.001392301986925304, + 0.09576108306646347, + -0.014628835022449493, + -0.14505760371685028, + 0.07135403156280518, + -0.00839388556778431, + -0.004555124789476395, + -0.04466082155704498, + 0.1456393599510193, + 0.3475525975227356, + -0.7879117131233215, + 0.36262738704681396, + 0.008226356469094753, + 0.0055343699641525745, + -0.061139706522226334, + 0.08975803852081299, + 0.9340736269950867, + -1.7307822704315186, + 0.796896755695343, + -0.024700213223695755, + -0.013090251013636589, + -0.05148586630821228, + 0.050525497645139694, + 0.927090048789978, + -1.7473385334014893, + 0.7727715373039246, + -0.005721901543438435, + 0.010676853358745575, + -0.012798544019460678, + 0.11131046712398529, + 0.4181194007396698, + -0.8475598096847534, + 0.33206430077552795, + 0.018843427300453186, + 0.0006885005859658122, + 0.027498219162225723, + 0.00207257061265409, + 0.0032615051604807377, + -0.021950624883174896, + -0.008452882058918476, + -0.007631891872733831, + -0.028561849147081375, + 0.04865337535738945, + -0.0023105579894036055, + -0.026170270517468452, + -0.011794357560575008, + 0.004327487666159868, + 0.01756221242249012, + 0.0011611212976276875, + -0.008793564513325691, + 0.0741758644580841, + -0.057649385184049606, + -0.006000686902552843, + -0.022717488929629326, + -0.0047143916599452496, + 0.005709030199795961, + -0.05611564591526985, + 0.05792170390486717, + 0.1873699128627777, + -0.3856293857097626, + 0.1371920108795166, + 0.018953431397676468, + 0.015250314958393574, + -0.0016827551880851388, + -0.08515634387731552, + 0.6517581939697266, + -0.9557326436042786, + 0.46986615657806396, + -0.014306572265923023, + -0.01625121757388115, + -0.016088897362351418, + -0.13429272174835205, + 0.6437729001045227, + -1.0167845487594604, + 0.5061463117599487, + 0.00879831612110138, + -0.008598369546234608, + 0.02747279778122902, + 0.007245234213769436, + 0.2527446150779724, + -0.47163763642311096, + 0.15560215711593628, + 0.005050336476415396, + -0.024848125874996185, + -0.0006449198699556291, + -0.008673148229718208, + -0.06940636038780212, + -0.016248086467385292, + 0.1250494420528412, + 0.026387182995676994, + 0.009615709073841572, + -0.0025482974015176296, + -0.04534498229622841, + -0.2626228630542755, + -0.2753732204437256, + 0.052055053412914276, + 0.010792221873998642, + 0.007360508665442467, + 0.10271529853343964, + 0.1113760769367218, + -0.31120774149894714, + -0.49849262833595276, + -0.2206398844718933, + 0.04994913563132286, + 0.054614756256341934, + 0.27786919474601746, + 0.56647789478302, + 0.20970205962657928, + -0.22717078030109406, + -0.17321231961250305, + -0.07836200296878815, + -0.09607961028814316, + 0.10685958713293076, + 0.40848156809806824, + 0.34087467193603516, + -0.005242985673248768, + -0.0682876780629158, + -0.0694413110613823, + -0.1886596381664276, + -0.04473332315683365, + 0.18096435070037842, + 0.1961163580417633, + 0.0014336564345285296, + 0.014584851451218128, + 0.0462430939078331, + -0.1556192934513092, + -0.12809665501117706, + 0.0213937908411026, + 0.10984069108963013, + -0.023050926625728607, + -0.013447473756968975, + 0.007857509888708591, + -0.027979737147688866, + -0.04768490046262741, + -0.09350565075874329, + -0.1659490317106247, + 0.007927919737994671, + 0.26641780138015747, + 0.03398526459932327, + 0.02118881419301033, + -0.006898822728544474, + -0.15209096670150757, + -0.4939330220222473, + -0.42655149102211, + 0.08215854316949844, + 0.02115131914615631, + 0.08892140537500381, + 0.2164168655872345, + 0.12431265413761139, + -0.47813764214515686, + -0.6588870882987976, + -0.3097454905509949, + 0.0837375745177269, + 0.1548176258802414, + 0.49661529064178467, + 0.7337944507598877, + 0.1966201215982437, + -0.29367199540138245, + -0.2547970116138458, + -0.11655519157648087, + -0.11720486730337143, + 0.21941716969013214, + 0.5902130603790283, + 0.42572125792503357, + 0.020460324361920357, + -0.12768393754959106, + -0.12030418962240219, + -0.2582310736179352, + -0.0355166494846344, + 0.2766987085342407, + 0.28080257773399353, + 0.08665957301855087, + 0.027141664177179337, + 0.02690703421831131, + -0.25276950001716614, + -0.23180679976940155, + 0.015180152840912342, + 0.11523276567459106, + 0.041165824979543686, + 0.017444534227252007, + 0.0009439520072191954, + -0.025763530284166336, + -0.022880665957927704, + -0.024819007143378258, + -0.04901815578341484, + 0.027672944590449333, + 0.11211585998535156, + 0.024664992466568947, + -0.010093869641423225, + 0.009466213174164295, + -0.043605536222457886, + -0.17007218301296234, + -0.1366996467113495, + 0.08740171790122986, + -0.014591479673981667, + -0.0031720874831080437, + 0.0835830345749855, + 0.028662094846367836, + -0.21436777710914612, + -0.24753160774707794, + -0.06092096120119095, + 0.03788171336054802, + 0.04295210912823677, + 0.19064708054065704, + 0.3095722496509552, + 0.08003447204828262, + -0.09509303420782089, + -0.05495578795671463, + -0.052218906581401825, + -0.07204427570104599, + 0.07710819691419601, + 0.18033725023269653, + 0.0834946483373642, + -0.049662720412015915, + -0.06561554968357086, + -0.013351643458008766, + -0.11217659711837769, + 0.031957074999809265, + 0.12180440872907639, + 0.06891122460365295, + -0.013705568388104439, + 0.0011150656500831246, + 0.03281388059258461, + -0.11285661906003952, + -0.06422404199838638, + 0.04218210279941559, + 0.014165353029966354, + -0.006244795396924019, + 0.01745765097439289, + 0.08924975246191025, + 0.01710040494799614, + -0.14013372361660004, + -0.21913501620292664, + 0.03613810986280441, + 0.14273521304130554, + 0.05801931768655777, + 0.021427493542432785, + 0.23185034096240997, + 0.2427377849817276, + -0.4384608566761017, + -0.7205182909965515, + -0.18313364684581757, + 0.033575087785720825, + -0.0809125304222107, + 0.04173902049660683, + 0.7251381874084473, + 1.1058244705200195, + -0.015065462328493595, + -0.6434917449951172, + -0.3080260753631592, + -0.090518057346344, + -0.3659006655216217, + -0.4520319700241089, + 0.5924424529075623, + 1.4148176908493042, + 0.5285682082176208, + -0.027211233973503113, + -0.07359065115451813, + -0.08583711832761765, + -0.5631492137908936, + -1.0246236324310303, + -0.1835726648569107, + 0.3307121694087982, + 0.22562064230442047, + 0.05237145721912384, + 0.13263091444969177, + 0.13899636268615723, + -0.1626550555229187, + -0.3918432295322418, + -0.03585565462708473, + 0.06904798001050949, + 0.029870154336094856, + 0.04289601743221283, + 0.05758490040898323, + 0.10055387020111084, + -0.011962685734033585, + -0.13269846141338348, + 0.0012237781193107367, + 0.05511128902435303, + 0.03764793649315834, + -0.07580426335334778, + -0.1750984787940979, + 0.0189101230353117, + 0.08156414330005646, + 0.01691802591085434, + 0.004023027140647173, + 0.18009696900844574, + 0.22744491696357727, + -0.38747039437294006, + -0.6413040161132812, + -0.19208981096744537, + 0.01971367374062538, + -0.036756888031959534, + 0.004946697968989611, + 0.7331712245941162, + 1.1178003549575806, + 0.03220612183213234, + -0.5881579518318176, + -0.24453559517860413, + -0.11856977641582489, + -0.43593257665634155, + -0.5339378118515015, + 0.49467018246650696, + 1.3376370668411255, + 0.5238692164421082, + 0.04584280773997307, + 0.004761924035847187, + -0.032823480665683746, + -0.5419207811355591, + -1.0093209743499756, + -0.19847697019577026, + 0.20687319338321686, + 0.12301573902368546, + 0.07981085777282715, + 0.14125365018844604, + 0.19885297119617462, + -0.1678825318813324, + -0.4042292535305023, + 0.004483209457248449, + 0.03009556047618389, + 0.010802071541547775, + 0.005967534612864256, + 0.0892769992351532, + 0.07342032343149185, + -0.0588892325758934, + -0.09044717997312546, + 0.06307072192430496, + -0.012583961710333824, + -0.006880680099129677, + 0.0030021765269339085, + 0.01633061282336712, + 0.06990820169448853, + 0.0070900083519518375, + -0.03546716272830963, + -0.022131899371743202, + -0.02906683459877968, + 0.010664403438568115, + -0.18731924891471863, + -0.158770352602005, + 0.08571326732635498, + 0.039154618978500366, + 0.032578419893980026, + -0.005781106185168028, + 0.17460086941719055, + 0.2787456810474396, + -0.13416190445423126, + -0.23966801166534424, + -0.004878139588981867, + 0.02796499989926815, + -0.06610933691263199, + -0.19162042438983917, + 0.11163146048784256, + 0.371842622756958, + 0.06444671750068665, + 0.016595548018813133, + 0.01164282951503992, + 0.08330011367797852, + -0.03192862868309021, + -0.2867860198020935, + -0.07080501317977905, + -0.016348646953701973, + -0.06306261569261551, + -0.016291450709104538, + 0.010558445006608963, + 0.13014638423919678, + 0.06202690303325653, + -0.03361419215798378, + 0.0691375732421875, + 0.003561250865459442, + -0.013095442205667496, + -0.050333790481090546, + -0.019117066636681557, + 0.0012089330703020096, + -0.004555183462798595, + -0.022682132199406624, + 0.04747068136930466, + -0.06425238400697708, + -0.0010437731398269534, + -0.0071629988960921764, + -0.04302623122930527, + -0.04830477759242058, + -0.04069536179304123, + -0.06627446413040161, + -0.011470981873571873, + 0.03961857780814171, + 0.026594260707497597, + -0.020662540569901466, + -0.05999285355210304, + -0.053548794239759445, + -0.025959201157093048, + -0.015834785997867584, + 0.013910192996263504, + -0.015868371352553368, + -0.056620921939611435, + -0.06785159558057785, + -0.061030179262161255, + -0.03560228645801544, + -0.04177624359726906, + -0.024657463654875755, + -0.04889696091413498, + 0.004557035397738218, + 0.15414470434188843, + 0.21642963588237762, + 0.035425592213869095, + -0.04339970648288727, + -0.05034525692462921, + -0.08522290736436844, + 0.10652441531419754, + 0.6791198253631592, + 0.7785530686378479, + 0.19941796362400055, + -0.05430706962943077, + -0.02583213709294796, + -0.055139996111392975, + 0.17940561473369598, + 0.6757862567901611, + 0.8240399360656738, + 0.25826773047447205, + -0.062254682183265686, + -0.026456547901034355, + -0.027271386235952377, + -0.0026193747762590647, + 0.11893659085035324, + 0.1915995329618454, + 0.013776157051324844, + 0.08452087640762329, + 0.009950258769094944, + 0.01774573139846325, + 0.06609759479761124, + 0.06512798368930817, + 0.07601971179246902, + 0.09192144125699997, + 0.007696932647377253, + -0.056120894849300385, + -0.03937293961644173, + 0.043086692690849304, + 0.055803027004003525, + 0.08208976686000824, + 0.03658852353692055, + 0.025779196992516518, + -0.0340605266392231, + 0.03186321631073952, + 0.09720855951309204, + 0.10651290416717529, + 0.09562067687511444, + 0.08120692521333694, + 0.06832587718963623, + 0.03940538689494133, + 0.09561086446046829, + -0.03726261481642723, + -0.3520663380622864, + -0.4187469184398651, + -0.11643502116203308, + 0.06203937157988548, + 0.056670401245355606, + 0.11540547758340836, + -0.2742924690246582, + -1.1301417350769043, + -1.2482489347457886, + -0.4411431849002838, + 0.08538330346345901, + 0.036888301372528076, + 0.08759869635105133, + -0.32129940390586853, + -1.1163593530654907, + -1.26430082321167, + -0.48638999462127686, + 0.1056363582611084, + 0.042436979711055756, + 0.07075526565313339, + -0.08341801166534424, + -0.30567145347595215, + -0.39268070459365845, + -0.10187282413244247, + -0.02507772110402584, + -0.0044433241710066795, + -0.009278317913413048, + -0.02964872494339943, + -0.018799586221575737, + -0.03760084509849548, + -0.030454028397798538, + -0.004638439975678921, + 0.026587119325995445, + 0.0095819728448987, + -0.007110759150236845, + -0.006491640582680702, + -0.028083719313144684, + -0.009543413296341896, + -0.005706887226551771, + 0.013012710027396679, + -0.010281933471560478, + -0.0544208325445652, + -0.023230208083987236, + -0.05344587564468384, + -0.04052828997373581, + -0.028035156428813934, + -0.011922319419682026, + -0.045427750796079636, + 0.020700184628367424, + 0.2117788940668106, + 0.21090814471244812, + 0.07214333862066269, + -0.019348343834280968, + -0.014455118216574192, + -0.03561105206608772, + 0.17339389026165009, + 0.49509289860725403, + 0.5219546556472778, + 0.26121678948402405, + -0.029803339391946793, + -0.013761913403868675, + -0.04028521850705147, + 0.17008572816848755, + 0.45583003759384155, + 0.4757367670536041, + 0.22357690334320068, + -0.050064269453287125, + -0.021086007356643677, + -0.039873600006103516, + 0.06433176249265671, + 0.20187893509864807, + 0.2078690379858017, + 0.07802058011293411, + 0.022050827741622925, + -0.05272649601101875, + -0.024311071261763573, + -0.12387345731258392, + -0.20065246522426605, + 0.0262442696839571, + 0.20101603865623474, + 0.056791841983795166, + -0.008266052231192589, + 0.025132112205028534, + -0.23289933800697327, + -0.5296569466590881, + -0.282010018825531, + 0.025113720446825027, + 0.13172000646591187, + 0.16999290883541107, + 0.31588253378868103, + 0.05583454668521881, + -0.5321000814437866, + -0.5585035085678101, + -0.23885560035705566, + 0.0461968369781971, + 0.13807418942451477, + 0.6536149382591248, + 0.6385176777839661, + -0.15636183321475983, + -0.5484278798103333, + -0.5470613241195679, + -0.06269911676645279, + -0.06726553291082382, + 0.5561463236808777, + 1.0985187292099, + 0.6801460385322571, + 0.12841203808784485, + -0.21693651378154755, + -0.19168342649936676, + -0.43073776364326477, + -0.15226863324642181, + 0.41150590777397156, + 0.47421786189079285, + 0.25146934390068054, + -0.017203813418745995, + -0.09694849699735641, + -0.4082376956939697, + -0.3549531400203705, + -0.023591510951519012, + 0.12086013704538345, + 0.08050766587257385, + -0.044960521161556244, + -0.0031193571630865335, + 0.014398006722331047, + -0.005931032355874777, + 0.01548685971647501, + 0.05407734215259552, + -0.006386967841535807, + 0.021660227328538895, + 0.01656133122742176, + 0.002835798542946577, + 0.0008500503608956933, + 0.021802745759487152, + 0.13470955193042755, + 0.06802596151828766, + 0.0033256933093070984, + 0.03037848509848118, + 0.054654810577631, + -0.034221138805150986, + 0.015171438455581665, + 0.23395732045173645, + 0.24771827459335327, + 0.16352902352809906, + -0.07505007833242416, + -0.0814652070403099, + -0.21493901312351227, + -0.3109704852104187, + 0.013416547328233719, + 0.12807825207710266, + 0.12044191360473633, + -0.007915153168141842, + 0.0100772799924016, + -0.15165796875953674, + -0.4013277292251587, + -0.24811144173145294, + -0.06641282886266708, + 0.022568246349692345, + 0.061083581298589706, + 0.09920243173837662, + 0.0695505365729332, + -0.12213064730167389, + -0.12606006860733032, + -0.04593949392437935, + -0.040190644562244415, + 0.03899035230278969, + 0.12688779830932617, + 0.114081971347332, + -0.013348283246159554, + 0.03325115144252777, + 0.007111718878149986, + 0.048056699335575104, + -0.003726312192156911, + 0.05401211231946945, + 0.05355936661362648, + 0.21303032338619232, + 0.2944865822792053, + -0.13604623079299927, + -0.3770989775657654, + -0.0808275118470192, + -0.006103217601776123, + -0.02005188539624214, + 0.37605899572372437, + 0.7776278853416443, + 0.32064270973205566, + -0.23708422482013702, + -0.23380732536315918, + -0.22103570401668549, + -0.45596328377723694, + 0.07213663309812546, + 0.9384943246841431, + 0.8762810230255127, + 0.3557227551937103, + -0.09239326417446136, + -0.25462013483047485, + -0.9858288168907166, + -0.9860153198242188, + 0.2600172162055969, + 0.7731484770774841, + 0.7665594816207886, + 0.14806008338928223, + 0.13109923899173737, + -0.6917864680290222, + -1.580305814743042, + -0.9557210803031921, + -0.16357193887233734, + 0.3189502954483032, + 0.28703632950782776, + 0.5599567890167236, + 0.2459551841020584, + -0.5451022982597351, + -0.6926754713058472, + -0.4368602931499481, + 0.027606861665844917, + 0.025857241824269295, + 0.5376880764961243, + 0.535673975944519, + 0.09012678265571594, + -0.14688564836978912, + -0.1812361180782318, + 0.050619762390851974, + 0.021388273686170578, + -0.05923623591661453, + -0.006538081914186478, + 0.05171535536646843, + -0.051560595631599426, + -0.007643367163836956, + 0.027748188003897667, + 0.0024676925968378782, + -0.008760283701121807, + 0.13039670884609222, + 0.18568934500217438, + 0.06342563778162003, + 0.030788781121373177, + -0.004423442296683788, + -0.041261281818151474, + 0.013299684040248394, + 0.22491391003131866, + 0.27831292152404785, + 0.0883866474032402, + 0.048967570066452026, + 0.0012756097130477428, + -0.03215779736638069, + 0.02710782177746296, + 0.20178261399269104, + 0.22446107864379883, + 0.06052157282829285, + 0.019020315259695053, + 0.02715166099369526, + -0.03146626800298691, + -0.017960363999009132, + 0.11820292472839355, + 0.16114193201065063, + 0.05221821367740631, + -0.02201441302895546, + -0.026308327913284302, + 0.008580431342124939, + -0.02444308064877987, + 0.061380185186862946, + 0.11184953153133392, + 0.006053542252629995, + -0.03248603641986847, + -0.037558719515800476, + 0.01881473697721958, + -0.02349863201379776, + 0.02150980569422245, + 0.09881952404975891, + 0.03962325677275658, + -0.0031782283913344145, + -0.0030868228059262037, + -0.007606725674122572, + -0.06136326491832733, + 0.022755015641450882, + 0.09683670848608017, + 0.0016674631042405963, + 0.01306125894188881, + 0.011335537768900394, + -0.01769089885056019, + 0.005807302892208099, + 0.19103741645812988, + 0.2631426155567169, + 0.10424992442131042, + 0.025223100557923317, + -0.024689532816410065, + -0.03370697423815727, + 0.0512213259935379, + 0.2983294129371643, + 0.37597405910491943, + 0.18788966536521912, + 0.056492965668439865, + -0.006051253993064165, + -0.027141474187374115, + 0.06733105331659317, + 0.29171472787857056, + 0.32160115242004395, + 0.14176633954048157, + 0.008538221009075642, + -0.013039524666965008, + -0.04279422387480736, + 0.03345612436532974, + 0.19111940264701843, + 0.25728005170822144, + 0.09830093383789062, + -0.03371569141745567, + -0.05277566984295845, + -0.0011038694065064192, + -0.013657800853252411, + 0.10037966072559357, + 0.1724642813205719, + 0.04436478391289711, + -0.02240786701440811, + -0.02181128039956093, + 0.019526727497577667, + -0.050060197710990906, + 0.017275504767894745, + 0.07785085588693619, + -0.001727179973386228, + -0.0014453287003561854, + 0.019352080300450325, + -0.003202121239155531, + -0.04241566359996796, + 0.005586653482168913, + 0.06037082523107529, + 0.014115821570158005, + -0.00568200321868062, + 0.018071964383125305, + -0.0007147599244490266, + 0.011219227686524391, + 0.10582104325294495, + 0.15557849407196045, + 0.06189450994133949, + 0.014160261489450932, + 0.00814653467386961, + -0.028064200654625893, + 0.026086319237947464, + 0.1474728286266327, + 0.18273885548114777, + 0.06638553738594055, + 0.019263381138443947, + 0.028977060690522194, + -0.02551555074751377, + 0.01937149092555046, + 0.12000202387571335, + 0.1285850703716278, + 0.047506313771009445, + -0.011383740231394768, + 0.02826755866408348, + -0.009583448991179466, + -0.02093282900750637, + 0.07994058728218079, + 0.0926218256354332, + 0.0318426676094532, + -0.024409465491771698, + 0.020994359627366066, + 0.03295197710394859, + -0.034276511520147324, + 0.037398867309093475, + 0.0794353187084198, + 0.022805212065577507, + 0.0015407208120450377, + 0.013169347308576107, + 0.038584139198064804, + -0.002118688775226474, + 0.03358406573534012, + 0.09085306525230408, + 0.04255761206150055, + 0.010275964625179768, + 0.025351760908961296, + 0.04205995053052902, + 0.1319226324558258, + 0.049708493053913116, + -0.03743802383542061, + -0.04293569549918175, + -0.07646205276250839, + -0.04986324533820152, + 0.15992362797260284, + 0.011027384549379349, + -0.32150742411613464, + -0.3761928677558899, + -0.1654653549194336, + -0.08728181570768356, + 0.044714685529470444, + -0.007500737439841032, + -0.41376256942749023, + -0.6625701189041138, + -0.21809393167495728, + 0.10641554743051529, + 0.09274336695671082, + 0.10189083218574524, + -0.1175118163228035, + -0.2905261516571045, + -0.06248515099287033, + 0.4791955053806305, + 0.49865299463272095, + 0.23415400087833405, + 0.12729482352733612, + -0.05814196541905403, + -0.003843356389552355, + 0.16410382091999054, + 0.40895968675613403, + 0.22034852206707, + 0.021014101803302765, + -0.05658271536231041, + -0.012199933640658855, + 0.034277670085430145, + 0.09565535932779312, + 0.18921032547950745, + 0.010441004298627377, + -0.07427560538053513, + -0.09049694985151291, + -0.00554919708520174, + 0.021386168897151947, + 0.0297325998544693, + 0.06431404501199722, + -0.07367311418056488, + -0.08734254539012909, + -0.059512097388505936, + 0.11382041126489639, + 0.19622667133808136, + 0.02534862980246544, + -0.09704668819904327, + -0.10857658833265305, + -0.10241919010877609, + -0.037928055971860886, + 0.17917697131633759, + -0.0396210141479969, + -0.472421795129776, + -0.5453466176986694, + -0.23921693861484528, + -0.06353127211332321, + 0.033679377287626266, + -0.011634309776127338, + -0.523267924785614, + -0.8400278091430664, + -0.3026646375656128, + 0.17986975610256195, + 0.20296970009803772, + 0.14190459251403809, + -0.12953802943229675, + -0.3968985378742218, + -0.13779792189598083, + 0.548722505569458, + 0.7039015293121338, + 0.4025704264640808, + 0.19535738229751587, + -0.08568660169839859, + -0.0589536651968956, + 0.1868993639945984, + 0.5782724618911743, + 0.43018248677253723, + 0.08876730501651764, + -0.10219226032495499, + -0.04660544916987419, + 0.018129168078303337, + 0.14359626173973083, + 0.3174169361591339, + 0.07668197154998779, + -0.13716676831245422, + -0.2058524489402771, + -0.023707473650574684, + 0.03213014453649521, + 0.06718969345092773, + 0.0917893499135971, + -0.10766899585723877, + -0.206499844789505, + -0.12713390588760376, + -0.03174767270684242, + 0.046395305544137955, + 0.018318502232432365, + -0.002416136907413602, + -0.027143845334649086, + -0.0036621293984353542, + -0.019220896065235138, + 0.05427055433392525, + 0.05058867856860161, + -0.05274957790970802, + -0.11321325600147247, + -0.07062514126300812, + -0.01720590703189373, + -0.00901520811021328, + 0.01746262051165104, + -0.08946436643600464, + -0.2304752618074417, + -0.1021895483136177, + 0.013501768000423908, + 0.029721295461058617, + -0.010094762779772282, + 0.009764805436134338, + -0.06424269080162048, + -0.03032868541777134, + 0.13044297695159912, + 0.12166891992092133, + 0.07157951593399048, + 0.029467372223734856, + -0.03827595338225365, + -0.031337328255176544, + -0.026486340910196304, + 0.05953369289636612, + 0.029497025534510612, + 0.022669093683362007, + -0.01055963709950447, + -0.025020133703947067, + 0.002589448355138302, + 0.017152298241853714, + 0.062067389488220215, + 0.008266719058156013, + 0.00563611788675189, + -0.0044869836419820786, + 0.003065212396904826, + 0.014371387660503387, + 0.013636622577905655, + 0.021183570846915245, + -0.012462744489312172, + -0.02493542619049549, + 0.009652925655245781, + -0.09309647232294083, + -0.09614148736000061, + 0.020278261974453926, + 0.262399286031723, + 0.0025974283926188946, + -0.09532646089792252, + -0.0391894206404686, + -0.003332971129566431, + -0.25919869542121887, + -0.2104814499616623, + 0.5975717306137085, + 0.20378711819648743, + -0.20521192252635956, + 0.005045099183917046, + 0.16707547008991241, + -0.08322134613990784, + -1.1734565496444702, + 0.4060916006565094, + 0.9109339714050293, + -0.22450445592403412, + -0.14085394144058228, + 0.19534644484519958, + 0.6220589280128479, + -1.0614460706710815, + -1.2444484233856201, + 1.1965712308883667, + 0.5032565593719482, + -0.26604175567626953, + -0.13583213090896606, + 0.6453277468681335, + 0.4994892477989197, + -1.7917202711105347, + -0.15182015299797058, + 0.7891079783439636, + 0.10711944103240967, + -0.11587982624769211, + 0.08287231624126434, + 0.7848142981529236, + -0.1764022707939148, + -1.0492321252822876, + 0.15281184017658234, + 0.3100045919418335, + -0.0461110882461071, + -0.06824400275945663, + 0.25544390082359314, + 0.3444065451622009, + -0.3189513683319092, + -0.3503313362598419, + 0.05462741479277611, + -0.041028521955013275, + 0.00624969182536006, + -0.0014677124563604593, + 0.10383514314889908, + -0.03467189520597458, + -0.03946290910243988, + 0.012734192423522472, + -0.003676857566460967, + -0.1616411954164505, + -0.034441810101270676, + 0.34758275747299194, + -0.0017601394793018699, + -0.17407774925231934, + 0.05167992413043976, + 0.12394318729639053, + -0.018228475004434586, + -0.71342533826828, + 0.39672648906707764, + 0.4870489537715912, + -0.27272745966911316, + -0.02687050960958004, + 0.09090551733970642, + 0.46698617935180664, + -0.6089348196983337, + -0.7488552331924438, + 0.8327828645706177, + 0.19947239756584167, + -0.17806877195835114, + -0.09197663515806198, + 0.3198661506175995, + 0.42619431018829346, + -1.1321229934692383, + -0.05452701821923256, + 0.4155597984790802, + -0.001295815804041922, + -0.06596186012029648, + -0.05821318179368973, + 0.4515152871608734, + 0.06321248412132263, + -0.6065720319747925, + 0.10882120579481125, + 0.13767170906066895, + 0.01809641905128956, + -0.070295050740242, + 0.04035783186554909, + 0.22459834814071655, + -0.048405971378088, + -0.14622822403907776, + -0.01119917817413807, + 0.00666345190256834, + 0.04815478250384331, + -0.017866114154458046, + -0.04813665896654129, + -0.02366034686565399, + 0.03589487820863724, + -0.0066519430838525295, + 0.0004148671869188547, + -0.014153627678751945, + 0.04403751716017723, + 0.04098428785800934, + -0.10525348782539368, + -0.0078808031976223, + 0.0444580540060997, + -0.027595041319727898, + 0.010916849598288536, + -0.1390431821346283, + 0.20334453880786896, + -0.006475532427430153, + -0.16053295135498047, + 0.06964287906885147, + -0.025649840012192726, + 0.12622858583927155, + -0.09694403409957886, + -0.09791161119937897, + 0.2617567479610443, + -0.06268735229969025, + -0.03128494322299957, + -0.017743078991770744, + -0.02372320368885994, + 0.2195650041103363, + -0.2456466406583786, + 0.031090563163161278, + 0.010196326300501823, + -0.04323133826255798, + 0.02746250294148922, + -0.079569011926651, + 0.06894756853580475, + 0.11414647102355957, + -0.12175147980451584, + 0.025397513061761856, + 0.006027852185070515, + 0.013360690325498581, + -0.024561991915106773, + -0.10966529697179794, + 0.04913714900612831, + 0.09801583737134933, + 0.00013951699656900018, + -0.03194398432970047, + 0.002382949460297823, + -0.003335593966767192, + 0.023621119558811188, + 0.024585755541920662, + -0.016027197241783142, + -0.02846739999949932, + -0.012949706055223942, + -0.020852699875831604, + -0.016913240775465965, + 0.016088848933577538, + 0.141468346118927, + 0.07285624742507935, + -0.008997173048555851, + -0.033306676894426346, + -0.03418722003698349, + -0.15127411484718323, + -0.047440435737371445, + 0.2687169015407562, + 0.17237843573093414, + 0.03505166247487068, + -0.06994523108005524, + -0.031143782660365105, + -0.3024960458278656, + -0.1552041918039322, + 0.33517369627952576, + 0.28441429138183594, + 0.06471730768680573, + -0.0613982267677784, + -0.02271229960024357, + -0.29379361867904663, + -0.3259792923927307, + 0.16062304377555847, + 0.29220375418663025, + 0.10862076282501221, + -0.005909152328968048, + 0.049116987735033035, + -0.20140305161476135, + -0.3278747797012329, + -0.02566053718328476, + 0.14338354766368866, + 0.006411381531506777, + -0.007274044211953878, + 0.08232597261667252, + -0.04198717698454857, + -0.17330540716648102, + -0.01131037063896656, + 0.08018575608730316, + -0.02374250255525112, + -0.002276432001963258, + 0.00019528658594936132, + -0.024716932326555252, + 0.026509074494242668, + 0.08361849933862686, + 0.012956380844116211, + -0.06030649319291115, + -0.020338360220193863, + -0.03577016666531563, + -0.06858085840940475, + 0.008245388977229595, + 0.25225168466567993, + 0.16135559976100922, + -0.03690743073821068, + -0.09188401699066162, + -0.10410526394844055, + -0.25971388816833496, + -0.07926154136657715, + 0.3933144509792328, + 0.33186599612236023, + 0.059405017644166946, + -0.11824909597635269, + -0.10528354346752167, + -0.4808295667171478, + -0.25224801898002625, + 0.4267246127128601, + 0.4853539764881134, + 0.16933484375476837, + -0.073345847427845, + -0.02648857608437538, + -0.4723232388496399, + -0.4904792010784149, + 0.1938265562057495, + 0.44070878624916077, + 0.22439399361610413, + 0.03877745941281319, + 0.08536087721586227, + -0.31432414054870605, + -0.5158097743988037, + -0.09537900239229202, + 0.20227058231830597, + 0.07895126938819885, + 0.059195615351200104, + 0.14728911221027374, + -0.059377528727054596, + -0.2884902060031891, + -0.12288203090429306, + 0.05220698565244675, + -0.045279599726200104, + 0.019795719534158707, + -0.009819806553423405, + -0.013713877648115158, + 0.0012175077572464943, + 0.03281072899699211, + 0.0017424041870981455, + -0.028847966343164444, + -0.0032059827353805304, + -0.020358575507998466, + 0.0009416870889253914, + -0.007760196924209595, + 0.07921157032251358, + 0.03826644644141197, + -0.02976907789707184, + -0.03300238028168678, + -0.017963968217372894, + -0.055836472660303116, + -0.03299689665436745, + 0.15166012942790985, + 0.06786434352397919, + 0.008589516393840313, + -0.05790036544203758, + -0.0029997669626027346, + -0.14070068299770355, + -0.08799122273921967, + 0.19680362939834595, + 0.14703704416751862, + 0.03569985553622246, + -0.02847554162144661, + 0.03601403906941414, + -0.1339161992073059, + -0.20527805387973785, + 0.1060374304652214, + 0.16269326210021973, + 0.0575268417596817, + 0.0029672966338694096, + 0.018848277628421783, + -0.1029881089925766, + -0.19446833431720734, + -0.055140964686870575, + 0.09632515162229538, + 0.01196608692407608, + 0.01994382217526436, + 0.0030014747753739357, + 0.0029817752074450254, + -0.09395840018987656, + -0.038611751049757004, + 0.03793984279036522, + -0.006295992527157068, + 0.01736803539097309, + -0.0961727425456047, + 0.1318971812725067, + 0.00169672432821244, + 0.02773740515112877, + -0.03737606480717659, + -0.02413480542600155, + -0.07371329516172409, + 0.04465596005320549, + 0.34972262382507324, + 0.269726425409317, + 0.14907677471637726, + -0.15323053300380707, + -0.24987848103046417, + -0.32931339740753174, + 0.05209995433688164, + 0.5192161798477173, + 0.5108750462532043, + 0.2627664804458618, + -0.26889729499816895, + -0.49891141057014465, + -0.5081418752670288, + 0.13535383343696594, + 0.7318623661994934, + 0.7116816639900208, + 0.2973657250404358, + -0.38982102274894714, + -0.7131763100624084, + -0.5916072130203247, + 0.1200462281703949, + 0.7752112746238708, + 0.6947993636131287, + 0.21100594103336334, + -0.5576100945472717, + -0.7797606587409973, + -0.6058254837989807, + 0.08617032319307327, + 0.6432424187660217, + 0.522933304309845, + 0.16018754243850708, + -0.5134027004241943, + -0.6838728189468384, + -0.5088241100311279, + 0.10101393610239029, + 0.4321025311946869, + 0.3330003023147583, + 0.10116448998451233, + -0.2786642014980316, + -0.4134466052055359, + -0.3247438967227936, + 0.009768294170498848, + 0.008712833747267723, + -0.029476309195160866, + 0.007709377445280552, + 0.025279967114329338, + 0.01615188643336296, + 0.01585867628455162, + -0.0031516810413450003, + -0.06462288647890091, + -0.055517926812171936, + -0.013180199079215527, + -0.014849795028567314, + 0.05535515025258064, + 0.04162544384598732, + 0.0022392054088413715, + -0.09408581256866455, + -0.07889631390571594, + -0.032870080322027206, + 0.0382377915084362, + 0.07495865970849991, + 0.08439645916223526, + 0.008036677725613117, + -0.1167779192328453, + -0.10782196372747421, + -0.06854722648859024, + 0.06310252100229263, + 0.09643208235502243, + 0.08629462122917175, + -0.016969647258520126, + -0.10456187278032303, + -0.10410942137241364, + -0.017384463921189308, + 0.03931420296430588, + 0.11296819150447845, + 0.08688211441040039, + -0.018024103716015816, + -0.0985492691397667, + -0.10534191876649857, + 0.016594627872109413, + 0.024613894522190094, + 0.09626104682683945, + 0.056779902428388596, + -0.01314453687518835, + -0.1173979789018631, + -0.07576211541891098, + -0.00741730397567153, + 0.04463285952806473, + 0.06365535408258438, + 0.029472019523382187, + 0.06097950413823128, + -0.0884813666343689, + -0.020469073206186295, + -0.004499382339417934, + 0.006147715728729963, + 0.0061135985888540745, + 0.046618249267339706, + -0.024977274239063263, + -0.2809607684612274, + -0.20776452124118805, + -0.10792756825685501, + 0.10520339012145996, + 0.2195160835981369, + 0.27846819162368774, + -0.0425783209502697, + -0.4539273977279663, + -0.4210258722305298, + -0.24160517752170563, + 0.2377386838197708, + 0.4254952371120453, + 0.40258923172950745, + -0.08894401043653488, + -0.6261403560638428, + -0.6177268624305725, + -0.2941279113292694, + 0.36115866899490356, + 0.6176164746284485, + 0.5170959234237671, + -0.12760992348194122, + -0.6392932534217834, + -0.6288641095161438, + -0.20397846400737762, + 0.4859760105609894, + 0.7283636927604675, + 0.5233575105667114, + -0.08038943260908127, + -0.513219952583313, + -0.4611802101135254, + -0.08622774481773376, + 0.41959214210510254, + 0.6145293116569519, + 0.4252074360847473, + -0.08993257582187653, + -0.3586794435977936, + -0.23889268934726715, + -0.07402873039245605, + 0.2362663745880127, + 0.33187127113342285, + 0.24442552030086517, + -0.10037989169359207, + -0.1200498715043068, + -0.06188809871673584, + 0.009648810140788555, + 0.07703708112239838, + -0.07734857499599457, + -0.16337357461452484, + -0.13160429894924164, + -0.037760209292173386, + 0.10750655829906464, + 0.21975228190422058, + 0.21332265436649323, + 0.1482381671667099, + -0.012174196541309357, + -0.03128019720315933, + 0.06983920931816101, + 0.2055918425321579, + 0.16611628234386444, + 0.20955723524093628, + 0.21407610177993774, + 0.13214662671089172, + 0.01558306161314249, + 0.20919384062290192, + 0.21453723311424255, + 0.10980720072984695, + 0.10323476791381836, + 0.1754676252603531, + 0.16320686042308807, + 0.076839879155159, + 0.2669583261013031, + 0.29500535130500793, + 0.18005967140197754, + 0.14900699257850647, + 0.2337430715560913, + 0.2607984244823456, + -0.08909865468740463, + 0.12383633106946945, + 0.27329200506210327, + 0.2634970247745514, + 0.2298160344362259, + 0.22673286497592926, + 0.1753624528646469, + -0.14258335530757904, + -0.033422429114580154, + 0.09338828176259995, + 0.21975602209568024, + 0.2488732784986496, + 0.21165378391742706, + 0.08514796197414398, + 0.0776415765285492, + -0.028732767328619957, + -0.0827818363904953, + -0.14784079790115356, + -0.06101813539862633, + -0.10570015013217926, + -0.07298385351896286, + -0.03352680057287216, + -0.08094660192728043, + -0.08546923100948334, + -0.025722583755850792, + -0.04828448221087456, + -0.15816760063171387, + -0.22295169532299042, + -0.04976325109601021, + -0.12255501747131348, + -0.04869991913437843, + 0.09818085283041, + 0.2285904735326767, + 0.015187943354249, + -0.19952231645584106, + -0.1415022611618042, + -0.09511925280094147, + 0.10828559100627899, + 0.35640013217926025, + 0.5399265289306641, + 0.3026861250400543, + -0.10532847791910172, + -0.0455780103802681, + -0.09365752339363098, + 0.2482689470052719, + 0.5483031272888184, + 0.6572608947753906, + 0.4098849594593048, + -0.0039499495178461075, + -0.11641024053096771, + -0.22666053473949432, + -0.03133581206202507, + 0.2815704643726349, + 0.3229265809059143, + 0.009749597869813442, + -0.19616934657096863, + -0.05046992748975754, + -0.15597671270370483, + -0.22775587439537048, + -0.14872166514396667, + -0.12174414098262787, + -0.23433859646320343, + -0.238412007689476, + 0.09725375473499298, + 0.08522887527942657, + 0.006490080617368221, + -0.024619178846478462, + 0.07278231531381607, + 0.13406167924404144, + 0.22993306815624237, + 0.10250072181224823, + 0.09119024127721786, + -0.07687287777662277, + -0.1012108325958252, + -0.09500063210725784, + -0.10082961618900299, + 0.09466016292572021, + 0.11299365013837814, + -0.033278487622737885, + -0.20269805192947388, + -0.21449527144432068, + -0.08820098638534546, + -0.18970704078674316, + -0.050536416471004486, + -0.03471578657627106, + -0.13205547630786896, + -0.18150201439857483, + -0.03963223099708557, + 0.13029472529888153, + -0.11594776809215546, + -0.173879474401474, + 0.017406627535820007, + -0.11885572224855423, + -0.06966021656990051, + 0.1687183529138565, + 0.2677668035030365, + -0.020446041598916054, + -0.11710261553525925, + 0.044354867190122604, + -0.10054060816764832, + -0.1287878155708313, + -0.03600803390145302, + -0.03198331966996193, + -0.22372953593730927, + -0.11045534163713455, + 0.22963544726371765, + 0.16736479103565216, + -0.023956498131155968, + -0.0882943719625473, + -0.11904646456241608, + -0.10481738299131393, + 0.083598293364048, + 0.058089643716812134, + -0.04821285232901573, + 0.16764044761657715, + -0.13788309693336487, + -0.1412951946258545, + 0.059633608907461166, + 0.012824267148971558, + -0.03141501545906067, + -0.017422236502170563, + 0.3908282518386841, + -0.31520241498947144, + -0.27876099944114685, + 0.17109407484531403, + 0.011913848109543324, + -0.04440265893936157, + 0.05610174685716629, + 0.5290316343307495, + -0.4506116211414337, + -0.2946499288082123, + 0.2802693545818329, + 0.04180249199271202, + -0.05673402547836304, + 0.0445592887699604, + 0.4933576285839081, + -0.4903600513935089, + -0.3259376883506775, + 0.26069584488868713, + 0.047843094915151596, + -0.053804315626621246, + 0.029928382486104965, + 0.3588394224643707, + -0.39090782403945923, + -0.18598265945911407, + 0.1703576147556305, + 0.010407418012619019, + 0.019840527325868607, + -0.017079327255487442, + 0.21012797951698303, + -0.1586841642856598, + -0.12738685309886932, + 0.12431345880031586, + 0.028149213641881943, + 0.05083676427602768, + -0.07053223252296448, + 0.12090320140123367, + -0.13737183809280396, + -0.09807822853326797, + 0.07203921675682068, + -0.01965559460222721, + 0.036479320377111435, + -0.02657422423362732, + 0.2924504280090332, + -0.19397024810314178, + -0.20908842980861664, + 0.07435549795627594, + 0.011985386721789837, + -0.051603686064481735, + 0.039122600108385086, + 0.5911946892738342, + -0.45937344431877136, + -0.43863579630851746, + 0.23180224001407623, + 0.05592876672744751, + -0.10227655619382858, + 0.1371937245130539, + 0.7193072438240051, + -0.6789532899856567, + -0.5275344252586365, + 0.4098500609397888, + 0.09136661887168884, + -0.08802130073308945, + 0.12226735055446625, + 0.6819202303886414, + -0.7316576838493347, + -0.5229181051254272, + 0.37578293681144714, + 0.09086397290229797, + -0.05128701403737068, + 0.09287497401237488, + 0.5103837251663208, + -0.6150248646736145, + -0.3208717107772827, + 0.29780012369155884, + 0.071808360517025, + 0.04605705663561821, + 0.028153980150818825, + 0.30872926115989685, + -0.32211968302726746, + -0.1925150454044342, + 0.18948692083358765, + 0.07391810417175293, + 0.08546463400125504, + -0.07042243331670761, + 0.14390304684638977, + -0.22509464621543884, + -0.12615789473056793, + 0.09681600332260132, + 0.0030679223127663136, + 0.06206878274679184, + -0.0493885837495327, + 0.11675205081701279, + -0.09476804733276367, + -0.0708041712641716, + 0.027848264202475548, + 0.018535451963543892, + 0.01112216804176569, + -0.023546719923615456, + 0.2808285057544708, + -0.2312571257352829, + -0.16320407390594482, + 0.15229304134845734, + -0.007220278959721327, + -0.026767488569021225, + -0.008487970568239689, + 0.39064091444015503, + -0.3746477961540222, + -0.22930599749088287, + 0.23297259211540222, + -0.020648201927542686, + -0.03918099403381348, + -0.03193120285868645, + 0.37857353687286377, + -0.38306936621665955, + -0.25103962421417236, + 0.2414209097623825, + 0.007709929719567299, + -0.041483473032712936, + -0.001570625347085297, + 0.315625935792923, + -0.276553213596344, + -0.13154125213623047, + 0.17517149448394775, + 0.03219839558005333, + 0.002647437620908022, + -0.012777225114405155, + 0.17064248025417328, + -0.13943275809288025, + -0.10204917937517166, + 0.09418098628520966, + 0.026260169222950935, + 0.05167905613780022, + -0.024634944275021553, + 0.0931941494345665, + -0.11875593662261963, + -0.0752263143658638, + 0.0569780170917511, + 0.00024334408226422966, + -0.001991289434954524, + -0.012094452045857906, + -0.0012201170902699232, + 0.01342268567532301, + 0.006425719242542982, + 0.01147665549069643, + -0.002208880614489317, + -0.019385183230042458, + -0.024868011474609375, + 0.00465290667489171, + 0.009205960668623447, + 0.0016242304118350148, + 0.0059639886021614075, + -0.03436571732163429, + 0.01672518253326416, + 0.008815832436084747, + 0.06389293074607849, + 0.06249547377228737, + 0.06542838364839554, + 0.043118152767419815, + 0.04117512330412865, + 0.014435848221182823, + 0.0065850247628986835, + 0.03811212629079819, + -0.006077034864574671, + -0.004025861620903015, + 0.006247953977435827, + 0.014478449709713459, + 0.0009701942908577621, + -0.002422194229438901, + 0.009390920400619507, + -0.052253514528274536, + -0.05192738026380539, + -0.010346310213208199, + -0.001328076352365315, + -0.002972622634842992, + 0.0015572139527648687, + 0.022503724321722984, + -0.002475353656336665, + 0.001927886507473886, + 0.02994818612933159, + 0.02062363363802433, + -0.0010653833160176873, + -0.005995174869894981, + 0.024450020864605904, + 0.013005194254219532, + 0.0496530681848526, + 0.029475165531039238, + 0.004157512914389372, + -0.0007043799851089716, + 0.01860312558710575, + 0.03839566186070442, + 0.00014980587002355605, + 0.018569663166999817, + 0.05668198689818382, + 0.04645680636167526, + 0.01642409712076187, + 0.03577466681599617, + 0.03575601801276207, + -0.03680748492479324, + -0.01865880750119686, + 0.041660092771053314, + 0.033268485218286514, + 0.03338993713259697, + 0.04665865749120712, + -0.03322917968034744, + -0.2860279381275177, + -0.28877392411231995, + -0.09617949277162552, + 0.014234350994229317, + 0.038012001663446426, + -0.016850680112838745, + -0.27252569794654846, + -0.6714493632316589, + -0.686245322227478, + -0.3376169502735138, + -0.0812990590929985, + 0.003058002796024084, + -0.026376569643616676, + -0.29216718673706055, + -0.6779875159263611, + -0.6917123198509216, + -0.3184400796890259, + -0.058261968195438385, + 0.06338769942522049, + 0.03199980780482292, + -0.09837217628955841, + -0.3355932831764221, + -0.30900436639785767, + -0.04878076910972595, + 0.061543505638837814, + 0.04651529714465141, + 0.0263908002525568, + 0.0030237447936087847, + -0.10458099842071533, + -0.07959774881601334, + 0.05430716276168823, + 0.056767694652080536, + 0.00796051137149334, + -0.016737859696149826, + -0.042338743805885315, + -0.0198048185557127, + -0.03085070475935936, + -0.058721307665109634, + -0.036032311618328094, + -0.0035414688754826784, + -8.359456842299551e-05, + -0.02213932015001774, + 0.02032857947051525, + 0.021788733080029488, + -0.03522418439388275, + -0.025317413732409477, + -0.042937491089105606, + -0.05680134892463684, + -0.012510996311903, + 0.226289302110672, + 0.24401520192623138, + 0.022300971671938896, + -0.030825607478618622, + -0.05485948920249939, + 0.007590078748762608, + 0.2208130657672882, + 0.6964298486709595, + 0.7457719445228577, + 0.3470557630062103, + 0.06941442936658859, + -0.03543366119265556, + 0.035853609442710876, + 0.2872598171234131, + 0.7504303455352783, + 0.7509996294975281, + 0.34327855706214905, + 0.024429334327578545, + -0.05711393058300018, + -0.034500252455472946, + 0.057939525693655014, + 0.33292675018310547, + 0.3141649067401886, + 0.033748809248209, + -0.062175147235393524, + -0.041224412620067596, + -0.01891348883509636, + -0.014519350603222847, + 0.08635713160037994, + 0.03148616850376129, + -0.08749162405729294, + -0.05658482387661934, + 0.00018510188965592533, + 0.002624311950057745, + -0.003570129396393895, + 0.0067627751268446445, + -0.01349653396755457, + -0.003961967770010233, + 0.0034001911990344524, + -0.00385954394005239, + 0.018012456595897675, + -0.018755480647087097, + -0.03163064643740654, + -0.0035233700182288885, + 0.011690095998346806, + -0.014693490229547024, + 0.017746854573488235, + 0.05693097040057182, + 0.1272590607404709, + 0.23477119207382202, + 0.19823509454727173, + 0.05071045830845833, + -0.007188393268734217, + -0.05571149289608002, + -0.06468938291072845, + -0.017831332981586456, + -0.07572834193706512, + -0.19599483907222748, + -0.15608063340187073, + -0.039450764656066895, + -0.035583946853876114, + -0.1605951488018036, + -0.5041624307632446, + -0.6836286783218384, + -0.3773191571235657, + -0.08623629808425903, + -0.04881078004837036, + 0.029403403401374817, + 0.15516817569732666, + 0.4108496308326721, + 0.6393839716911316, + 0.4688946008682251, + 0.2135964334011078, + 0.0623941570520401, + 0.02426956780254841, + -8.065254223765805e-05, + -0.00816427543759346, + -0.09353788942098618, + -0.06872912496328354, + -0.029405562207102776, + 0.012364620342850685, + 0.0060868943110108376, + 0.017015695571899414, + -0.0076495204120874405, + -0.006090708542615175, + -0.016521835699677467, + 0.009218892082571983, + 0.030833140015602112, + -0.0002345978282392025, + 0.03332215175032616, + 0.0030349211301654577, + 0.009600857272744179, + 0.05706647038459778, + 0.06095677986741066, + -0.016137542203068733, + 0.03195658698678017, + 0.13535599410533905, + 0.28229761123657227, + 0.4573267698287964, + 0.39102476835250854, + 0.17547546327114105, + 0.005337159149348736, + -0.07699840515851974, + -0.12667469680309296, + -0.16613735258579254, + -0.2908898890018463, + -0.44942277669906616, + -0.34229782223701477, + -0.16225378215312958, + -0.1100199744105339, + -0.4044281840324402, + -0.9058251976966858, + -1.1549302339553833, + -0.7502554059028625, + -0.2716369032859802, + -0.13495275378227234, + 0.08614412695169449, + 0.3164423108100891, + 0.7155097723007202, + 1.0356683731079102, + 0.7939887642860413, + 0.39567017555236816, + 0.16957539319992065, + 0.02675812318921089, + 0.048314403742551804, + 0.053107086569070816, + -0.009243623353540897, + -0.011442561633884907, + 0.004911235999315977, + 0.012210517190396786, + 0.006660772021859884, + -0.004562888294458389, + -0.009606098756194115, + -0.01610635593533516, + -0.03475078567862511, + 0.007796770427376032, + 0.02015513926744461, + 0.020311446860432625, + 0.009043446741998196, + -0.01929326355457306, + -0.04183953255414963, + -0.003052672604098916, + 0.020744286477565765, + 0.01371331699192524, + 0.004048139322549105, + 0.0692848190665245, + 0.16867054998874664, + 0.2799474000930786, + 0.28119951486587524, + 0.13579942286014557, + -0.0015732255997136235, + -0.05406518653035164, + -0.05831173434853554, + -0.034435681998729706, + -0.11925295740365982, + -0.2570647895336151, + -0.19120880961418152, + -0.09981344640254974, + -0.011702792719006538, + -0.22477947175502777, + -0.5395713448524475, + -0.7111374139785767, + -0.4207299053668976, + -0.11811137199401855, + -0.035199034959077835, + 0.024358956143260002, + 0.16262274980545044, + 0.46769100427627563, + 0.677872896194458, + 0.4637402892112732, + 0.15558630228042603, + 0.04467496648430824, + 0.03221412003040314, + 0.02430277317762375, + -0.006398700177669525, + -0.07235423475503922, + -0.03669704124331474, + -0.000992153538390994, + 0.02220241352915764, + -0.03329842537641525, + 0.05199713259935379, + -0.14053553342819214, + 0.1906905472278595, + -0.13544943928718567, + 0.08535720407962799, + -0.009813228622078896, + 0.03578176349401474, + -0.05863757058978081, + 0.33848440647125244, + -0.49837300181388855, + 0.15308170020580292, + 0.14865124225616455, + -0.12349266558885574, + -0.025796135887503624, + 0.17790427803993225, + -0.7813658714294434, + 0.853188693523407, + 0.2489670068025589, + -0.7378701567649841, + 0.2207188457250595, + 0.05207442864775658, + -0.4280349314212799, + 1.1408430337905884, + -0.24505679309368134, + -1.5490919351577759, + 1.4560288190841675, + -0.31143030524253845, + -0.03536878153681755, + 0.5640448331832886, + -0.6874421834945679, + -1.210310697555542, + 2.6637399196624756, + -1.6589887142181396, + 0.2221546173095703, + 0.10179737955331802, + -0.4354941248893738, + 0.034149203449487686, + 1.480568528175354, + -2.072199821472168, + 0.9205833673477173, + 0.021510563790798187, + -0.07755836099386215, + 0.17983688414096832, + 0.040537625551223755, + -0.5325585603713989, + 0.550999641418457, + -0.11060550063848495, + -0.09052976220846176, + -0.048361390829086304, + 0.03450514376163483, + -0.11854307353496552, + 0.23462797701358795, + -0.17563995718955994, + 0.0653814822435379, + -0.009748813696205616, + 0.07013920694589615, + -0.08628369867801666, + 0.3019683063030243, + -0.630340576171875, + 0.274477481842041, + 0.15417183935642242, + -0.036220982670784, + -0.07344137132167816, + 0.2339126616716385, + -1.0395091772079468, + 1.2002928256988525, + 0.085142120718956, + -0.7080597281455994, + 0.23101751506328583, + 0.016307154670357704, + -0.45877355337142944, + 1.617128849029541, + -0.6593433618545532, + -1.8957709074020386, + 1.746606469154358, + -0.37062564492225647, + 0.01213759370148182, + 0.5851964354515076, + -1.0307577848434448, + -1.4803766012191772, + 3.812014102935791, + -2.0028398036956787, + 0.12008816003799438, + 0.01813559979200363, + -0.5065457820892334, + 0.17598780989646912, + 2.0418734550476074, + -2.680522918701172, + 0.7466094493865967, + 0.16271913051605225, + -0.04379571974277496, + 0.21930621564388275, + 0.041255541145801544, + -0.6644601821899414, + 0.481300413608551, + 0.05410065874457359, + -0.09025495499372482, + 0.01954805478453636, + 0.01899997517466545, + -0.1337241530418396, + 0.19821906089782715, + -0.06395180523395538, + -0.03586877882480621, + 0.01973363384604454, + 0.013873124495148659, + -0.09288538247346878, + 0.4300728440284729, + -0.4235192537307739, + 0.03646458685398102, + 0.10077393800020218, + -0.07569073140621185, + -0.08176662772893906, + 0.3834531605243683, + -0.747482419013977, + 0.4493187367916107, + 0.2960513234138489, + -0.5245057344436646, + 0.27831950783729553, + 0.0731748417019844, + -0.45574328303337097, + 0.6987965703010559, + 0.019539732486009598, + -1.1160184144973755, + 1.0756875276565552, + -0.3804619312286377, + -0.040626902133226395, + 0.2780243456363678, + -0.32946258783340454, + -0.8122196793556213, + 1.9535348415374756, + -1.300661563873291, + 0.3443142771720886, + 0.04858396574854851, + -0.17409801483154297, + -0.07783844321966171, + 1.0875797271728516, + -1.5148566961288452, + 0.8014272451400757, + -0.19643208384513855, + -0.033590562641620636, + 0.11178025603294373, + 0.08284300565719604, + -0.5165408849716187, + 0.5841389894485474, + -0.24739950895309448, + 0.027926180511713028, + -0.028708497062325478, + 0.0037401756271719933, + -0.0047450135461986065, + 0.008427698165178299, + 0.009801353327929974, + -0.0029346586670726538, + -0.010193527676165104, + 0.014876358211040497, + 0.009861295111477375, + -0.005554665345698595, + -0.06270359456539154, + -0.0316256619989872, + 0.006706684362143278, + 0.04316525161266327, + 0.008637072518467903, + -0.03666357323527336, + -0.0719730481505394, + -0.1525861918926239, + -0.14396126568317413, + -0.05387119948863983, + 0.01955549605190754, + 0.007112634833902121, + -0.05175568535923958, + -0.16772602498531342, + -0.20807777345180511, + -0.18768996000289917, + -0.17093753814697266, + -0.03334345668554306, + 0.0011808606795966625, + -0.01579100452363491, + -0.12589050829410553, + -0.17219413816928864, + -0.19648219645023346, + -0.21980451047420502, + -0.04920821264386177, + 0.0012217299081385136, + 0.023885242640972137, + -0.056074876338243484, + -0.13907776772975922, + -0.19139252603054047, + -0.13652737438678741, + -0.0027339402586221695, + 0.004720518831163645, + -0.00037206560955382884, + 0.017924504354596138, + -0.02118082158267498, + -0.06553903222084045, + -0.0435921773314476, + 0.02721239998936653, + 0.020702000707387924, + 0.024033410474658012, + 0.005382229574024677, + -0.01273527555167675, + -0.01742861233651638, + 0.007402990944683552, + 0.010333286598324776, + 0.02598601020872593, + 0.012456837110221386, + -0.03471057116985321, + -0.10051856189966202, + -0.08084382116794586, + -0.023420603945851326, + 0.031205907464027405, + 0.00424322672188282, + -0.03734385594725609, + -0.1152661070227623, + -0.2012551724910736, + -0.1995576024055481, + -0.07972321659326553, + -0.011126434430480003, + -0.0185835100710392, + -0.06944561004638672, + -0.21481844782829285, + -0.26795628666877747, + -0.24916253983974457, + -0.17833945155143738, + -0.06658200174570084, + -0.00305415247566998, + -0.054028186947107315, + -0.19072681665420532, + -0.256619930267334, + -0.26868295669555664, + -0.21621295809745789, + -0.06564134359359741, + 0.0031192339956760406, + 0.013205861672759056, + -0.08044812828302383, + -0.18137820065021515, + -0.23007699847221375, + -0.13054916262626648, + -0.01135951280593872, + 0.013734308071434498, + 0.010981118306517601, + -0.02249351143836975, + -0.05804377421736717, + -0.10652261227369308, + -0.04163172468543053, + 0.017101088538765907, + -0.028687385842204094, + -0.0019976652693003416, + 0.009987232275307178, + 0.010130539536476135, + 0.0015575449215248227, + -0.000983694102615118, + -0.012845008634030819, + 0.01329281460493803, + 0.0029350779950618744, + -0.003755913581699133, + -0.036475058645009995, + -0.0245466697961092, + -0.0020879909861832857, + 0.025867130607366562, + -0.0065954397432506084, + 0.008656582795083523, + -0.04037104919552803, + -0.11718368530273438, + -0.13506115972995758, + -0.024255141615867615, + 0.014097613282501698, + -0.0009370348998345435, + -0.010953565128147602, + -0.12869219481945038, + -0.18789908289909363, + -0.19098156690597534, + -0.12795749306678772, + -0.002666366985067725, + -0.004907527007162571, + -0.014610078185796738, + -0.11913872510194778, + -0.19921070337295532, + -0.21869640052318573, + -0.1849898099899292, + -0.03470952808856964, + 0.0064156935550272465, + 0.03401843458414078, + -0.04000416398048401, + -0.12354391813278198, + -0.16908879578113556, + -0.10385500639677048, + 0.002833302365615964, + -0.036176733672618866, + -0.001048827893100679, + 0.010002595372498035, + -0.020798830315470695, + -0.0488261841237545, + -0.002972641494125128, + 0.016395021229982376, + -0.045770127326250076, + -0.12710650265216827, + -0.1637774109840393, + -0.1411965787410736, + 0.20447289943695068, + 0.509396493434906, + 0.07264503091573715, + 0.12041529268026352, + -0.015143441036343575, + -0.2673257887363434, + -0.3589763641357422, + 0.11289574205875397, + 0.8517020344734192, + 0.7068799138069153, + 0.067301444709301, + -0.02102830447256565, + -0.5235708355903625, + -1.2064802646636963, + -0.856619656085968, + 0.26774707436561584, + 0.6825867295265198, + 0.13516077399253845, + 0.3054035007953644, + -0.0727991834282875, + -1.4912222623825073, + -1.906838297843933, + -0.8574200868606567, + -0.15282419323921204, + 0.39327505230903625, + 0.9758505821228027, + 1.2323224544525146, + 0.18179064989089966, + -0.947610080242157, + -0.6657719016075134, + -0.19935055077075958, + -0.09150458872318268, + 0.34379544854164124, + 1.2025749683380127, + 0.9517407417297363, + -0.12023784220218658, + -0.3146151900291443, + -0.1049022302031517, + -0.34867578744888306, + -0.32945582270622253, + 0.28920575976371765, + 0.7844374179840088, + 0.35520124435424805, + 0.007452746387571096, + 0.018862545490264893, + -0.0021927610505372286, + 0.0321974977850914, + 0.05439181253314018, + -0.030729038640856743, + -0.03517322614789009, + -0.037830010056495667, + -0.056672073900699615, + -0.017769837751984596, + 0.06385952979326248, + 0.08161566406488419, + 0.07809178531169891, + 0.06333671510219574, + -0.036322008818387985, + -0.06432312726974487, + -0.03629852458834648, + 0.010879911482334137, + 0.088901087641716, + 0.0021402277052402496, + 0.09618857502937317, + 0.02661084569990635, + -0.03414442762732506, + -0.08736730366945267, + -0.048222169280052185, + 0.03507986292243004, + -0.053828027099370956, + 0.006044292356818914, + 0.04232194274663925, + 0.001624415279366076, + -0.028371643275022507, + -0.08724038302898407, + -0.005835397634655237, + 0.01057528518140316, + 0.04210871085524559, + 0.06106603890657425, + 0.04250370338559151, + 0.0028668276499956846, + -0.07583706080913544, + -0.06849333643913269, + -0.08538331836462021, + -0.021475542336702347, + 0.044341571629047394, + 0.03604369983077049, + 0.05146002024412155, + 0.00280605535954237, + -0.004615028854459524, + -0.07857430726289749, + -0.03716180846095085, + 0.010876243002712727, + -0.03418488800525665, + 0.007391764782369137, + 0.05969953536987305, + 0.08769611269235611, + 0.066011443734169, + -0.10404568910598755, + -0.27194535732269287, + -0.05224551260471344, + -0.03618992492556572, + -0.023098375648260117, + 0.13832588493824005, + 0.21510572731494904, + -0.07285867631435394, + -0.489085853099823, + -0.33285844326019287, + -0.04830349236726761, + 0.014211038127541542, + 0.2612524926662445, + 0.6911754608154297, + 0.5294638276100159, + -0.2706173360347748, + -0.39350029826164246, + -0.05156399682164192, + -0.16490484774112701, + 0.1161464974284172, + 0.8029336929321289, + 1.1809980869293213, + 0.5025736689567566, + 0.07084998488426208, + -0.1901131123304367, + -0.4918227195739746, + -0.603122889995575, + -0.09460704773664474, + 0.5786081552505493, + 0.35392242670059204, + 0.1328991800546646, + -0.008106965571641922, + -0.2159435749053955, + -0.6369062662124634, + -0.5241336822509766, + 0.06276796758174896, + 0.1139409989118576, + 0.05483332276344299, + 0.1703934520483017, + 0.14603517949581146, + -0.16187912225723267, + -0.4139055907726288, + -0.14918148517608643, + -0.06163417547941208, + 0.005302567034959793, + 0.015524876303970814, + -0.11895350366830826, + -0.19724233448505402, + 0.03412429615855217, + 0.10862118750810623, + 0.08550503104925156, + -0.008599682711064816, + -0.03031114675104618, + -0.33224624395370483, + -0.27994298934936523, + 0.196475550532341, + 0.31109708547592163, + 0.17151644825935364, + -0.04994147643446922, + -0.167176753282547, + -0.5247878432273865, + -0.21136601269245148, + 0.54701828956604, + 0.6110883951187134, + 0.04194486886262894, + -0.27640673518180847, + -0.0795169249176979, + -0.360530287027359, + 0.3472684621810913, + 1.5428175926208496, + 1.0249378681182861, + -0.2724844515323639, + -0.3013695478439331, + 0.020736562088131905, + -0.019495302811264992, + 0.7758124470710754, + 1.5381159782409668, + 0.028625331819057465, + -1.289720892906189, + -0.5894255638122559, + 0.0526396706700325, + 0.11443997919559479, + 0.5935031771659851, + 0.47169724106788635, + -1.2507063150405884, + -1.351940631866455, + -0.03894977271556854, + 0.05095001682639122, + 0.01581231690943241, + 0.11137383431196213, + -0.22327138483524323, + -0.9629225730895996, + -0.2607772946357727, + 0.5907121300697327, + 0.006906076334416866, + 0.002633580705150962, + 0.01940075121819973, + 0.0143396882340312, + 0.020781584084033966, + -0.07249777764081955, + -0.016355905681848526, + 0.016553230583667755, + -0.027528395876288414, + 0.0244428887963295, + 0.024910561740398407, + 0.027229825034737587, + -0.04104151204228401, + 0.007100561633706093, + 0.0157785601913929, + -0.06626633554697037, + 0.006520191207528114, + 0.021171070635318756, + 0.036674920469522476, + -0.06950324773788452, + -0.03003627620637417, + 2.178798422391992e-05, + -0.07278106361627579, + 0.014382920227944851, + 0.0982266515493393, + 0.1454961597919464, + -0.10096189379692078, + 0.022237209603190422, + -0.00040665315464138985, + -0.013766243122518063, + 0.06440296769142151, + 0.21751047670841217, + 0.02519127167761326, + -0.23383572697639465, + 0.0038903038948774338, + -0.042271602898836136, + -0.012596859596669674, + 0.023778460919857025, + 0.07685687392950058, + -0.21480663120746613, + -0.19205358624458313, + 0.04876565560698509, + -0.016765035688877106, + -0.02620583213865757, + 0.01641852967441082, + 0.02201787941157818, + -0.07457322627305984, + -0.003633625339716673, + 0.07550841569900513, + 0.024774253368377686, + 0.04710151255130768, + 0.09110233932733536, + -0.017366377636790276, + -0.04366954043507576, + -0.039786458015441895, + 0.005311290733516216, + 0.037867460399866104, + 0.05367766693234444, + 0.07434491813182831, + -0.07251215726137161, + -0.04231821000576019, + -0.023427855223417282, + 0.036294277757406235, + 0.07782749086618423, + 0.11835407465696335, + 0.08753973245620728, + -0.20742319524288177, + -0.13341759145259857, + -0.008225077763199806, + 0.07292432337999344, + 0.006392402108758688, + 0.021914338693022728, + -0.09218581020832062, + -0.44192466139793396, + -0.1744878888130188, + 0.014938815496861935, + 0.10678526759147644, + -0.012087192386388779, + -0.024533385410904884, + -0.1804407387971878, + -0.3253834545612335, + 0.040678758174180984, + 0.2011708915233612, + 0.17262929677963257, + -0.0045212251134216785, + -0.033313386142253876, + -0.10575363039970398, + -0.07636692374944687, + 0.20343273878097534, + 0.28330928087234497, + 0.043149981647729874, + -0.01109551265835762, + -0.0027725452091544867, + 0.003926735837012529, + 0.029440222308039665, + 0.23945140838623047, + 0.09122566133737564, + -0.15140119194984436, + 0.08737201988697052, + 0.07120998948812485, + 0.05722665786743164, + -0.04388495534658432, + 0.02116825245320797, + 0.023315919563174248, + 0.10898162424564362, + 0.11808467656373978, + 0.03412344306707382, + 0.002771642990410328, + -0.1959579437971115, + -0.05181330814957619, + -0.0044630044139921665, + 0.12481725960969925, + 0.09140311926603317, + 0.03444851189851761, + -0.10931172221899033, + -0.3204459846019745, + -0.21193139255046844, + -0.11101037263870239, + 0.04186606407165527, + -0.07420916110277176, + -0.2004990428686142, + -0.26937955617904663, + -0.12928874790668488, + 0.20819628238677979, + -0.17379426956176758, + -0.2181481271982193, + 0.005387924611568451, + -0.24132733047008514, + -0.23942433297634125, + 0.41489261388778687, + 1.0702778100967407, + 0.024913936853408813, + -0.28405970335006714, + 0.083008773624897, + -0.11059781163930893, + -0.17623695731163025, + -0.17386195063591003, + 0.010644182562828064, + -0.32716259360313416, + -0.2135595828294754, + 0.1223129853606224, + 0.07060510665178299, + -0.048680394887924194, + -0.3332099914550781, + -0.25886017084121704, + -0.18619979918003082, + -0.00733158877119422, + 0.03393476828932762, + -0.010564662516117096, + -0.01817108877003193, + -0.05650597810745239, + -0.01891104131937027, + -0.0554141066968441, + -0.004592927638441324, + -0.0013615720672532916, + -0.05552899092435837, + -0.0560498908162117, + -0.1080632209777832, + -0.013965745456516743, + -0.03290533646941185, + -0.02599845454096794, + -0.02877708151936531, + -0.05670137703418732, + -0.07158109545707703, + -0.08808472007513046, + -0.03919175639748573, + -0.08478893339633942, + -0.08045543730258942, + -0.10066724568605423, + -0.048338882625103, + -0.06750114262104034, + 0.08164039999246597, + 0.3343777060508728, + 0.004952755756676197, + -0.14891156554222107, + 0.032855477184057236, + -0.03277512267231941, + 0.0474768728017807, + 0.6316664814949036, + 1.2214386463165283, + 0.2548498213291168, + -0.13185030221939087, + -0.018188906833529472, + -0.07653989642858505, + -0.01643386110663414, + 0.06630122661590576, + 0.23864209651947021, + -0.013703612610697746, + -0.09347789734601974, + -0.0900193303823471, + -0.04930814355611801, + -0.02791711315512657, + -0.15441712737083435, + -0.01623091846704483, + -0.0447690524160862, + -0.06071227043867111, + -0.04737209901213646, + -0.059769801795482635, + -0.04375007003545761, + -0.00650476710870862, + 0.021540174260735512, + -0.05590728670358658, + -0.13030850887298584, + -0.022067781537771225, + -0.05066747963428497, + 0.00609770929440856, + 0.108611099421978, + 0.1621929407119751, + 0.05232185125350952, + -0.049729123711586, + -0.11906369775533676, + -0.030973592773079872, + 0.057787079364061356, + 0.1610448956489563, + 0.18756121397018433, + 0.07277501374483109, + -0.05777435004711151, + -0.05227195844054222, + 0.14434091746807098, + 0.1889694482088089, + 0.26951169967651367, + 0.4710105359554291, + 0.2164669781923294, + 0.05052375793457031, + -0.0038236663676798344, + 0.20267778635025024, + 0.31214746832847595, + 0.7506387829780579, + 1.2302387952804565, + 0.4363090693950653, + 0.16759593784809113, + -0.049752235412597656, + 0.044786907732486725, + 0.14537742733955383, + 0.2227499932050705, + 0.37362414598464966, + 0.16590620577335358, + 0.0864599421620369, + -0.14058542251586914, + -0.04404178634285927, + -0.0325944609940052, + -0.019113417714834213, + 0.17414243519306183, + 0.11160623282194138, + -0.034911543130874634, + 0.1523953527212143, + 0.04554234445095062, + -0.054958827793598175, + -0.11794494092464447, + -0.19570015370845795, + -0.21358126401901245, + -0.1885669231414795, + -0.08286706358194351, + -0.29818814992904663, + -0.52330082654953, + -0.6190353631973267, + -0.682529091835022, + -0.6171367764472961, + -0.4793100655078888, + -0.11180876195430756, + -0.3490432798862457, + -0.5531057715415955, + -0.6426181793212891, + -0.6420838832855225, + -0.4970071613788605, + -0.27038174867630005, + -0.09740017354488373, + -0.1929621547460556, + -0.30848363041877747, + -0.27204805612564087, + -0.2515120208263397, + -0.07497832179069519, + 0.03551386669278145, + -0.05060403421521187, + 0.08276989310979843, + 0.14321963489055634, + 0.3583574593067169, + 0.40667927265167236, + 0.39398193359375, + 0.27561235427856445, + 0.005085935816168785, + 0.2793635427951813, + 0.48155927658081055, + 0.7088037729263306, + 0.7394692897796631, + 0.6158861517906189, + 0.3986552655696869, + 0.025508087128400803, + 0.38533228635787964, + 0.5305332541465759, + 0.6659612059593201, + 0.6396889090538025, + 0.5396444797515869, + 0.39010515809059143, + -0.03072960674762726, + 0.014305810444056988, + 0.029885446652770042, + 0.038084372878074646, + 0.012448564171791077, + 0.034353457391262054, + 0.048626724630594254, + 0.048866890370845795, + 0.07561437785625458, + 0.09152165800333023, + 0.08432324975728989, + 0.09332144260406494, + 0.07517607510089874, + 0.049146559089422226, + 0.03146318346261978, + 0.06335246562957764, + 0.06438779830932617, + 0.06851581484079361, + 0.09263566881418228, + 0.06460423022508621, + 0.011992924846708775, + 0.03396693989634514, + 0.04433950409293175, + 0.04642309248447418, + 0.0022602551616728306, + -0.0361824594438076, + -0.0005105047021061182, + 0.030808264389634132, + 0.0022333709057420492, + -0.017826544120907784, + -0.03796307370066643, + -0.012887164019048214, + -0.028499294072389603, + -0.03367336839437485, + -0.03668365254998207, + -0.02807682938873768, + -0.07444571703672409, + -0.081318199634552, + -0.09610070288181305, + -0.05368436127901077, + -0.09006591141223907, + -0.10038736462593079, + -0.04115951433777809, + -0.056811004877090454, + -0.09935522079467773, + -0.11107856035232544, + -0.07852742075920105, + -0.0942930206656456, + -0.07625897973775864, + -0.12966541945934296, + -0.038938648998737335, + 0.04580259323120117, + 0.10179819911718369, + 0.17127273976802826, + 0.17857632040977478, + 0.13426578044891357, + 0.04687841981649399, + 0.2424812912940979, + 0.42633309960365295, + 0.5291624069213867, + 0.6012980937957764, + 0.5449428558349609, + 0.3945220708847046, + 0.07037744671106339, + 0.26918724179267883, + 0.44614800810813904, + 0.5331310629844666, + 0.568580687046051, + 0.43367546796798706, + 0.25516101717948914, + 0.08428427577018738, + 0.177769735455513, + 0.24885930120944977, + 0.2178547978401184, + 0.13834305107593536, + 0.07452446967363358, + 0.005187708884477615, + 0.050621017813682556, + -0.08428733795881271, + -0.15576106309890747, + -0.25531095266342163, + -0.34646397829055786, + -0.3276817202568054, + -0.24377694725990295, + 0.02817704901099205, + -0.2531633675098419, + -0.3907041549682617, + -0.5944734811782837, + -0.6062930822372437, + -0.5171639919281006, + -0.3501560389995575, + -0.019397703930735588, + -0.2758809030056, + -0.4118667244911194, + -0.5375933051109314, + -0.5525977611541748, + -0.44681206345558167, + -0.2748269736766815, + -0.04229651764035225, + -0.005005967803299427, + -0.011332424357533455, + 0.011387092061340809, + -0.015463154762983322, + -0.012038768269121647, + 0.011360889300704002, + 0.03551746904850006, + 0.05123865604400635, + 0.020377267152071, + 0.1065637394785881, + 0.18875306844711304, + 0.18516196310520172, + 0.12519532442092896, + -0.042940977960824966, + -0.03246130794286728, + -0.016645772382616997, + 0.07807288318872452, + -0.7815885543823242, + -0.5930942296981812, + 0.03312799707055092, + -0.04537777230143547, + -0.022234303876757622, + 0.009241255931556225, + 0.16947965323925018, + -0.0700032040476799, + -0.06346366554498672, + 0.09555318206548691, + 0.02858082763850689, + 0.009246457368135452, + 0.03902693837881088, + 0.007071994710713625, + 0.10085106641054153, + 0.0881502702832222, + 0.011019160971045494, + 0.006030070595443249, + -0.012882355600595474, + -0.01701420359313488, + 0.022596944123506546, + -0.05345382168889046, + 0.02355102449655533, + -0.0091088330373168, + 0.00015542628534603864, + -0.0004997836658731103, + -0.006951311603188515, + 0.01267238613218069, + -0.0033983420580625534, + -0.0030770134180784225, + 0.02975126914680004, + 0.010702245868742466, + -0.016947058960795403, + 0.007774800062179565, + 0.09566964209079742, + 0.07426714897155762, + 0.1621979922056198, + 0.12728945910930634, + 0.06112523376941681, + 0.06061968579888344, + 0.07934501022100449, + 0.11534841358661652, + 0.10001469403505325, + 0.15475066006183624, + 0.1828109323978424, + 0.02134544588625431, + -0.015320047736167908, + 0.012000483460724354, + -0.014393450692296028, + -1.5520576238632202, + -1.2115217447280884, + 0.017239907756447792, + -0.007013735361397266, + 0.0019166347337886691, + 0.025112343952059746, + 0.1803419440984726, + -0.30807924270629883, + -0.33957329392433167, + 0.10846519470214844, + 0.06151076406240463, + 0.054799750447273254, + 0.06235412135720253, + 0.09605015069246292, + 0.16495031118392944, + 0.12624189257621765, + 0.12234552949666977, + 0.006969878450036049, + 0.0033541936427354813, + 0.008165130391716957, + 0.035377491265535355, + -0.03170061111450195, + 0.019396571442484856, + -0.011411413550376892, + 0.019043665379285812, + 0.00957057997584343, + 0.0055394587107002735, + 0.05569477006793022, + 0.0076510305516421795, + 0.018707536160945892, + 0.06073765829205513, + 0.006503407843410969, + -0.0058801183477044106, + -0.03229741007089615, + 0.0386439748108387, + 0.03167358413338661, + 0.027749545872211456, + -0.04634377732872963, + -0.00019781991431955248, + 0.024982664734125137, + 0.009453915059566498, + 0.1091528981924057, + 0.21055325865745544, + 0.23810525238513947, + 0.13829846680164337, + -0.019112061709165573, + -0.0014926757430657744, + 0.01856786385178566, + 0.10649964213371277, + -0.8599057793617249, + -0.6383436322212219, + 0.10839059948921204, + -0.038730181753635406, + -0.030203847214579582, + -0.033147793263196945, + 0.18132103979587555, + -0.1427767276763916, + -0.11132896691560745, + 0.10957232862710953, + -0.00349965482018888, + 0.03486581891775131, + 0.016247740015387535, + 0.060106489807367325, + 0.1439678966999054, + 0.07201634347438812, + 0.07603273540735245, + -0.0072280303575098515, + 0.01600506529211998, + -0.012912745587527752, + 0.015192546881735325, + -0.034853674471378326, + 0.026164958253502846, + 0.001483929343521595, + 0.0508253313601017, + -0.010546445846557617, + -0.024398569017648697, + -0.0043407524935901165, + 0.0030393539927899837, + -0.009643012657761574, + -0.008882591500878334, + 0.01182172168046236, + 0.003359999740496278, + -0.01145304087549448, + -7.34154018573463e-05, + 0.007416137028485537, + -0.012022661976516247, + 0.013550116680562496, + -0.005982181057333946, + -0.019205773249268532, + -0.0811527743935585, + -0.06323252618312836, + -0.026379290968179703, + -0.04671972244977951, + -0.006205265875905752, + 0.05242094770073891, + 0.05065605416893959, + 0.01961991749703884, + 0.021542323753237724, + 0.04147094115614891, + 0.04451332613825798, + 0.05155060812830925, + 0.15659169852733612, + 0.4448348879814148, + 0.7207449078559875, + 0.8680058717727661, + 0.7269517779350281, + 0.36259666085243225, + 0.10394725203514099, + -0.20449180901050568, + -0.42664405703544617, + -0.7290332317352295, + -0.9376083016395569, + -0.735107958316803, + -0.3541502356529236, + -0.23789332807064056, + -0.10901623964309692, + -0.26809337735176086, + -0.38465574383735657, + -0.44440212845802307, + -0.4070444703102112, + -0.22405119240283966, + -0.14190013706684113, + 0.07151509076356888, + 0.21848519146442413, + 0.41893038153648376, + 0.4783499836921692, + 0.4281534254550934, + 0.28631147742271423, + 0.057699400931596756, + 0.0029010034631937742, + -0.02580493874847889, + -0.02152368798851967, + -0.025850815698504448, + 0.004789783153682947, + 0.021941278129816055, + 0.00574735039845109, + -0.004016151186078787, + -0.014377521350979805, + -0.0828985944390297, + -0.06380187720060349, + -0.048879947513341904, + -0.04580164700746536, + -0.030843649059534073, + 0.024663949385285378, + 0.03409295156598091, + 0.060452476143836975, + 0.037006158381700516, + 0.058853648602962494, + 0.07275765389204025, + 0.02882941998541355, + 0.14549848437309265, + 0.4268765151500702, + 0.7150183320045471, + 0.8942612409591675, + 0.7532845139503479, + 0.3846176564693451, + 0.15604183077812195, + -0.19108416140079498, + -0.42633384466171265, + -0.7508237361907959, + -0.9448286890983582, + -0.719300389289856, + -0.3583783805370331, + -0.2060524821281433, + -0.10382426530122757, + -0.2624296545982361, + -0.4049411416053772, + -0.4338999092578888, + -0.41390693187713623, + -0.22797809541225433, + -0.14593803882598877, + 0.08197329193353653, + 0.2430788278579712, + 0.3906225562095642, + 0.47147202491760254, + 0.42429792881011963, + 0.29326340556144714, + 0.06683206558227539, + 0.004355552606284618, + -0.007973028346896172, + 0.0035172239877283573, + -0.0018502225866541266, + -0.015291260555386543, + 0.0025160792283713818, + 0.0015979957534000278, + 0.011951611377298832, + -0.0004334237310104072, + -0.00172338483389467, + 0.017284434288740158, + -0.00445173867046833, + -0.004828867502510548, + 0.004030159674584866, + 0.03321678191423416, + -0.016998661682009697, + -0.029765218496322632, + -0.07912255078554153, + -0.0494595468044281, + 0.012136446312069893, + 0.029541414231061935, + -0.01129366084933281, + 0.09502168744802475, + 0.21533286571502686, + 0.3453419804573059, + 0.22987395524978638, + 0.04720258712768555, + 0.0032486498821526766, + -0.0042808204889297485, + -0.10162857174873352, + -0.21601493656635284, + -0.3040534257888794, + -0.19600912928581238, + -0.0568307563662529, + -0.0062937624752521515, + -0.021828925237059593, + -0.03831009939312935, + -0.08992031216621399, + -0.08103442937135696, + -0.07600760459899902, + -0.02319694682955742, + -0.008472982794046402, + -0.004151565954089165, + 0.05002164468169212, + 0.0985124409198761, + 0.11273156106472015, + 0.10279814153909683, + 0.032678257673978806, + -0.023295480757951736, + -0.022312145680189133, + 0.032877422869205475, + 0.08301658928394318, + -0.049675002694129944, + -0.05956050381064415, + 0.006878976244479418, + 0.011597251519560814, + -0.03617611899971962, + -0.005020621232688427, + 0.0066283573396503925, + 0.061849869787693024, + 0.0668889507651329, + -0.1120104044675827, + 0.0215831957757473, + -0.008177083916962147, + 0.019240612164139748, + -0.03794482350349426, + -0.21581093966960907, + 0.3248063623905182, + 0.0525924488902092, + -0.13873063027858734, + -0.030904211103916168, + -0.004122832324355841, + 0.2784009277820587, + -0.42068102955818176, + -0.15351417660713196, + 0.4266241192817688, + -0.10780557245016098, + 0.03840374946594238, + -0.15116721391677856, + 0.2292502224445343, + 0.23400554060935974, + -0.5023872256278992, + 0.14868289232254028, + 0.09809935092926025, + 0.03480924293398857, + -0.046804867684841156, + -0.14212554693222046, + 0.3073779344558716, + -0.029529480263590813, + -0.13998086750507355, + -0.02750661037862301, + 0.010526027530431747, + 0.032874979078769684, + -0.07645174115896225, + -0.02746269293129444, + 0.10902399569749832, + -0.00446560001000762, + -0.01339190173894167, + 0.003540819976478815, + -0.04410126060247421, + -0.10884726047515869, + 0.016081949695944786, + 0.15211890637874603, + 0.04027504846453667, + -0.05552368983626366, + 0.04718002676963806, + 0.014503135345876217, + -0.2764658033847809, + -0.16068166494369507, + 0.3356778621673584, + 0.06485499441623688, + -0.07164154946804047, + 0.084479421377182, + 0.2702949047088623, + -0.1339409202337265, + -0.9642015695571899, + 0.47433769702911377, + 0.4715694189071655, + -0.17669782042503357, + -0.04434441775083542, + 0.2641690671443939, + 0.7357130646705627, + -1.2222046852111816, + -0.8205837607383728, + 0.9091072678565979, + 0.14896778762340546, + -0.09332367032766342, + -0.16173647344112396, + 0.8782246708869934, + 0.3819980323314667, + -1.619883418083191, + 0.059255462139844894, + 0.42745286226272583, + -0.03186821565032005, + -0.16420172154903412, + 0.12124066799879074, + 0.8650834560394287, + -0.3728218674659729, + -0.5816569328308105, + 0.10949260741472244, + -0.010671291500329971, + -0.07903271913528442, + -0.09700250625610352, + 0.3192030191421509, + 0.2756008505821228, + -0.2616698145866394, + -0.11051242798566818, + 0.016789941117167473, + -0.0484573096036911, + -0.12333080172538757, + 0.0158428642898798, + 0.11172449588775635, + 0.014953864738345146, + -0.011746960692107677, + 0.05310823395848274, + 0.030244171619415283, + -0.23969320952892303, + -0.1039247065782547, + 0.285805881023407, + -0.04652552306652069, + -0.05380000174045563, + 0.05430186912417412, + 0.25547218322753906, + -0.06164371967315674, + -0.7386756539344788, + 0.4393811821937561, + 0.2623714804649353, + -0.1849273294210434, + -0.049713607877492905, + 0.1656467467546463, + 0.6638666391372681, + -0.899787187576294, + -0.5747878551483154, + 0.7465870976448059, + -0.025567445904016495, + -0.051771312952041626, + -0.19754628837108612, + 0.6828271746635437, + 0.4451557695865631, + -1.2559787034988403, + 0.07448688894510269, + 0.27905938029289246, + 0.003908769693225622, + -0.18454433977603912, + -0.011183545924723148, + 0.7449039816856384, + -0.228777676820755, + -0.47592073678970337, + 0.13784541189670563, + 0.019371675327420235, + -0.06424596160650253, + -0.1660400629043579, + 0.2080633044242859, + 0.2942465841770172, + -0.20263032615184784, + -0.0709841251373291, + -0.0021153483539819717, + -0.028180474415421486, + -0.021557176485657692, + 0.012511649169027805, + 0.06533018499612808, + 0.006560645066201687, + -0.01908997632563114, + -0.020228691399097443, + 0.10450740903615952, + 0.04476405307650566, + -0.20389842987060547, + -0.36356496810913086, + -0.18690945208072662, + 0.06581642478704453, + 0.005246834829449654, + -0.14777734875679016, + 0.04554577171802521, + 0.7314760088920593, + 1.1759854555130005, + 0.7747871279716492, + 0.08771117031574249, + 0.04425497353076935, + 0.14875195920467377, + -0.05036012455821037, + -1.0561891794204712, + -1.7835016250610352, + -1.313464879989624, + -0.4041728973388672, + -0.08825081586837769, + -0.18483860790729523, + -0.09619659930467606, + 0.6506555676460266, + 1.2331949472427368, + 1.057729721069336, + 0.3030258119106293, + 0.053314659744501114, + 0.10696353763341904, + 0.19720971584320068, + -0.19457301497459412, + -0.3546113669872284, + -0.3773464560508728, + 0.007737448439002037, + 0.007112926337867975, + -0.026632368564605713, + -0.07708505541086197, + 0.016982559114694595, + 0.03331448882818222, + 0.03235285356640816, + -0.04479134455323219, + 0.0062864539213478565, + -0.04983896017074585, + -0.014209658838808537, + 0.025105496868491173, + 0.07187403738498688, + -0.019782420247793198, + -0.0387532040476799, + 0.01098113413900137, + 0.10765481740236282, + -0.005502769257873297, + -0.29967597126960754, + -0.5370010733604431, + -0.25729984045028687, + 0.0341138020157814, + -0.01927473582327366, + -0.11736954003572464, + 0.09457080066204071, + 0.8881804943084717, + 1.5049697160720825, + 1.0347492694854736, + 0.22410355508327484, + -0.004720119293779135, + 0.1449226438999176, + -0.11916695535182953, + -1.2009364366531372, + -2.080855369567871, + -1.5549882650375366, + -0.5231477618217468, + -0.005029830615967512, + -0.11258674412965775, + 0.03710457682609558, + 0.9192798137664795, + 1.525830626487732, + 1.3018689155578613, + 0.44408130645751953, + 0.006972550880163908, + 0.07937697321176529, + 0.060622286051511765, + -0.4068094491958618, + -0.5964561104774475, + -0.6058750152587891, + -0.1743212193250656, + -0.0038881103973835707, + -0.04932431876659393, + -0.04989266395568848, + 0.07228495925664902, + 0.10359980911016464, + 0.11054171621799469, + 0.017031395807862282, + -0.012849675491452217, + -0.02224516123533249, + -0.019851619377732277, + 0.04567919671535492, + 0.12134519219398499, + 0.018673665821552277, + -0.03933878242969513, + 0.03506385162472725, + 0.07499910145998001, + -0.004981306381523609, + -0.269795298576355, + -0.4478399455547333, + -0.3141564130783081, + 0.014856644906103611, + -0.01102763693779707, + -0.11778493225574493, + -0.00048367868294008076, + 0.46917271614074707, + 0.8380635976791382, + 0.5829758048057556, + 0.14924737811088562, + 0.00504975114017725, + 0.1242799386382103, + 0.027800291776657104, + -0.5343790054321289, + -0.9185061454772949, + -0.6974499225616455, + -0.1733488291501999, + 0.028415951877832413, + -0.07513032108545303, + 0.010947657749056816, + 0.5501428246498108, + 0.8556726574897766, + 0.6854383945465088, + 0.21023745834827423, + -0.04757346957921982, + 0.028925150632858276, + -0.05005616322159767, + -0.4106282889842987, + -0.5990055203437805, + -0.5274976491928101, + -0.18928098678588867, + 0.007199999876320362, + 0.004744168370962143, + -0.006203897297382355, + 0.16117095947265625, + 0.20310591161251068, + 0.17358633875846863, + 0.057794276624917984, + 0.0018837900133803487, + -0.021730661392211914, + 0.03705505281686783, + 0.048999205231666565, + 0.017187459394335747, + -0.04760497808456421, + -0.06534644961357117, + 0.027641354128718376, + -0.02722003310918808, + -0.09557735174894333, + 0.2721945643424988, + 0.06861108541488647, + -0.17862513661384583, + 0.029542427510023117, + -0.028343068435788155, + -0.24357359111309052, + 0.2928915321826935, + 0.6317090392112732, + -0.5675624012947083, + -0.31298428773880005, + 0.119928739964962, + -0.04503166303038597, + 0.1997436285018921, + 0.9068917632102966, + -0.6105388402938843, + -1.176649808883667, + 0.391012579202652, + 0.21436090767383575, + 0.06404570490121841, + 0.4306352436542511, + -0.18372972309589386, + -1.6093186140060425, + 0.5129231810569763, + 0.8333584666252136, + -0.11607109010219574, + 0.024050598964095116, + -0.027272621169686317, + -0.8072280883789062, + 0.15613007545471191, + 1.0115277767181396, + -0.1886059194803238, + -0.1662863790988922, + -0.07484262436628342, + -0.11359186470508575, + -0.05765556916594505, + 0.48085057735443115, + 0.031143836677074432, + -0.20803743600845337, + 0.005643316078931093, + -0.011422591283917427, + -0.02063453011214733, + 0.010139239020645618, + 0.026931140571832657, + 0.02650240994989872, + 0.014503400772809982, + -0.030498046427965164, + 0.01038119662553072, + -0.041832923889160156, + -0.11747029423713684, + 0.24838468432426453, + 0.08126607537269592, + -0.17684465646743774, + 0.009867151267826557, + -0.04349489137530327, + -0.22892898321151733, + 0.3097872734069824, + 0.6229272484779358, + -0.5710748434066772, + -0.2540203332901001, + 0.15970031917095184, + -0.05765099450945854, + 0.24631772935390472, + 0.9121918678283691, + -0.6539115309715271, + -1.1680796146392822, + 0.43742635846138, + 0.1981748640537262, + 0.060766786336898804, + 0.48115089535713196, + -0.2704729437828064, + -1.668082594871521, + 0.6258481740951538, + 0.8217618465423584, + -0.17844447493553162, + 0.07583325356245041, + -0.031355466693639755, + -0.884739100933075, + 0.21298757195472717, + 1.0279508829116821, + -0.2118954360485077, + -0.16616611182689667, + -0.025157395750284195, + -0.11329160630702972, + -0.08147483319044113, + 0.46636614203453064, + 0.023730026558041573, + -0.21343427896499634, + -0.015201984904706478, + -0.00498165050521493, + 0.022955382242798805, + 0.020228328183293343, + -0.029405873268842697, + -0.032065436244010925, + 0.047389160841703415, + -0.01793060638010502, + 0.01669210195541382, + 0.05227159336209297, + -0.11703876405954361, + 0.006789325270801783, + 0.03741219639778137, + -0.04651298373937607, + -0.012846981175243855, + 0.024231625720858574, + -0.13399703800678253, + -0.024073680862784386, + 0.2970501184463501, + -0.1497301310300827, + -0.04287628084421158, + 0.08405227214097977, + -0.06020639091730118, + -0.01648692972958088, + 0.4150170087814331, + -0.17000712454319, + -0.43461430072784424, + 0.27202337980270386, + 0.006708468310534954, + -0.04474359005689621, + 0.15199843049049377, + -0.03348325565457344, + -0.6591396331787109, + 0.4057810306549072, + 0.25226324796676636, + -0.16070741415023804, + 0.03464199975132942, + 0.023064177483320236, + -0.35642316937446594, + 0.22774185240268707, + 0.37138837575912476, + -0.24171461164951324, + -0.023513946682214737, + 0.028774995356798172, + -0.02702418342232704, + -0.012504744343459606, + 0.17893734574317932, + -0.1554262489080429, + -0.09501983970403671, + 0.06177212670445442, + -0.013536165468394756, + 0.012441401369869709, + 0.006566522642970085, + -0.018207622691988945, + 0.003373368876054883, + -0.034891802817583084, + 0.002223123563453555, + 0.006169564090669155, + 0.022658145055174828, + -0.005327044054865837, + -0.023764559999108315, + -0.004386506043374538, + -0.02777106687426567, + 0.01950058527290821, + 0.004401096608489752, + 0.02882237359881401, + 0.01790205016732216, + -0.007827110588550568, + -0.005222277250140905, + -0.05361752584576607, + 0.008359426632523537, + -0.026494475081562996, + -0.015572195872664452, + -0.04412947595119476, + -0.006163781508803368, + 0.180303692817688, + 0.17117105424404144, + -0.014117442071437836, + 0.014543564058840275, + 0.03875281661748886, + 0.002004631096497178, + 0.11982911080121994, + 0.609316349029541, + 0.5792325735092163, + 0.10267578810453415, + -0.02287464588880539, + -0.011516223661601543, + -0.02587946131825447, + 0.019127164036035538, + 0.2742871046066284, + 0.23896890878677368, + -0.013414637185633183, + 0.012439075857400894, + 0.01148916780948639, + 0.0024075021501630545, + -0.028374193236231804, + -0.02938784286379814, + -0.061723873019218445, + -0.03288640081882477, + 0.010918691754341125, + 0.01171314436942339, + 0.00894222967326641, + -0.0050367508083581924, + 0.00322812981903553, + -0.01958087645471096, + 0.000401448953198269, + 0.00655051926150918, + 0.008647873997688293, + -0.015351405367255211, + -0.022286182269454002, + -0.0018973759142681956, + -0.032965533435344696, + 0.009401706047356129, + 0.01680464670062065, + 0.01722409576177597, + 0.017367251217365265, + -0.0012145076179876924, + 0.015895379707217216, + -0.013976357877254486, + 0.01587546430528164, + -0.019388504326343536, + -0.004597584251314402, + -0.026080038398504257, + 0.020517753437161446, + 0.20680218935012817, + 0.20302064716815948, + 0.03813354671001434, + 0.027738921344280243, + 0.02183712273836136, + 0.023807305842638016, + 0.14632326364517212, + 0.5991678237915039, + 0.608651340007782, + 0.15929070115089417, + -0.02112223394215107, + -0.020013611763715744, + -0.03723381832242012, + 0.032139480113983154, + 0.27032363414764404, + 0.24862462282180786, + 0.02374681644141674, + 0.007894856855273247, + 0.00042308925185352564, + -0.004832752980291843, + -0.024313796311616898, + -0.0018940505106002092, + -0.02681432105600834, + 0.002362651750445366, + 0.013330202549695969, + 0.012553646229207516, + 0.002630018163472414, + 0.002979951212182641, + 0.0015847217291593552, + -0.03376828506588936, + -0.010844729840755463, + -0.002748559694737196, + 0.012938202358782291, + -0.011872833594679832, + -0.0025761008728295565, + 0.003677211469039321, + -0.04305516183376312, + 0.001133457524701953, + 0.0020396243780851364, + 0.01797032356262207, + 0.016580887138843536, + 0.04445189982652664, + 0.013270077295601368, + -0.04839251935482025, + 0.011546633206307888, + -0.015829432755708694, + 0.019473392516374588, + -0.011464826762676239, + 0.018693143501877785, + 0.18201367557048798, + 0.16157257556915283, + 0.02082117274403572, + 0.015915032476186752, + 0.010720869526267052, + -0.0020238866563886404, + 0.09329187124967575, + 0.46998023986816406, + 0.5186727046966553, + 0.09814783185720444, + -0.016547314822673798, + 0.00325066689401865, + -0.028936590999364853, + 0.01002424769103527, + 0.21822214126586914, + 0.22012007236480713, + 0.008229314349591732, + 0.015599996782839298, + 0.014740276150405407, + 0.0019725109450519085, + 0.003613655688241124, + -0.03043546713888645, + -0.06308998167514801, + 0.014664110727608204, + 0.06775129586458206, + -0.12990300357341766, + -0.03638269379734993, + -0.03883139044046402, + 0.05194637551903725, + 0.03896122798323631, + -0.05132362246513367, + -0.07234688848257065, + -0.36106064915657043, + -0.2839237451553345, + -0.11496391147375107, + 0.3026673197746277, + 0.3528609871864319, + 0.21559017896652222, + -0.11970120668411255, + -0.5473688244819641, + -0.5362005233764648, + -0.21015112102031708, + 0.4089161455631256, + 0.6033567786216736, + 0.38614287972450256, + -0.12437233328819275, + -0.6394402384757996, + -0.6945835947990417, + -0.3482857942581177, + 0.5189254283905029, + 0.8457668423652649, + 0.6248002648353577, + -0.12700730562210083, + -0.6978924870491028, + -0.7764106392860413, + -0.4171960651874542, + 0.44747814536094666, + 0.8406224846839905, + 0.6821274161338806, + -0.07793218642473221, + -0.5459966659545898, + -0.6139025092124939, + -0.35998886823654175, + 0.27800890803337097, + 0.6048891544342041, + 0.591307520866394, + -0.04850815609097481, + -0.3863481283187866, + -0.3542836606502533, + -0.2491992861032486, + 0.1616278886795044, + 0.3402666747570038, + 0.4610227644443512, + -0.010262396186590195, + 0.0408165417611599, + 0.006382474210113287, + -0.011430315673351288, + -0.027895113453269005, + -0.009767768904566765, + 0.005882019177079201, + 0.05225436016917229, + 0.0415218211710453, + 0.08244743943214417, + 0.026765575632452965, + -0.05404946208000183, + -0.06101839989423752, + -0.028233220800757408, + 0.03128793090581894, + 0.07133004069328308, + 0.0718698799610138, + 0.042146697640419006, + -0.08380170166492462, + -0.09263177216053009, + -0.07569421827793121, + 0.032425008714199066, + 0.12351400405168533, + 0.09103626012802124, + -0.004768018145114183, + -0.05960838869214058, + -0.11922567337751389, + -0.10132396221160889, + 0.044341862201690674, + 0.100867860019207, + 0.09607693552970886, + -0.00129030947573483, + -0.05481477826833725, + -0.1278291642665863, + -0.12058380991220474, + 0.016678951680660248, + 0.09958931058645248, + 0.08456224203109741, + 0.061599165201187134, + -0.049776893109083176, + -0.11354166269302368, + -0.09844806790351868, + 0.004753128159791231, + 0.07868346571922302, + 0.06464104354381561, + 0.020981626585125923, + -0.010770543478429317, + -0.08838209509849548, + -0.07265795767307281, + -0.058313023298978806, + 0.10897739976644516, + 0.026735201478004456, + 0.03972309082746506, + -0.019998662173748016, + -0.048948734998703, + 0.03377270698547363, + 0.053406376391649246, + 0.27304399013519287, + 0.20850272476673126, + 0.07890326529741287, + -0.22241365909576416, + -0.2816997468471527, + -0.1745096743106842, + 0.08957889676094055, + 0.4962941110134125, + 0.4586986303329468, + 0.20177948474884033, + -0.3625744581222534, + -0.47758376598358154, + -0.32412785291671753, + 0.0669194757938385, + 0.5394997596740723, + 0.601328432559967, + 0.24388420581817627, + -0.4319041073322296, + -0.6893490552902222, + -0.5106037259101868, + 0.10174300521612167, + 0.5457565784454346, + 0.6549625992774963, + 0.38772058486938477, + -0.3778320252895355, + -0.6820934414863586, + -0.551069438457489, + 0.049600999802351, + 0.45137161016464233, + 0.5143972039222717, + 0.3713279068470001, + -0.26546329259872437, + -0.5121409893035889, + -0.47691628336906433, + 0.03843758627772331, + 0.30808231234550476, + 0.3185756504535675, + 0.22629432380199432, + -0.14860986173152924, + -0.2915389835834503, + -0.3552006185054779, + -0.003137432038784027, + -0.01327254343777895, + -0.027139298617839813, + 0.04800891876220703, + 0.05380738899111748, + -0.01380784809589386, + 0.0022881641052663326, + -0.012132279574871063, + 0.06182793900370598, + 0.03762871399521828, + 0.0966145321726799, + 0.08963571488857269, + 0.06551238149404526, + 0.031640589237213135, + -0.010532311163842678, + 0.07195396721363068, + 0.11343465745449066, + 0.11621421575546265, + 0.047318290919065475, + 0.1111951395869255, + 0.044054243713617325, + 0.016777141019701958, + 0.03392713516950607, + 0.06047024950385094, + -0.7924502491950989, + -0.7310910224914551, + 0.031088173389434814, + 0.0906061977148056, + 0.022829236462712288, + 0.04470035433769226, + 0.025999872013926506, + -0.8246837258338928, + -0.723675549030304, + 0.15835590660572052, + 0.07358791679143906, + -0.015819497406482697, + -0.014207872562110424, + 0.08506257086992264, + 0.08868777751922607, + 0.0976945012807846, + 0.11740022897720337, + 0.016287995502352715, + -0.024363648146390915, + 0.04249691963195801, + 0.02909177541732788, + 0.12011238187551498, + 0.10729824751615524, + 0.05927390977740288, + 0.04731644690036774, + 0.008210064843297005, + 0.03859357163310051, + -0.005175672471523285, + 0.01984376832842827, + -0.0011626111809164286, + -0.0010909241391345859, + 0.02311880886554718, + 0.007646523881703615, + 0.04582137614488602, + -0.0027255103923380375, + 0.027656713500618935, + 0.02781369723379612, + 0.015750093385577202, + 0.040563344955444336, + -0.007784596644341946, + 0.006534814368933439, + 0.002403199439868331, + -0.020037032663822174, + -0.011717663146555424, + 0.07826739549636841, + 0.018203573301434517, + 0.021228624507784843, + 0.014112413860857487, + -0.02866269089281559, + -0.9502679109573364, + -0.825043797492981, + 0.05938851460814476, + 0.06553053110837936, + 0.015418429858982563, + 0.0616452619433403, + -0.0094453701749444, + -0.9471839666366577, + -0.7922234535217285, + 0.13069523870944977, + 0.04939320683479309, + 0.007429714780300856, + 0.022599652409553528, + 0.0820123627781868, + 0.06440276652574539, + 0.09897352755069733, + 0.0856291800737381, + 0.006608777679502964, + -0.0005533680086955428, + 0.021656949073076248, + 0.014818831346929073, + 0.03757459297776222, + -0.001428246614523232, + 0.03473127633333206, + 0.03607869893312454, + 0.017313262447714806, + 0.0025767614133656025, + -0.033292777836322784, + 0.027883101254701614, + -0.007534499745815992, + -0.04302362725138664, + -0.01795666106045246, + -0.007667913101613522, + 0.012547189369797707, + -0.021762438118457794, + 0.03789107874035835, + 0.06384614109992981, + 0.0014223429607227445, + -0.01393786258995533, + -0.041693057864904404, + -0.01813604310154915, + 0.065328449010849, + 0.15736474096775055, + 0.1531635969877243, + 0.09920474886894226, + -0.04044449329376221, + 0.010558396577835083, + 0.05559245124459267, + 0.10931257158517838, + -0.5784384608268738, + -0.5109886527061462, + 0.17690584063529968, + 0.07484250515699387, + 0.010378374718129635, + 0.0890144556760788, + 0.13172735273838043, + -0.6058865785598755, + -0.49908995628356934, + 0.1835336685180664, + 0.005293308291584253, + -0.03870566934347153, + -0.025229454040527344, + 0.12571711838245392, + 0.14792272448539734, + 0.14905226230621338, + 0.0700206533074379, + -0.035034529864788055, + 0.013128797523677349, + 0.015581230632960796, + 0.005400130525231361, + 0.07070232182741165, + 0.03829728811979294, + -0.013876918703317642, + -0.019958000630140305, + -0.020086020231246948, + -0.019999003037810326, + -0.015111410059034824, + 0.11963249742984772, + -0.08270428329706192, + -0.0025947154499590397, + -0.010668564587831497, + 0.016670405864715576, + -0.03206938877701759, + -0.053453829139471054, + 0.1236601173877716, + -0.020077411085367203, + 0.00779569149017334, + -0.0318986251950264, + 0.03579804673790932, + -0.060723867267370224, + -0.009301809594035149, + 0.09249342232942581, + -0.13378725945949554, + 0.17496798932552338, + -0.0935625433921814, + 0.06569044291973114, + -0.18187756836414337, + 0.06397300213575363, + 0.3793930113315582, + -0.5664302706718445, + 0.23658618330955505, + -0.03206830099225044, + 0.03155658766627312, + 0.039305318146944046, + -0.6008145213127136, + 1.0417630672454834, + -0.5062726140022278, + -0.04698493704199791, + 0.0979752242565155, + -0.037326715886592865, + 0.26255178451538086, + -0.590207576751709, + 0.4195419251918793, + 0.12212422490119934, + -0.26122942566871643, + 0.06442253291606903, + -0.07682429254055023, + 0.12608948349952698, + -0.13872937858104706, + -0.030260663479566574, + 0.2047160565853119, + -0.13068141043186188, + 0.016608506441116333, + -0.021629147231578827, + 0.04659907519817352, + 0.024417348206043243, + 0.06751634925603867, + -0.1705978959798813, + 0.0655774399638176, + -0.0041802311316132545, + -0.02263445220887661, + -0.014069054275751114, + 0.06242800131440163, + 0.08984102308750153, + -0.19382472336292267, + 0.09380361437797546, + -0.0032764992211014032, + -0.03950225189328194, + -0.08896161615848541, + 0.28387022018432617, + 0.1668996810913086, + -0.5457127094268799, + 0.21796099841594696, + 0.012032964266836643, + 0.030721815302968025, + -0.4431600570678711, + 0.3104412257671356, + 1.0070439577102661, + -1.1077969074249268, + 0.08187273889780045, + 0.1387241780757904, + 0.09014563262462616, + -0.25378379225730896, + -0.9253583550453186, + 1.9745515584945679, + -0.6605072617530823, + -0.4394792318344116, + 0.11501576751470566, + 0.03007262572646141, + 0.2538164258003235, + -1.1462018489837646, + 0.7988958954811096, + 0.46934643387794495, + -0.4244523048400879, + -0.0001816617150325328, + -0.04351970925927162, + 0.20500127971172333, + -0.40710335969924927, + -0.15871365368366241, + 0.4640160799026489, + -0.06024328991770744, + -0.016036653891205788, + -0.012419192120432854, + 0.05552554875612259, + 0.050986770540475845, + -0.0171927809715271, + -0.12105240672826767, + 0.03947274759411812, + 0.009537882171571255, + -0.026668362319469452, + 0.017273351550102234, + 0.10812800377607346, + -0.015008139424026012, + -0.14154496788978577, + 0.08008233457803726, + -0.01306608971208334, + -0.05574854835867882, + -0.06091056764125824, + 0.2888447940349579, + 0.05022002384066582, + -0.4581625759601593, + 0.21146118640899658, + -0.01495362538844347, + 0.02946372702717781, + -0.38554418087005615, + 0.30167311429977417, + 0.7605867981910706, + -0.898481547832489, + 0.11953620612621307, + 0.12686115503311157, + 0.09949854761362076, + -0.14409342408180237, + -0.7404491901397705, + 1.5449001789093018, + -0.5307857394218445, + -0.3347839415073395, + 0.09940771013498306, + 0.009087899699807167, + 0.3081797957420349, + -0.9053899049758911, + 0.5102643370628357, + 0.4646914303302765, + -0.36200836300849915, + -0.043260715901851654, + -0.05309509113430977, + 0.22480911016464233, + -0.2674587666988373, + -0.25316888093948364, + 0.435017466545105, + -0.017485838383436203, + -0.049459364265203476, + 0.012460661120712757, + -0.02262282371520996, + -0.04392899200320244, + 0.013330060057342052, + 0.05963548645377159, + -0.020561739802360535, + -0.013496879488229752, + -0.02310933545231819, + -0.06549905985593796, + 0.12132573872804642, + 0.22165189683437347, + -0.07683887332677841, + -0.12427931278944016, + 0.05543455854058266, + 0.009089780040085316, + 0.19844494760036469, + 0.07650767266750336, + -0.48934996128082275, + -0.35080164670944214, + 0.13422781229019165, + 0.022217294201254845, + -0.006589306052774191, + -0.18357548117637634, + -0.6055922508239746, + 0.09492127597332001, + 0.7073907256126404, + 0.1777055710554123, + -0.05434347689151764, + 0.04566245526075363, + -0.023967979475855827, + 0.4856843054294586, + 0.8131930828094482, + -0.2068077027797699, + -0.3863125145435333, + 0.02887917123734951, + -0.05048410966992378, + 0.051201049238443375, + 0.057671088725328445, + -0.6412642002105713, + -0.39739903807640076, + 0.11036981642246246, + 0.06687764078378677, + -0.018151026219129562, + 0.0022760110441595316, + -0.09328305721282959, + 0.1352599710226059, + 0.19680921733379364, + 0.032235175371170044, + -0.06123670935630798, + -0.013810456730425358, + -0.01821190118789673, + -0.029903864488005638, + 0.027588335797190666, + 0.0762094110250473, + -0.046041399240493774, + 0.017117975279688835, + -0.018925148993730545, + 0.00423092395067215, + 0.2065701186656952, + 0.157025545835495, + -0.26491472125053406, + -0.24569831788539886, + 0.0873267725110054, + 0.004694689530879259, + 0.1838335543870926, + -0.18973900377750397, + -0.9744532108306885, + -0.41959065198898315, + 0.409589946269989, + 0.22223009169101715, + -0.0989728644490242, + -0.40883490443229675, + -0.8418471813201904, + 0.40256521105766296, + 1.4742398262023926, + 0.4913789629936218, + -0.14741277694702148, + -0.0028576564509421587, + 0.0861843004822731, + 1.0056577920913696, + 1.479182481765747, + -0.21940617263317108, + -0.8383130431175232, + -0.30560192465782166, + 0.12028121203184128, + 0.24013034999370575, + 0.11750353127717972, + -1.1071972846984863, + -0.9066778421401978, + -0.055051110684871674, + 0.15361995995044708, + 0.0032418384216725826, + -0.08823435008525848, + -0.3188804090023041, + -0.02160414680838585, + 0.2972750663757324, + 0.17006494104862213, + 0.03401973098516464, + 0.017106015235185623, + 0.010733614675700665, + 0.004688877146691084, + 0.02985573373734951, + 0.046415988355875015, + -0.05177726596593857, + -0.04624386876821518, + 0.026672907173633575, + 0.03479000926017761, + 0.22761401534080505, + 0.12049756944179535, + -0.23494181036949158, + -0.2207801640033722, + 0.06036320701241493, + 0.02112250216305256, + 0.16173022985458374, + -0.14196650683879852, + -0.8236543536186218, + -0.3530665934085846, + 0.3715725541114807, + 0.25781863927841187, + -0.09806561470031738, + -0.341796338558197, + -0.7201419472694397, + 0.2111824005842209, + 1.1648427248001099, + 0.3866075575351715, + -0.1955428272485733, + -0.13164694607257843, + -0.06048528477549553, + 0.7989920973777771, + 1.143347144126892, + -0.19509637355804443, + -0.6719933152198792, + -0.26912447810173035, + 0.16733723878860474, + 0.32526257634162903, + 0.1910397708415985, + -0.8516904711723328, + -0.6005953550338745, + 0.10627525299787521, + 0.16700856387615204, + 0.032433755695819855, + -0.11345972120761871, + -0.270126610994339, + -0.012052524834871292, + 0.25489771366119385, + 0.14647918939590454, + -0.014324051328003407, + -0.011148945428431034, + -0.0011708218371495605, + -0.018903911113739014, + -0.010648071765899658, + -0.017981043085455894, + 0.014055400155484676, + -0.020784996449947357, + -0.030126383528113365, + 0.1150858998298645, + -0.1112036183476448, + -0.023664508014917374, + 0.1651369333267212, + -0.055412910878658295, + -0.007318025920540094, + -0.07404221594333649, + 0.3068569302558899, + -0.6175673007965088, + 0.35226404666900635, + 0.1940349042415619, + -0.22921296954154968, + 0.06411048769950867, + 0.001689439988695085, + 0.23336739838123322, + -0.9470900893211365, + 1.2042961120605469, + -0.44587329030036926, + -0.15847182273864746, + 0.07572423666715622, + 0.11138042062520981, + -0.2075018584728241, + -0.2651064693927765, + 0.8896074295043945, + -0.7130936980247498, + 0.10370831191539764, + 0.07730382680892944, + 0.02368813008069992, + -0.20520009100437164, + 0.13611918687820435, + 0.31062978506088257, + -0.471883624792099, + 0.21489326655864716, + -0.0216743852943182, + -0.04020361602306366, + -0.022920167073607445, + 0.16054102778434753, + -0.002624030224978924, + -0.14670424163341522, + 0.12018264085054398, + -0.043656397610902786, + -0.005084550939500332, + 0.03873870149254799, + -0.07967288792133331, + -0.007439201697707176, + 0.027688704431056976, + 0.08916077762842178, + -0.0036629599053412676, + -0.01389122661203146, + 0.1402083784341812, + -0.2923351228237152, + -0.01932896114885807, + 0.224355086684227, + -0.013193303719162941, + -0.03984276205301285, + -0.04474477842450142, + 0.3302844762802124, + -0.9746807217597961, + 0.5603556036949158, + 0.3556183874607086, + -0.2713812589645386, + 0.01890619471669197, + 0.06983876973390579, + 0.09052442759275436, + -1.3613605499267578, + 1.8220031261444092, + -0.40902698040008545, + -0.31302449107170105, + 0.03893759846687317, + 0.11448371410369873, + -0.4220678210258484, + -0.3677598237991333, + 1.539440631866455, + -0.8297391533851624, + -0.08504960685968399, + 0.0629446730017662, + -0.016804160550236702, + -0.31778836250305176, + 0.2363198846578598, + 0.6452136635780334, + -0.700931191444397, + 0.09927428513765335, + 0.0019635935313999653, + -0.05397690460085869, + -0.014552262611687183, + 0.2352754771709442, + 0.09991656988859177, + -0.28891685605049133, + 0.07818552106618881, + -0.021534763276576996, + -0.009461677633225918, + -0.01069199200719595, + -0.008059840649366379, + -0.0129952197894454, + 0.038492631167173386, + 0.018906958401203156, + -0.025432486087083817, + -0.03420932963490486, + 0.09104404598474503, + -0.10342919826507568, + -0.035048507153987885, + 0.1415904313325882, + -0.052986644208431244, + -0.021596742793917656, + -0.049690280109643936, + 0.3079117238521576, + -0.5487046837806702, + 0.27024003863334656, + 0.15158434212207794, + -0.16488635540008545, + 0.027642132714390755, + 0.004561549983918667, + 0.21555493772029877, + -0.9188903570175171, + 1.0972669124603271, + -0.3528037667274475, + -0.07574182748794556, + 0.021962830796837807, + 0.08826783299446106, + -0.18681983649730682, + -0.2789378762245178, + 0.864517331123352, + -0.5642455816268921, + 0.07469761371612549, + 0.03803368657827377, + 0.014268620871007442, + -0.17712704837322235, + 0.1349189728498459, + 0.3181247115135193, + -0.45067182183265686, + 0.1391848623752594, + 0.009777083061635494, + -0.028080958873033524, + -0.03586730733513832, + 0.14503192901611328, + -0.014655024744570255, + -0.1472700834274292, + 0.07361634075641632, + -0.0029754601418972015, + -0.006887470372021198, + -0.019166842103004456, + 0.0034907464869320393, + -0.015169994905591011, + 0.053831856697797775, + -0.028789488598704338, + -0.02033298648893833, + 0.0018537036376073956, + 0.07567961513996124, + -0.07041627168655396, + -0.047083087265491486, + 0.17573483288288116, + -0.04860217124223709, + 0.013171656988561153, + 0.020158233121037483, + -0.006270059384405613, + -0.28434091806411743, + 0.2760852873325348, + 0.32198208570480347, + -0.43535903096199036, + 0.03188510239124298, + 0.019360313192009926, + -0.20063988864421844, + 0.04450676590204239, + 0.9678076505661011, + -0.683987021446228, + -0.3979112207889557, + 0.2618143558502197, + -0.049711134284734726, + -0.06456997990608215, + 0.6518288850784302, + -0.1357039213180542, + -1.1304017305374146, + 0.4881652295589447, + 0.19583553075790405, + -0.03677722439169884, + 0.21429045498371124, + 0.09559855610132217, + -0.7311355471611023, + 0.10988117009401321, + 0.4949330687522888, + -0.17359353601932526, + 0.03822369873523712, + 0.011371256783604622, + -0.1900172382593155, + -0.04778448864817619, + 0.2897090017795563, + -0.02235160581767559, + -0.05582524091005325, + 0.007624597754329443, + -0.027456223964691162, + -0.029680097475647926, + -0.023810429498553276, + 0.15409281849861145, + 0.013284318149089813, + -0.0788225457072258, + -0.025637971237301826, + 0.01406402699649334, + -0.13676859438419342, + 0.027384959161281586, + 0.30458444356918335, + -0.11150643229484558, + -0.06806201487779617, + 0.009601237252354622, + -0.0866582989692688, + -0.2328706979751587, + 0.5188567638397217, + 0.3787381649017334, + -0.655829906463623, + 0.0072118742391467094, + -0.0031494891736656427, + -0.2424815446138382, + 0.28893929719924927, + 1.2396824359893799, + -1.0406886339187622, + -0.6376030445098877, + 0.4103420078754425, + -0.05929668992757797, + 0.03918358311057091, + 0.9274081587791443, + -0.28890565037727356, + -1.6682262420654297, + 0.66976398229599, + 0.35488471388816833, + 0.027932289987802505, + 0.3169145882129669, + 0.09107685089111328, + -1.2099432945251465, + 0.11623579263687134, + 0.7632684707641602, + -0.16506360471248627, + 0.037474747747182846, + -0.005203985143452883, + -0.35939401388168335, + -0.17138688266277313, + 0.525232195854187, + 0.10247340798377991, + -0.14317406713962555, + 0.007572649512439966, + -0.006046198774129152, + 0.06188087910413742, + -0.050851333886384964, + 0.032844241708517075, + 0.0544477179646492, + -0.07947597652673721, + -0.03073730878531933, + 0.04025515541434288, + -0.010001083835959435, + -0.11831062287092209, + 0.17422229051589966, + -0.05468267202377319, + -0.04996664077043533, + 0.023996006697416306, + 0.02888253889977932, + -0.18709556758403778, + 0.13987921178340912, + 0.32867854833602905, + -0.31714990735054016, + 0.019951285794377327, + 0.027247004210948944, + -0.19416090846061707, + -0.006519266404211521, + 0.7540720105171204, + -0.5474190711975098, + -0.27137213945388794, + 0.20772530138492584, + -0.042619917541742325, + -0.09566087275743484, + 0.548494815826416, + -0.1599852293729782, + -0.9178788661956787, + 0.5456539988517761, + 0.07497559487819672, + 0.003984459210187197, + 0.18640351295471191, + 0.12121234089136124, + -0.7249511480331421, + 0.2559764087200165, + 0.4684237241744995, + -0.19216996431350708, + 0.018075481057167053, + 0.02684594877064228, + -0.221074178814888, + -0.09164194762706757, + 0.3596596121788025, + -0.08310746401548386, + -0.10815230011940002, + -0.015406409278512001, + -0.011985878460109234, + 0.028467312455177307, + -0.0879230722784996, + 0.0347294844686985, + 0.05081191286444664, + 0.00362736196257174, + 0.010529003106057644, + -0.002672453410923481, + 0.025318201631307602, + -0.06232529878616333, + 0.008822780102491379, + 0.06744717806577682, + 0.003999210894107819, + -0.0022885131184011698, + -0.046704765409231186, + 0.13673964142799377, + -0.2590992748737335, + -0.022161437198519707, + 0.258914053440094, + -0.10650330036878586, + 0.023435762152075768, + 0.06992689520120621, + 0.03760937228798866, + -0.5444027185440063, + 0.4131152629852295, + 0.25325170159339905, + -0.2482522875070572, + 0.010479461401700974, + 0.045747850090265274, + -0.1541248857975006, + -0.35291528701782227, + 0.9078133702278137, + -0.34428781270980835, + -0.14787709712982178, + -0.024105649441480637, + -0.007651817053556442, + -0.14991067349910736, + 0.17544956505298615, + 0.3692120611667633, + -0.46861159801483154, + 0.10201738774776459, + 0.003734431229531765, + -0.010433703660964966, + 0.022045455873012543, + 0.0944862961769104, + 0.01679016835987568, + -0.16537833213806152, + 0.07900089025497437, + -0.004211293533444405, + -0.01076442189514637, + 0.09729930013418198, + -0.1490965485572815, + -0.02511671558022499, + 0.0766475573182106, + 0.010980346240103245, + -0.010220799595117569, + -0.0004861881607212126, + 0.09204736351966858, + -0.179045170545578, + -0.025164175778627396, + 0.15608654916286469, + 0.004787537269294262, + -0.0005253870622254908, + 0.034556396305561066, + 0.1509256660938263, + -0.5432079434394836, + -0.03155849874019623, + 0.513609766960144, + -0.14458952844142914, + 0.015178131870925426, + 0.09172039479017258, + -0.12612608075141907, + -0.926306962966919, + 0.8281942009925842, + 0.5954549908638, + -0.492740273475647, + 0.007195526268333197, + -0.018258413299918175, + -0.4074647128582001, + -0.43008187413215637, + 1.7370752096176147, + -0.350849986076355, + -0.5158001780509949, + -0.017458094283938408, + -0.08306471258401871, + -0.2334563285112381, + 0.445117712020874, + 0.7808031439781189, + -0.7913723587989807, + -0.11814796179533005, + -0.00913319457322359, + 0.0223994143307209, + 0.1012248545885086, + 0.25349485874176025, + 0.028286214917898178, + -0.4809858798980713, + 0.05953341722488403, + 0.015634188428521156, + 0.005101620219647884, + 0.10901974141597748, + -0.11964976042509079, + -0.09117673337459564, + 0.0734483003616333, + 0.01821213960647583, + 5.350751234800555e-05, + -0.020279232412576675, + 0.1097220927476883, + -0.1354990452528, + -0.08653146773576736, + 0.11775246262550354, + -0.012575668282806873, + 0.0310806967318058, + 0.010271146893501282, + 0.20337054133415222, + -0.3854014277458191, + -0.09943562000989914, + 0.3921409249305725, + -0.08432158827781677, + 0.010676748119294643, + 0.040244489908218384, + -0.0015478944405913353, + -0.7022866010665894, + 0.49858638644218445, + 0.42338883876800537, + -0.2982582449913025, + -0.005396307446062565, + -0.008777705952525139, + -0.2325415015220642, + -0.4083922803401947, + 1.186205506324768, + -0.26399391889572144, + -0.2621048092842102, + -0.015712907537817955, + -0.04675402492284775, + -0.1797540783882141, + 0.2992522716522217, + 0.4747498333454132, + -0.5266988277435303, + 0.04581758379936218, + -0.04037958011031151, + 0.0071074217557907104, + 0.047499995678663254, + 0.16617828607559204, + -0.03973710536956787, + -0.2953551113605499, + 0.10628587752580643, + -0.00904526561498642, + 0.010427894070744514, + 0.08035022020339966, + 0.03841109946370125, + -0.06335253268480301, + -0.06992083787918091, + 0.015409895218908787, + -0.026900725439190865, + -0.04523912072181702, + 0.08087682723999023, + 0.12542113661766052, + 0.018750213086605072, + -0.23430712521076202, + 0.11755944788455963, + -0.019747508689761162, + -0.03171322122216225, + -0.12132623791694641, + 0.2640603184700012, + 0.38445138931274414, + -0.5724408030509949, + 0.15661633014678955, + 0.01949799247086048, + -0.021771302446722984, + -0.18984957039356232, + -0.23499636352062225, + 1.2112919092178345, + -0.7037869095802307, + -0.14260035753250122, + 0.01848726160824299, + 0.06443414837121964, + -0.11740390956401825, + -0.8794785141944885, + 1.4160369634628296, + 0.016899125650525093, + -0.5444768071174622, + 0.017313210293650627, + 0.0508052259683609, + 0.11102095246315002, + -0.790285587310791, + 0.3501206636428833, + 0.7238660454750061, + -0.49468666315078735, + -0.019021952524781227, + -0.01212992612272501, + 0.15032203495502472, + -0.3573611080646515, + -0.1293754130601883, + 0.45295456051826477, + -0.08407819271087646, + -0.008717959746718407, + 0.022566653788089752, + -0.012640242464840412, + 0.03181227669119835, + 0.0638526976108551, + -0.058120664209127426, + -0.042917650192976, + 0.02129550836980343, + -0.018790805712342262, + -0.00655191857367754, + 0.05951414257287979, + 0.12890471518039703, + -0.1886381357908249, + 0.059096939861774445, + -0.016928592696785927, + 0.02327263168990612, + -0.17282842099666595, + 0.13812857866287231, + 0.38889989256858826, + -0.5282873511314392, + 0.07564643770456314, + -0.006128210574388504, + -0.00876594614237547, + -0.18427829444408417, + -0.26697441935539246, + 1.2529815435409546, + -0.6549165844917297, + -0.2111111879348755, + 0.011410325765609741, + 0.07089994102716446, + -0.12627695500850677, + -0.8245998024940491, + 1.4581915140151978, + -0.01822204887866974, + -0.5626582503318787, + -0.01661459542810917, + 0.03759436681866646, + 0.10841676592826843, + -0.7652962803840637, + 0.4360819458961487, + 0.7012669444084167, + -0.47011038661003113, + 0.01529701892286539, + -0.0033166150096803904, + 0.12170535326004028, + -0.3871544301509857, + -0.05247795954346657, + 0.4504147171974182, + -0.11442532390356064, + -0.00882577896118164, + 0.005190832540392876, + -0.05153197422623634, + 0.0055236960761249065, + 0.09320031106472015, + -0.03762076050043106, + -0.021778371185064316, + 0.00750907463952899, + 0.014965789392590523, + -0.015135630965232849, + -0.037086039781570435, + 0.08020154386758804, + -0.04429963231086731, + 0.0038218852132558823, + -0.01712334342300892, + 0.053772956132888794, + -0.05226677283644676, + -0.024439912289381027, + 0.12774989008903503, + -0.18722355365753174, + 0.0683830976486206, + -0.010828870348632336, + -0.012880662456154823, + 0.02679484151303768, + -0.13696907460689545, + 0.46868517994880676, + -0.322968989610672, + 0.052930932492017746, + 0.009463602676987648, + -0.046861011534929276, + 0.07714711129665375, + -0.35792097449302673, + 0.5517901182174683, + -0.13382655382156372, + -0.12921281158924103, + 0.018562642857432365, + -0.03842621296644211, + 0.10284601897001266, + -0.28243398666381836, + 0.13314206898212433, + 0.20769073069095612, + -0.1551610678434372, + 0.018036767840385437, + -0.03553476929664612, + 0.036686040461063385, + -0.09568552672863007, + 0.008917863480746746, + 0.11340243369340897, + -0.04745811969041824, + 0.005833764094859362, + -0.04174824804067612, + 0.022730106487870216, + 0.0013601485406979918, + -0.07473982870578766, + -0.004801879171282053, + 0.05632775276899338, + -0.04081303998827934, + 0.11509573459625244, + 0.004507652949541807, + -0.24791881442070007, + 0.43171870708465576, + -0.1362573653459549, + -0.10758046060800552, + 0.02746163308620453, + -0.2954745888710022, + 0.30186471343040466, + 0.3135572075843811, + -1.2296111583709717, + 0.8754236102104187, + -0.11699853837490082, + 0.022482017055153847, + 0.24945153295993805, + -0.7858022451400757, + 0.5181443095207214, + 1.4243930578231812, + -1.876152515411377, + 0.4689188003540039, + 0.04258054122328758, + -0.030832920223474503, + 0.9340220093727112, + -1.512351632118225, + -0.3731614947319031, + 2.021338701248169, + -0.7801089286804199, + -0.09288544207811356, + -0.12423597276210785, + -0.36861127614974976, + 1.1679530143737793, + -0.4960964024066925, + -1.0398281812667847, + 0.686152458190918, + 0.02052121050655842, + 0.07246638089418411, + -0.01763315312564373, + -0.37442535161972046, + 0.33217450976371765, + 0.22260302305221558, + -0.2657756209373474, + 0.00016369696822948754, + 0.008136127144098282, + -0.03592197597026825, + 0.022231513634324074, + 0.041430093348026276, + -0.06439317017793655, + 0.03496818616986275, + -0.05143435671925545, + 0.09930871427059174, + 0.017110232263803482, + -0.3834381699562073, + 0.44344815611839294, + -0.00280396337620914, + -0.11487428843975067, + 0.050503507256507874, + -0.22837062180042267, + 0.47540077567100525, + 0.5802375674247742, + -1.7325034141540527, + 0.8587368130683899, + 0.10429240018129349, + -0.02456486038863659, + 0.1340152472257614, + -1.2299835681915283, + 0.7986555099487305, + 2.2204456329345703, + -2.4498374462127686, + 0.33742472529411316, + 0.1001473218202591, + 0.08700849115848541, + 0.9933257102966309, + -2.5278031826019287, + -0.5935835242271423, + 2.710871934890747, + -0.87749183177948, + -0.06125229224562645, + -0.19061818718910217, + -0.04017600044608116, + 1.7519460916519165, + -0.7798219919204712, + -1.28012216091156, + 0.7500321269035339, + 0.02245335467159748, + 0.08263842761516571, + -0.1563340127468109, + -0.3502165377140045, + 0.5060794949531555, + 0.11768018454313278, + -0.2394258826971054, + 0.0027446788735687733, + -0.0012661140644922853, + 0.010839025489985943, + 0.04500429332256317, + -0.04333498701453209, + -0.027386408299207687, + 0.04357098788022995, + -0.04407481476664543, + 0.08443310111761093, + -0.08108946681022644, + -0.20346391201019287, + 0.3825778365135193, + -0.16498182713985443, + -0.04287993535399437, + 0.05340999737381935, + -0.14011172950267792, + 0.29446643590927124, + 0.2738667130470276, + -1.1299961805343628, + 0.7827413082122803, + -0.07552053779363632, + -0.03602323681116104, + 0.16167275607585907, + -0.6924317479133606, + 0.4478289783000946, + 1.2428895235061646, + -1.4833877086639404, + 0.4690392315387726, + -0.00820756796747446, + -0.09873292595148087, + 0.692342221736908, + -1.0981175899505615, + -0.3906446695327759, + 1.438644528388977, + -0.719068169593811, + 0.026173872873187065, + -0.09383898228406906, + -0.3282022774219513, + 1.0363390445709229, + -0.23960772156715393, + -0.7638148069381714, + 0.5488630533218384, + -0.015319733880460262, + 0.11911362409591675, + 0.017409542575478554, + -0.4231888949871063, + 0.23724795877933502, + 0.1191876158118248, + -0.15694500505924225, + -0.03534351661801338, + 0.06342366337776184, + 0.17738288640975952, + 0.012300643138587475, + -0.06408121436834335, + -0.06030220910906792, + 0.0018237337935715914, + 0.07659764587879181, + 0.1820947527885437, + 0.24410061538219452, + -0.06998514384031296, + -0.1491813361644745, + -0.06184092164039612, + 0.04607890918850899, + 0.15362663567066193, + 0.18308304250240326, + 0.08175522834062576, + -0.305602103471756, + -0.2915116548538208, + -0.08144206553697586, + 0.07138665020465851, + -0.03521484509110451, + -0.0914112851023674, + -0.2766699492931366, + -0.6285344362258911, + -0.38168880343437195, + -0.0033710987772792578, + 0.14477019011974335, + -0.03885374590754509, + -0.11367184668779373, + -0.1979650855064392, + -0.3575190007686615, + 0.016150522977113724, + 0.28292712569236755, + 0.2836199402809143, + -0.016672370955348015, + -0.034946177154779434, + -0.014770845882594585, + -0.0004113636096008122, + 0.29938748478889465, + 0.3562523126602173, + 0.13313128054141998, + -0.029499055817723274, + 0.007187174167484045, + 0.0636785551905632, + 0.047712039202451706, + 0.20670579373836517, + 0.10999035090208054, + -0.1150810718536377, + 0.00879934523254633, + -0.009125287644565105, + -0.013732590712606907, + 0.04738131910562515, + 0.0549951009452343, + -0.014094026759266853, + -0.01195482350885868, + -0.017125386744737625, + -0.071754589676857, + -0.023961570113897324, + 0.013098018243908882, + 0.05972208455204964, + -0.032899752259254456, + -0.024354496970772743, + -0.013116234913468361, + -0.05865325778722763, + -0.006360829807817936, + 0.12809234857559204, + 0.14038555324077606, + -0.022946689277887344, + -0.039698828011751175, + 0.05144746974110603, + -0.025034509599208832, + 0.08764739334583282, + 0.24594412744045258, + 0.19307002425193787, + -0.04085381329059601, + -0.020323628559708595, + 0.022060081362724304, + 0.01799374632537365, + 0.09039195626974106, + 0.1681770235300064, + 0.0016234283102676272, + -0.23777234554290771, + -0.11634974926710129, + -0.014439117163419724, + -0.034799374639987946, + 0.0457066111266613, + 0.049919649958610535, + -0.1926913857460022, + -0.2680967450141907, + 0.0018220803467556834, + -0.012749310582876205, + -0.04389086738228798, + 0.0060565415769815445, + -0.012036234140396118, + -0.12737582623958588, + -0.05777670815587044, + 0.09932202100753784, + 0.09969642758369446, + -0.1296343356370926, + -0.2964152693748474, + -0.05487265810370445, + 0.12073978036642075, + 0.06634647399187088, + 0.004042446613311768, + -0.1586746722459793, + -0.6267098784446716, + -0.5184157490730286, + -0.032286129891872406, + 0.28023189306259155, + 0.12663227319717407, + -0.08828771114349365, + -0.2600027620792389, + -0.5287090539932251, + -0.0994620993733406, + 0.7820600271224976, + 0.9638882279396057, + 0.2193463146686554, + -0.13466303050518036, + 0.042050741612911224, + -0.02292742393910885, + 0.7523098587989807, + 1.7435946464538574, + 1.111282229423523, + -0.2104763388633728, + -0.35129284858703613, + 0.08224371820688248, + 0.11167984455823898, + 0.6513852477073669, + 0.9696454405784607, + -0.1501394510269165, + -1.1777327060699463, + -0.7738466262817383, + 0.01114045549184084, + 0.004884988535195589, + 0.2849186658859253, + 0.14232710003852844, + -1.0306764841079712, + -1.2078118324279785, + -0.14658716320991516, + 0.036605384200811386, + 0.0001495486794738099, + 0.12111346423625946, + -0.24653346836566925, + -0.7028710246086121, + -0.18977169692516327, + 0.5171932578086853, + -0.02514370158314705, + 0.0885375589132309, + -0.1023016944527626, + 0.023200739175081253, + 0.11839435249567032, + -0.09749021381139755, + 0.008283962495625019, + 0.0106261121109128, + -0.031724803149700165, + -0.1594654619693756, + 0.433218389749527, + -0.33944255113601685, + 0.14406877756118774, + -0.0339396670460701, + 0.09370072185993195, + -0.35916459560394287, + 0.7577320337295532, + -0.5531823635101318, + -0.016844574362039566, + 0.2994873523712158, + -0.21487002074718475, + -0.16125759482383728, + 0.35567227005958557, + 0.09099612385034561, + -1.3889282941818237, + 1.9466298818588257, + -1.2556309700012207, + 0.4389301836490631, + -0.010665428824722767, + 0.4707520306110382, + -1.4310415983200073, + 2.0986156463623047, + -1.5515614748001099, + 0.3905705511569977, + 0.01881679706275463, + 0.057307951152324677, + -0.29734691977500916, + 0.369127094745636, + -0.05115725100040436, + -0.44008156657218933, + 0.48642784357070923, + -0.13904061913490295, + -0.004375698510557413, + -0.06351548433303833, + 0.256020188331604, + -0.34121274948120117, + 0.22490821778774261, + 0.004067304544150829, + -0.059063635766506195, + -0.010710661299526691, + 0.03514768183231354, + -0.08577805012464523, + 0.05103181675076485, + 0.04276616871356964, + -0.10832246392965317, + 0.03325289487838745, + 0.06318283081054688, + -0.11063538491725922, + -0.062119144946336746, + 0.40978243947029114, + -0.5597845315933228, + 0.34106317162513733, + -0.030269838869571686, + 0.057014383375644684, + -0.44329890608787537, + 1.0965592861175537, + -1.0767146348953247, + 0.13287265598773956, + 0.517289400100708, + -0.310720294713974, + -0.15501761436462402, + 0.5854693055152893, + -0.12469431757926941, + -1.7694847583770752, + 2.6433238983154297, + -1.596714735031128, + 0.3888415992259979, + -0.02415616251528263, + 0.42178481817245483, + -1.8008503913879395, + 2.8845136165618896, + -1.7628657817840576, + 0.1951047033071518, + 0.11415407806634903, + 0.07305648922920227, + -0.34212157130241394, + 0.46562451124191284, + 0.03175807744264603, + -0.7942091226577759, + 0.6133171319961548, + -0.14596694707870483, + 0.010496735572814941, + -0.03459644690155983, + 0.2948842942714691, + -0.47654271125793457, + 0.2612597346305847, + 0.016025209799408913, + -0.05287598818540573, + -0.01606004498898983, + 0.022197037935256958, + 0.028397703543305397, + -0.0390767939388752, + 0.0037972000427544117, + -0.07010228931903839, + 0.10934390872716904, + 0.017220165580511093, + 0.02215729095041752, + -0.14772991836071014, + 0.2353552132844925, + -0.3846408724784851, + 0.23990634083747864, + -0.02300707995891571, + 0.12085225433111191, + -0.3576957881450653, + 0.6410096883773804, + -0.532350480556488, + -0.002389132045209408, + 0.41821879148483276, + -0.24739143252372742, + -0.10216745734214783, + 0.16793736815452576, + 0.16367803514003754, + -1.1304419040679932, + 1.676539421081543, + -1.064436435699463, + 0.26995453238487244, + -0.07634275406599045, + 0.3324422240257263, + -1.11312997341156, + 1.8095507621765137, + -1.2477567195892334, + 0.3605581820011139, + -0.06627745926380157, + 0.008511146530508995, + -0.19528241455554962, + 0.4320055842399597, + -0.22881783545017242, + -0.18463851511478424, + 0.3064245581626892, + -0.14437103271484375, + 0.02049900032579899, + 0.018321938812732697, + 0.14011529088020325, + -0.26683253049850464, + 0.2172057181596756, + -0.12119362503290176, + 0.025965997949242592, + -0.03424325957894325, + 0.0433838777244091, + 0.1072857677936554, + 0.1997794657945633, + 0.0648089200258255, + -0.06444115936756134, + -0.13146057724952698, + 0.02106364443898201, + -0.22582228481769562, + -0.007233713287860155, + 0.18876874446868896, + -0.5612399578094482, + 0.2632557451725006, + 0.44088244438171387, + 0.11389002948999405, + -0.2791701555252075, + -0.18004432320594788, + 0.8571203947067261, + -1.9517340660095215, + -1.4906251430511475, + 0.3436146676540375, + 0.31222787499427795, + -0.20083315670490265, + -0.217665895819664, + 3.801243782043457, + 1.2014728784561157, + -0.9149202704429626, + 0.6968244910240173, + 0.12756747007369995, + -0.06783506274223328, + -2.086660385131836, + 0.5455523133277893, + 0.49095916748046875, + -0.5991013050079346, + 0.7938552498817444, + -0.1335069239139557, + 0.4730406701564789, + -1.00951087474823, + -0.537578821182251, + -0.49764835834503174, + -1.2683815956115723, + -0.045739322900772095, + -0.16049732267856598, + 0.30239275097846985, + 0.035600025206804276, + 0.6344828605651855, + 0.8256548643112183, + -0.12940075993537903, + 0.09257010370492935, + -0.11000311374664307, + 0.003206665627658367, + -0.008585316129028797, + -0.14573170244693756, + 0.172541081905365, + 0.2107972949743271, + -0.05270108953118324, + -0.08480435609817505, + 0.1914149820804596, + 0.21630872786045074, + -0.23309426009655, + -0.29484814405441284, + -0.1899339109659195, + 0.02601807750761509, + -0.05416746065020561, + 0.20924429595470428, + 0.15566189587116241, + -0.1556546688079834, + -0.23387494683265686, + -0.5112816691398621, + 0.24130745232105255, + -0.049835484474897385, + -0.2685615122318268, + -0.024764614179730415, + 0.5458847880363464, + 0.9501044750213623, + 0.1328524947166443, + 0.21218529343605042, + 0.2524968683719635, + -0.5205130577087402, + -0.3361912667751312, + 1.1678112745285034, + -0.004513490945100784, + -0.9149109125137329, + 0.2125048041343689, + 0.22423015534877777, + -0.08384363353252411, + -0.2866036593914032, + -0.20210212469100952, + -1.2377471923828125, + -0.7704879641532898, + 0.365038126707077, + -0.08308980613946915, + -0.08326874673366547, + 0.456358402967453, + 0.35142943263053894, + 0.19268833100795746, + 0.3706081509590149, + -0.04951317980885506, + 0.10151109844446182, + 0.005193099845200777, + -0.1124582439661026, + -0.08353164792060852, + -0.18709596991539001, + -0.18975794315338135, + 0.17628741264343262, + 0.05536900460720062, + 0.008301885798573494, + -0.1890449970960617, + 0.056875281035900116, + 0.7981322407722473, + -0.05872391164302826, + -0.4860122501850128, + -0.08073797076940536, + 0.13145819306373596, + -0.03608228266239166, + -0.6600452661514282, + 2.243560314178467, + 1.9288626909255981, + -0.5698518753051758, + -0.2486664056777954, + 0.42693793773651123, + 0.2667267322540283, + -4.395429611206055, + -2.15342378616333, + 0.819127082824707, + -0.9362612962722778, + -0.3760467767715454, + 0.5671858787536621, + 2.468177080154419, + -1.6694080829620361, + -0.49952322244644165, + 1.502772569656372, + -1.0188850164413452, + -0.10419629514217377, + -0.36795151233673096, + 1.2645196914672852, + 0.7223924994468689, + 1.751431941986084, + 2.018704891204834, + -0.3197852671146393, + 0.22054125368595123, + -0.19326329231262207, + -0.5307535529136658, + -0.9362435936927795, + -1.0772119760513306, + -0.19870880246162415, + -0.0650869607925415, + -0.0796947032213211, + 0.15733301639556885, + 0.08798394352197647, + 0.0010860684560611844, + 0.05327683687210083, + 0.1107875183224678, + 0.13224183022975922, + 0.08979664742946625, + 0.004348093178123236, + -0.07060158997774124, + -0.19925491511821747, + -0.15811985731124878, + -0.08220887929201126, + -0.022623460739850998, + 0.08509720861911774, + 0.00792989507317543, + -0.14345014095306396, + -0.2720486521720886, + -0.18885627388954163, + -0.11063539236783981, + -0.0355350486934185, + 0.048891279846429825, + -0.12828074395656586, + -0.2712610363960266, + -0.20134924352169037, + -0.1863398402929306, + -0.19976121187210083, + -0.09535074234008789, + 0.009852319024503231, + -0.2776590585708618, + -0.3087778687477112, + -0.21431012451648712, + -0.19772370159626007, + -0.23412325978279114, + -0.11640459299087524, + 0.09514907747507095, + -0.17561811208724976, + -0.29451555013656616, + -0.2381855845451355, + -0.18296842277050018, + -0.18682444095611572, + -0.023345205932855606, + 0.1438502073287964, + 0.02504260651767254, + -0.1554802507162094, + -0.1477985382080078, + -0.07874225080013275, + -0.002977968193590641, + 0.1048416793346405, + -0.1779504120349884, + 0.13204343616962433, + 0.14215172827243805, + 0.049610622227191925, + 0.0888131782412529, + 0.07250366359949112, + 0.0696505531668663, + 0.009899160824716091, + 0.032067786902189255, + 0.08401404321193695, + -0.03567894548177719, + -0.004740188363939524, + -0.0021664693485945463, + -0.011156522668898106, + 0.0821070745587349, + 0.10295391082763672, + -0.0017653254326432943, + -0.16915833950042725, + -0.062223054468631744, + 0.004783258773386478, + 0.038355808705091476, + 0.10124270617961884, + -0.003437258303165436, + -0.18881437182426453, + -0.15905225276947021, + -0.12576808035373688, + -0.11059725284576416, + 0.021587060764431953, + 0.07237453758716583, + -0.1706620156764984, + -0.27434206008911133, + -0.23003827035427094, + -0.20530915260314941, + -0.20856624841690063, + -0.021966496482491493, + 0.13395215570926666, + -0.03810539469122887, + -0.2409798800945282, + -0.2515420913696289, + -0.1872486174106598, + -0.15951117873191833, + 0.04223426431417465, + 0.09909931570291519, + 0.12328703701496124, + -0.057749148458242416, + -0.1300545036792755, + -0.046062104403972626, + 0.019744107499718666, + 0.09484386444091797, + -0.2709728479385376, + 0.03540695831179619, + 0.1206774190068245, + 0.057636432349681854, + 0.10385740548372269, + 0.032486993819475174, + -0.020434774458408356, + -0.10122086852788925, + -0.0023329253308475018, + 0.16941140592098236, + 0.098082534968853, + 0.1250472217798233, + 0.06134447827935219, + -0.025240115821361542, + 0.004181401338428259, + 0.14425808191299438, + 0.17515034973621368, + 0.04739757999777794, + 0.1618604063987732, + 0.1751406490802765, + 0.09162088483572006, + 0.09512057155370712, + 0.13736343383789062, + 0.028775952756404877, + 0.042535409331321716, + 0.08839954435825348, + 0.09229374676942825, + 0.1658262014389038, + 0.09852072596549988, + 0.002680110279470682, + -0.05479496717453003, + -0.03634755313396454, + -0.002902726177126169, + -0.023990361019968987, + 0.1277875006198883, + 0.12727677822113037, + 0.1002269834280014, + -0.040967896580696106, + -0.07101184874773026, + -0.007902896963059902, + 0.019561029970645905, + 0.145268052816391, + 0.017638152465224266, + 0.19240263104438782, + 0.12857146561145782, + 0.05043037235736847, + 0.11596394330263138, + 0.12513381242752075, + 0.12088746577501297, + 0.04333524778485298, + 0.05500142276287079, + 0.05169082432985306, + -0.09941842406988144, + -0.005959822330623865, + -0.032586321234703064, + -0.03065132349729538, + -0.04826900362968445, + 0.14192889630794525, + 0.2543988823890686, + 0.09563885629177094, + -0.28965362906455994, + -0.1341734230518341, + 0.033991701900959015, + -0.22402706742286682, + -0.3190857768058777, + 0.011840387247502804, + 0.9620282053947449, + 1.0609054565429688, + -0.13429726660251617, + -0.20191268622875214, + 0.05324135720729828, + -0.16234318912029266, + -0.9101927280426025, + -1.7916113138198853, + 0.3981992304325104, + 1.3173034191131592, + 0.53525310754776, + 0.18472574651241302, + 0.3719426691532135, + 0.7792536020278931, + -0.027768991887569427, + -2.245561122894287, + -1.2211185693740845, + 0.22817185521125793, + -0.0023349972907453775, + -0.12598364055156708, + 0.06836964190006256, + 0.9917387366294861, + 1.1885775327682495, + -0.2851368486881256, + -0.7428704500198364, + -0.04798422381281853, + -0.00811613816767931, + -0.19619861245155334, + -0.28184008598327637, + 0.0828644260764122, + 0.44643187522888184, + 0.1461745798587799, + -0.005575121380388737, + -0.06604957580566406, + 0.011459077708423138, + 0.03927984461188316, + 0.0634538009762764, + -0.005732079967856407, + -0.01014732290059328, + 0.07607843726873398, + 0.06948187947273254, + -0.010600326582789421, + -0.056259915232658386, + -0.24602480232715607, + -0.01649448834359646, + 0.11143466085195541, + -0.0027401424013078213, + -0.012853104621171951, + 0.08452893793582916, + 0.639316201210022, + 0.5167437195777893, + -0.2775256335735321, + -0.22241903841495514, + -0.07067711651325226, + -0.06368192285299301, + -0.4687917232513428, + -1.1776493787765503, + 0.36015447974205017, + 0.9171182513237, + 0.1905054748058319, + -0.010661551728844643, + 0.10800722986459732, + 0.5352235436439514, + 0.18558207154273987, + -1.5184046030044556, + -0.8130561709403992, + 0.15417319536209106, + 0.0713079422712326, + -0.07369451224803925, + -0.09037846326828003, + 0.6168488264083862, + 0.9663773775100708, + -0.007113471627235413, + -0.33585548400878906, + -0.02738586813211441, + 0.061310965567827225, + -0.0955657884478569, + -0.23896107077598572, + -0.1107473075389862, + 0.1830059289932251, + 0.10748914629220963, + -0.040772341191768646, + -0.05803938955068588, + -0.0004895658930763602, + 0.07664632797241211, + 0.039049405604600906, + -0.002806248841807246, + -0.02642429992556572, + 0.05169009417295456, + -0.036710865795612335, + -0.1002974808216095, + -0.12001149356365204, + -0.08043934404850006, + 0.11466419696807861, + 0.12322796136140823, + 0.07564827799797058, + 0.10148002207279205, + 0.04720174893736839, + 0.14046646654605865, + -0.0819464847445488, + -0.30803975462913513, + -0.0838734582066536, + -0.0801682323217392, + 0.05861072987318039, + 0.04970559477806091, + -0.20592759549617767, + 0.2673366665840149, + 0.2431953400373459, + -0.10027645528316498, + -0.07884806394577026, + -0.09939537942409515, + 0.1181628480553627, + 0.25269386172294617, + -0.3439132571220398, + -0.11160463094711304, + 0.08640077710151672, + 0.07200870662927628, + -0.03449570760130882, + -0.17610406875610352, + -0.021308166906237602, + 0.30556705594062805, + 0.05186203494668007, + -0.004691269714385271, + -0.005278654862195253, + 0.06289899349212646, + 0.052224051207304, + -0.05927770212292671, + -0.1586783081293106, + -0.022610770538449287, + 0.03463536128401756, + 0.004338411148637533, + 0.01452699676156044, + -0.008622901514172554, + 0.010536444373428822, + -0.038111478090286255, + 0.013373414985835552, + 0.007125865668058395, + -0.003420598339289427, + 0.03533756732940674, + 0.0320388600230217, + 0.045789655297994614, + -0.08139114826917648, + -0.03447948023676872, + -0.01453007198870182, + -0.004573625046759844, + 0.10279268026351929, + 0.10881853848695755, + 0.07537791877985, + -0.10887791216373444, + -0.0980544164776802, + -0.06889445334672928, + 0.006558350287377834, + 0.197514146566391, + 0.17890937626361847, + 0.07630149275064468, + -0.16081148386001587, + -0.16685302555561066, + -0.11421715468168259, + -0.013679573312401772, + 0.22477784752845764, + 0.20761631429195404, + 0.07321957498788834, + -0.17697854340076447, + -0.17810045182704926, + -0.1579347848892212, + -0.02679254300892353, + 0.1408146619796753, + 0.15144851803779602, + 0.08801613748073578, + -0.13237154483795166, + -0.13181765377521515, + -0.1279487907886505, + -0.01779216341674328, + 0.08145096898078918, + 0.05625852569937706, + 0.07724357396364212, + -0.04653938114643097, + -0.07479449361562729, + -0.06189379468560219, + -0.04310920089483261, + 0.02028634026646614, + -0.006228619255125523, + 0.03549303859472275, + -0.043929651379585266, + 0.007818001322448254, + 0.00874761026352644, + -0.017027731984853745, + 0.11014463752508163, + 0.0841977447271347, + 0.05960552394390106, + -0.12814101576805115, + -0.0544624924659729, + -0.045333195477724075, + 0.02336869016289711, + 0.22365787625312805, + 0.18523427844047546, + 0.09366372227668762, + -0.20144090056419373, + -0.16367222368717194, + -0.13003699481487274, + 0.0590205080807209, + 0.3301562964916229, + 0.26524844765663147, + 0.09425198286771774, + -0.26156124472618103, + -0.28513699769973755, + -0.21749621629714966, + 0.04356053099036217, + 0.35879984498023987, + 0.29898661375045776, + 0.0977487862110138, + -0.28175386786460876, + -0.2964495122432709, + -0.249031201004982, + 0.028877725824713707, + 0.26395633816719055, + 0.23059280216693878, + 0.09593978524208069, + -0.22489066421985626, + -0.2248908430337906, + -0.19214706122875214, + 0.007535146549344063, + 0.15299226343631744, + 0.09148521721363068, + 0.06946425884962082, + -0.1445557326078415, + -0.11587042361497879, + -0.0978587418794632, + -0.00984917301684618, + -0.012626220472157001, + -0.02837960794568062, + 0.02399199828505516, + -0.005340439733117819, + 0.023224178701639175, + 0.011642432771623135, + 0.003958537708967924, + 0.042965203523635864, + 0.01099414099007845, + 0.024063799530267715, + -0.0702008455991745, + 0.007805663626641035, + 0.0050195748917758465, + 0.017281856387853622, + 0.10123670846223831, + 0.06401767581701279, + 0.02626805007457733, + -0.1073761060833931, + -0.03802435100078583, + -0.014407800510525703, + -0.0006281707319431007, + 0.15516239404678345, + 0.12629136443138123, + 0.033691491931676865, + -0.17609107494354248, + -0.15251316130161285, + -0.07914211601018906, + -0.015578335151076317, + 0.18422608077526093, + 0.1740245372056961, + 0.06139932945370674, + -0.17213505506515503, + -0.1602732092142105, + -0.08922445774078369, + -0.012822975404560566, + 0.13543544709682465, + 0.12543149292469025, + 0.07651004195213318, + -0.13805902004241943, + -0.09661149233579636, + -0.052669934928417206, + -0.03268992528319359, + 0.0391642227768898, + 0.01116940937936306, + 0.04585625231266022, + -0.06474924832582474, + -0.023607701063156128, + -0.007017284631729126, + -0.026150476187467575, + 0.05729387328028679, + -0.10095079243183136, + 0.16617903113365173, + -0.13664309680461884, + 0.026482274755835533, + 0.008411461487412453, + -0.03410203382372856, + 0.022963764145970345, + 0.008903563022613525, + 0.11244194954633713, + -0.20863348245620728, + 0.11064451932907104, + -0.024916114285588264, + 0.009591493755578995, + -0.26092270016670227, + 0.5717483758926392, + -0.38539814949035645, + 0.035056713968515396, + 0.08623965084552765, + -0.016184961423277855, + 0.11129201203584671, + -0.6138678789138794, + 1.3646206855773926, + -1.4969615936279297, + 0.8465064764022827, + -0.2794847786426544, + 0.05826558917760849, + 0.07709132134914398, + -0.5444677472114563, + 1.3013663291931152, + -1.5686073303222656, + 0.9930508732795715, + -0.39188963174819946, + 0.08085884898900986, + -0.05875617265701294, + 0.03498996049165726, + 0.23967482149600983, + -0.3468690514564514, + 0.19146253168582916, + 0.019604403525590897, + -0.027150027453899384, + -0.024670494720339775, + 0.09944183379411697, + -0.11718503385782242, + 0.09772855788469315, + -0.11857263743877411, + 0.09660946577787399, + -0.03638811036944389, + -0.0295167975127697, + 0.1032838523387909, + -0.12557579576969147, + 0.11812210828065872, + -0.08446288853883743, + 0.027706580236554146, + 0.010997293516993523, + -0.06348618865013123, + 0.09578556567430496, + -0.0165568757802248, + -0.014778072014451027, + -0.07772849500179291, + 0.11245536059141159, + -0.043248821049928665, + 0.013345679268240929, + -0.22149333357810974, + 0.6456363797187805, + -0.7280437350273132, + 0.3046833574771881, + 0.06304280459880829, + -0.07310052216053009, + 0.08824795484542847, + -0.65179842710495, + 1.6453673839569092, + -2.046448230743408, + 1.3267604112625122, + -0.42399832606315613, + 0.0010522910160943866, + 0.07953720539808273, + -0.5960973501205444, + 1.5601089000701904, + -2.084894895553589, + 1.4612183570861816, + -0.5491638779640198, + 0.13709494471549988, + -0.09170618653297424, + 0.07287970930337906, + 0.24422486126422882, + -0.4581631124019623, + 0.29479551315307617, + -0.07515113800764084, + -0.012292998842895031, + -0.04451148584485054, + 0.14961428940296173, + -0.15577177703380585, + 0.06323063373565674, + -0.07806269824504852, + 0.07061618566513062, + -0.026793144643306732, + -0.051938362419605255, + 0.13946141302585602, + -0.14129231870174408, + 0.11092118173837662, + -0.08889970183372498, + 0.034787945449352264, + -0.008983314968645573, + -0.04930088296532631, + 0.09856640547513962, + -0.09350966662168503, + 0.07015673816204071, + -0.06468848884105682, + 0.08028972148895264, + -0.02378295361995697, + 0.004251216538250446, + -0.11239825189113617, + 0.2660067081451416, + -0.367576539516449, + 0.2212517410516739, + -0.035011082887649536, + -0.037866897881031036, + 0.11835235357284546, + -0.4868132174015045, + 0.9402765035629272, + -1.0933791399002075, + 0.9518744349479675, + -0.5096855759620667, + 0.12277142703533173, + 0.12916085124015808, + -0.4648635983467102, + 0.8895858526229858, + -1.0776352882385254, + 1.023865818977356, + -0.5914785861968994, + 0.1682877242565155, + -0.05646277964115143, + 0.04132156819105148, + -0.01790236309170723, + -0.059831030666828156, + 0.10092897713184357, + -0.1268356889486313, + 0.013669619336724281, + -0.02746082842350006, + 0.11544085294008255, + -0.2124193012714386, + 0.2733248472213745, + -0.1360178142786026, + 0.025302443653345108, + 0.01249375008046627, + -0.015119954012334347, + 0.017966970801353455, + 0.00269943755120039, + 0.014392177574336529, + 0.007648292928934097, + 0.011665135622024536, + -0.006192799191921949, + 0.004215092398226261, + 0.017718149349093437, + 0.046436555683612823, + 0.044417623430490494, + 0.01518242433667183, + -0.0020157198887318373, + -0.01828707568347454, + -0.029163505882024765, + -0.03131464868783951, + -0.004393945913761854, + 0.048599082976579666, + 0.015757638961076736, + -0.015650734305381775, + -0.002684049541130662, + -0.0697445422410965, + -0.25050923228263855, + -0.4758685231208801, + -0.5382962822914124, + -0.38907238841056824, + -0.12599025666713715, + -0.00266047241166234, + 0.0758173018693924, + 0.26593172550201416, + 0.4203726053237915, + 0.4958920478820801, + 0.3697706162929535, + 0.12434400618076324, + 0.026325728744268417, + 0.022295912727713585, + 0.08135133236646652, + 0.2627769708633423, + 0.26325660943984985, + 0.12326934933662415, + 0.058665141463279724, + 0.04346219077706337, + -0.0013142779935151339, + -0.10037153959274292, + -0.27075886726379395, + -0.28071707487106323, + -0.17300420999526978, + -0.06914675980806351, + 0.004067219793796539, + -0.020674005150794983, + 0.02103183977305889, + 0.0033879741095006466, + 0.013523808680474758, + -0.007318845018744469, + -0.009975744411349297, + -0.02981705591082573, + 0.023193644359707832, + 0.09624253213405609, + 0.1077117845416069, + 0.11186518520116806, + 0.07592211663722992, + 0.04614634811878204, + 0.015908582136034966, + -0.05212458223104477, + -0.1262977123260498, + -0.10974782705307007, + -0.07645918428897858, + -0.06987964361906052, + -0.08783216774463654, + -0.046172842383384705, + -0.22593465447425842, + -0.5281140804290771, + -0.8424770832061768, + -0.9608982801437378, + -0.7363743185997009, + -0.3312055170536041, + -0.10426472127437592, + 0.24067367613315582, + 0.5504152178764343, + 0.81276935338974, + 0.9592635035514832, + 0.7479950785636902, + 0.32608768343925476, + 0.14525265991687775, + 0.15008939802646637, + 0.32246851921081543, + 0.5287250876426697, + 0.5817036032676697, + 0.37340155243873596, + 0.20366452634334564, + 0.1546182781457901, + -0.11224830150604248, + -0.29856279492378235, + -0.5281672477722168, + -0.5890122056007385, + -0.4024880528450012, + -0.23706914484500885, + -0.0641399398446083, + -0.0025121152866631746, + 0.0051757702603936195, + -0.014290476217865944, + 0.0043721878901124, + -0.004783981014043093, + 0.021787043660879135, + -0.004969750996679068, + -0.022116241976618767, + 0.05208030343055725, + 0.07022145390510559, + 0.03730607405304909, + 0.03242917358875275, + 0.04344351217150688, + -0.01189794484525919, + -0.0418211966753006, + -0.059125497937202454, + -0.014576594345271587, + 0.01294493954628706, + -0.011262460611760616, + -0.059920165687799454, + -0.04733816161751747, + -0.12665517628192902, + -0.29677024483680725, + -0.5247481465339661, + -0.6474934816360474, + -0.4751538038253784, + -0.1937171369791031, + -0.05117221921682358, + 0.14646948873996735, + 0.32891425490379333, + 0.5415402054786682, + 0.6071264147758484, + 0.4653589427471161, + 0.18045872449874878, + 0.09937354922294617, + 0.1264665126800537, + 0.18507222831249237, + 0.31783968210220337, + 0.3545042872428894, + 0.22468777000904083, + 0.09973976761102676, + 0.1227618008852005, + -0.07824759930372238, + -0.20465101301670074, + -0.36476215720176697, + -0.38243186473846436, + -0.2540777623653412, + -0.13525226712226868, + -0.03621843457221985, + -0.012233156710863113, + -0.01481863297522068, + -0.04313792288303375, + 0.002874002791941166, + -0.028444716706871986, + -0.04687628522515297, + -0.026806645095348358, + -0.0228339321911335, + -0.015892738476395607, + -0.015550780110061169, + 0.07011140882968903, + 0.0017389585264027119, + -0.05721491947770119, + -0.017484690994024277, + -0.03954736143350601, + -0.006339249666780233, + 0.08166316151618958, + 0.37439921498298645, + 0.2830294966697693, + 0.00668215099722147, + -0.038873329758644104, + -0.012295035645365715, + 0.04932165890932083, + 0.31826695799827576, + 0.8449289202690125, + 0.7123299241065979, + 0.2574000954627991, + 0.04747961834073067, + -0.04416817054152489, + -0.005029442720115185, + 0.2027042657136917, + 0.6639980673789978, + 0.6243636012077332, + 0.21359916031360626, + 0.027929672971367836, + -0.05395142361521721, + -0.04981911554932594, + -0.006375179626047611, + 0.23660773038864136, + 0.2155737280845642, + 0.020577391609549522, + -0.032118700444698334, + -0.02332071214914322, + -0.009217707440257072, + -0.038096409291028976, + 0.05811609327793121, + 0.03776064142584801, + -0.03570764884352684, + -0.042420413345098495, + 0.017812976613640785, + 0.019242385402321815, + 0.030057156458497047, + 0.003040613606572151, + 0.02378096617758274, + 0.04043402150273323, + 0.0243258997797966, + 0.014026327058672905, + 0.005650558043271303, + -0.002831381279975176, + -0.0645776093006134, + -0.03761167451739311, + 0.043774381279945374, + 0.010685136541724205, + 0.031011218205094337, + -0.0025828774087131023, + -0.11959855258464813, + -0.3524792194366455, + -0.30037227272987366, + -0.053334690630435944, + 0.009859252721071243, + 0.0010005333460867405, + -0.04819931834936142, + -0.3154168128967285, + -0.7240553498268127, + -0.6380828022956848, + -0.25695785880088806, + -0.06639125943183899, + 0.03295261785387993, + -0.012727363035082817, + -0.24232468008995056, + -0.6055921912193298, + -0.5679556727409363, + -0.20067356526851654, + -0.03628019988536835, + 0.04774145409464836, + 0.029560575261712074, + -0.038632482290267944, + -0.24032950401306152, + -0.2095729559659958, + -0.006905315909534693, + 0.02563827484846115, + 0.03053808957338333, + 0.0012747920118272305, + 0.004095789045095444, + -0.07932732999324799, + -0.046672020107507706, + 0.02153847925364971, + 0.019504766911268234, + -0.006118285935372114, + 0.0026654782705008984, + 0.013819373212754726, + -0.01078135147690773, + 0.0070082321763038635, + 0.00906399916857481, + 0.010149766691029072, + 0.000516490894369781, + 0.00034157291520386934, + 0.02412085421383381, + 0.006926041562110186, + 0.023299943655729294, + 0.01129852794110775, + -0.0018704778049141169, + 0.016042279079556465, + 0.023886069655418396, + 0.04207555204629898, + -0.0021778997033834457, + 0.041684601455926895, + 0.05059140920639038, + 0.03518521040678024, + -0.0032736151479184628, + -0.0007146652205847204, + 0.015503454953432083, + -0.11896659433841705, + -0.07006713002920151, + 0.007565992418676615, + 0.012584990821778774, + 0.00843358226120472, + 0.017024952918291092, + 0.0359124094247818, + -0.05997823178768158, + -0.04116949439048767, + -0.016472430899739265, + 0.002696823561564088, + 0.00829327292740345, + 0.016238784417510033, + 0.0455794483423233, + 0.0019872160628437996, + -0.005927432328462601, + -0.003552153240889311, + 0.020063765347003937, + 0.00010026743984781206, + 0.01045019831508398, + 0.034689340740442276, + 0.014206668362021446, + 0.015128945000469685, + 0.00972809735685587, + 0.019944868981838226, + 0.020581791177392006, + 0.02938947267830372, + 0.03923909366130829, + 0.03601628914475441, + 0.030168617144227028, + 0.05403255671262741, + 0.03985666483640671, + 0.020015308633446693, + 0.0285494402050972, + 0.013555807992815971, + -0.04409409686923027, + -0.07503483444452286, + 0.01716756261885166, + 0.02053452841937542, + 0.057520389556884766, + 0.02973104454576969, + -0.04563397541642189, + -0.2676408588886261, + -0.30933722853660583, + -0.11671236902475357, + 0.0020135289523750544, + 0.022801443934440613, + -0.03161352127790451, + -0.2704106271266937, + -0.5803710222244263, + -0.5762420296669006, + -0.30449461936950684, + -0.0780220776796341, + 0.017343536019325256, + -0.05319945886731148, + -0.2906038463115692, + -0.598426342010498, + -0.5925986766815186, + -0.31852787733078003, + -0.09950074553489685, + 0.05888299271464348, + 0.01939479075372219, + -0.1060815081000328, + -0.3505017161369324, + -0.3200446665287018, + -0.10609738528728485, + 0.03659524768590927, + 0.056114207953214645, + 0.03447861596941948, + 0.014380007050931454, + -0.09436371922492981, + -0.07562272250652313, + 0.04223132133483887, + 0.06327345967292786, + -0.03735652193427086, + -0.052881840616464615, + -0.058017320930957794, + -0.02474917098879814, + -0.02431381866335869, + -0.0629878118634224, + -0.05212349444627762, + -0.03820814937353134, + -0.0034579068887978792, + -0.004930540919303894, + 0.07968354970216751, + 0.07278168946504593, + 0.015167324803769588, + -0.013638288713991642, + -0.05875609815120697, + -0.008851750753819942, + 0.10708516091108322, + 0.33075177669525146, + 0.3502756953239441, + 0.14791442453861237, + 0.03131852671504021, + -0.028764141723513603, + 0.07454497367143631, + 0.3000347316265106, + 0.6147283315658569, + 0.6289594173431396, + 0.3398674726486206, + 0.13494613766670227, + -0.03705109655857086, + 0.0633230209350586, + 0.3147434592247009, + 0.595033586025238, + 0.594217836856842, + 0.33864542841911316, + 0.11264053732156754, + -0.059276629239320755, + 0.005206871312111616, + 0.14524762332439423, + 0.37473905086517334, + 0.34477534890174866, + 0.12632343173027039, + 0.011062734760344028, + -0.06149457022547722, + -0.028670497238636017, + 0.011082210578024387, + 0.13112866878509521, + 0.1106843650341034, + -0.0025933771394193172, + -0.03781202808022499, + 0.030325254425406456, + 0.017758814617991447, + 0.01635698974132538, + -0.008786264806985855, + -0.0005018062074668705, + 0.005934061016887426, + 0.020206287503242493, + 0.019497420638799667, + -0.01290479488670826, + -0.010817185044288635, + -0.032760608941316605, + -0.026973316445946693, + -0.0021766452118754387, + -0.012848617509007454, + -0.0002560729335527867, + -0.02383977733552456, + -0.05322824791073799, + -0.05382781848311424, + -0.04459262639284134, + -0.04581240937113762, + -0.03465775027871132, + 0.0026904877740889788, + -0.026097090914845467, + -0.05170493200421333, + -0.04981262609362602, + -0.05221042037010193, + -0.05268307775259018, + -0.04735802114009857, + 0.019142162054777145, + -0.019374292343854904, + -0.03312355652451515, + -0.04133244976401329, + -0.033129844814538956, + -0.01844680868089199, + -0.024726904928684235, + 0.0012146441731601954, + -0.025521529838442802, + -0.03120318427681923, + -0.04863203689455986, + -0.021450525149703026, + -0.04190714284777641, + -0.02833862416446209, + 0.017827404662966728, + -0.010181388817727566, + -0.020994380116462708, + -0.04290826618671417, + -0.031555648893117905, + -0.030525390058755875, + -0.024981478229165077, + -0.017512500286102295, + 0.019927235320210457, + 0.00433371402323246, + -0.009276121854782104, + -0.03990143537521362, + -0.021251117810606956, + 0.017825132235884666, + -0.02313065528869629, + 0.012881814502179623, + 0.0009175563463941216, + -0.0656605213880539, + -0.007037178613245487, + 0.023603176698088646, + 0.04873553663492203, + 0.013912673108279705, + 9.78652315097861e-05, + -0.03166677802801132, + -0.11772678792476654, + -0.034320034086704254, + 0.04952533170580864, + 0.10113520920276642, + 0.030472615733742714, + -0.05131377652287483, + -0.1371452510356903, + -0.2326214611530304, + -0.0629519522190094, + 0.12444627285003662, + 0.15845368802547455, + 0.014535457827150822, + -0.06888624280691147, + -0.18798232078552246, + -0.24720685184001923, + -0.04858007654547691, + 0.26889580488204956, + 0.2433905005455017, + -0.01772989332675934, + -0.06027546152472496, + -0.12164203822612762, + -0.20018024742603302, + 0.0035393801517784595, + 0.27190765738487244, + 0.1929154396057129, + -0.012923460453748703, + -0.013931642286479473, + -0.043986693024635315, + -0.0655391663312912, + 0.04751605913043022, + 0.13482201099395752, + 0.06690078228712082, + -0.01862635649740696, + 0.02938506379723549, + 0.01789080537855625, + -0.006509440019726753, + -0.029202938079833984, + -0.023693149909377098, + 0.01042762491852045, + -0.0035929735749959946, + 0.024952176958322525, + -0.013459124602377415, + -0.10798560827970505, + -0.020217353478074074, + 0.017876077443361282, + 0.07628928124904633, + 0.04444783553481102, + 0.012667268514633179, + -0.09012818336486816, + -0.22452381253242493, + -0.07556752860546112, + 0.07942477613687515, + 0.17035256326198578, + 0.0396822914481163, + -0.08236342668533325, + -0.23916372656822205, + -0.3645225763320923, + -0.10748416185379028, + 0.1996970921754837, + 0.3076043725013733, + -0.0033923503942787647, + -0.13259321451187134, + -0.28894615173339844, + -0.3605952262878418, + -0.07969008386135101, + 0.3583948314189911, + 0.4267900586128235, + -0.02228585258126259, + -0.11386624723672867, + -0.21445821225643158, + -0.26956692337989807, + 0.026791207492351532, + 0.37918713688850403, + 0.37130093574523926, + -0.05172214284539223, + -0.05132569745182991, + -0.07469630241394043, + -0.11400169134140015, + 0.07863093167543411, + 0.24061299860477448, + 0.19393151998519897, + -0.03217098489403725, + 0.013085477985441685, + 0.032348379492759705, + 0.03207695484161377, + 0.010604938492178917, + -0.026534704491496086, + -0.018284842371940613, + -0.01768680103123188, + -0.001516501884907484, + 0.013829287141561508, + -0.034318119287490845, + 0.015753330662846565, + -0.0018936718115583062, + 0.014737343415617943, + 0.03306088596582413, + 0.020835628733038902, + -0.03396771103143692, + -0.10758449137210846, + -0.03052518330514431, + 0.020080547779798508, + 0.06180800125002861, + 0.03735671192407608, + -0.037925880402326584, + -0.09720461815595627, + -0.21495617926120758, + -0.06842153519392014, + 0.08532039076089859, + 0.13350333273410797, + 0.03649023920297623, + -0.03904158994555473, + -0.1483580619096756, + -0.2068314403295517, + -0.05687328055500984, + 0.21108660101890564, + 0.21018920838832855, + 0.009318819269537926, + -0.037683792412281036, + -0.09845960140228271, + -0.1535443514585495, + 0.004504916723817587, + 0.20256847143173218, + 0.1799001693725586, + -0.03175490349531174, + -0.020391397178173065, + -0.007309200707823038, + -0.06765769422054291, + 0.013149870559573174, + 0.08469820767641068, + 0.04147877171635628, + -0.0027241194620728493, + 0.008016721345484257, + 0.001382349175401032, + 0.0001219741752720438, + -0.059255484491586685, + -0.03761141747236252, + 0.0381690077483654, + -0.01603613793849945, + 0.0017731477273628116, + -0.016544193029403687, + 0.09518970549106598, + 0.1735895872116089, + 0.005558829288929701, + -0.13464735448360443, + -0.0703420490026474, + 0.001990854274481535, + -0.03426021337509155, + -0.4390500485897064, + -0.11292288452386856, + 0.20430812239646912, + 0.14832687377929688, + 0.06074441969394684, + -0.03749264031648636, + 0.408058226108551, + 0.43119552731513977, + -0.3804298937320709, + -0.3694773018360138, + -0.03696960583329201, + 0.04022200033068657, + -0.0812998041510582, + -0.4322642385959625, + 0.19638888537883759, + 0.7809834480285645, + 0.11584538966417313, + -0.04975399747490883, + -0.015579828992486, + 0.1362757831811905, + 0.027220597490668297, + -0.4703449606895447, + -0.3726261258125305, + 0.11754196882247925, + -0.01204066164791584, + -0.00118898821529001, + -0.05152498185634613, + 0.08767394721508026, + 0.14183296263217926, + 0.01692730002105236, + -0.04587334021925926, + 0.011115594767034054, + 0.021572716534137726, + -0.021584773436188698, + -0.012763801962137222, + 0.05708793178200722, + 0.021982798352837563, + -0.02731800265610218, + 0.03000856563448906, + 0.006653181277215481, + -0.02485630102455616, + -0.20296195149421692, + -0.10483214259147644, + 0.20483383536338806, + 0.1350196748971939, + -0.08543248474597931, + 0.02644401416182518, + 0.26855263113975525, + 0.1071053072810173, + -0.8168368935585022, + -0.6617473363876343, + 0.02877889946103096, + 0.21807144582271576, + -0.02164696715772152, + -0.03712613880634308, + 0.9743875861167908, + 1.1631361246109009, + -0.45643851161003113, + -0.8180081844329834, + -0.28109386563301086, + -0.09115415811538696, + -0.4352502226829529, + -0.7433719038963318, + 0.5383746027946472, + 1.7271664142608643, + 0.509749174118042, + -0.0689467042684555, + 0.010011479258537292, + 0.11752951890230179, + -0.28825971484184265, + -1.113126277923584, + -0.6029489636421204, + 0.357056587934494, + 0.19766344130039215, + 0.023361098021268845, + 0.04305602237582207, + 0.24867205321788788, + 0.16359609365463257, + -0.2485191822052002, + -0.2251967489719391, + 0.030422789976000786, + 0.0049157580360770226, + -0.05497031658887863, + -0.030760835856199265, + 0.034536562860012054, + 0.019565051421523094, + -0.00933124776929617, + 0.01611645519733429, + 0.07988770306110382, + -0.021982649341225624, + -0.21876110136508942, + -0.10555483400821686, + 0.1893070936203003, + 0.14684906601905823, + -0.031080693006515503, + 0.09768003225326538, + 0.3261844515800476, + 0.1466774046421051, + -0.6738073825836182, + -0.5424039363861084, + 0.04689512774348259, + 0.22039148211479187, + -0.07084018737077713, + -0.07436021417379379, + 0.8260523080825806, + 1.0253428220748901, + -0.38162854313850403, + -0.727206289768219, + -0.2605172097682953, + -0.0996573269367218, + -0.3653049170970917, + -0.6791687607765198, + 0.43514078855514526, + 1.4186147451400757, + 0.38797008991241455, + -0.12675431370735168, + 0.02766786515712738, + 0.14237603545188904, + -0.2306709885597229, + -0.9204807877540588, + -0.5071616172790527, + 0.32662850618362427, + 0.20703284442424774, + -0.020968681201338768, + 0.014105334877967834, + 0.24642448127269745, + 0.20103473961353302, + -0.15519124269485474, + -0.22072142362594604, + 0.049920063465833664, + -0.05465548485517502, + 0.018651481717824936, + 0.030082669109106064, + 0.05234164372086525, + 0.10243640840053558, + 0.03569166734814644, + 0.038984544575214386, + 0.05248976871371269, + 0.24501988291740417, + 0.4674161374568939, + 0.7142530083656311, + 0.7423628568649292, + 0.6262048482894897, + 0.4019012451171875, + -0.010997634381055832, + 0.17266513407230377, + 0.4467124342918396, + 0.7795005440711975, + 0.8282667994499207, + 0.6824804544448853, + 0.3955397605895996, + 0.009771074168384075, + 0.10707246512174606, + 0.23039454221725464, + 0.33151063323020935, + 0.36120596528053284, + 0.3240644633769989, + 0.17939962446689606, + -0.01115038525313139, + -0.11081521213054657, + -0.2146066278219223, + -0.3572347164154053, + -0.44021451473236084, + -0.38320258259773254, + -0.24643990397453308, + 0.031578775495290756, + -0.21325217187404633, + -0.4312629997730255, + -0.7276368141174316, + -0.8273008465766907, + -0.718246340751648, + -0.4161607027053833, + -0.06636986136436462, + -0.28078269958496094, + -0.476252943277359, + -0.734549880027771, + -0.7796792984008789, + -0.6637035608291626, + -0.41896238923072815, + 0.021693198010325432, + 0.006199972704052925, + -0.016619624570012093, + -0.010678192600607872, + 0.012267512269318104, + 0.004102918319404125, + -0.004080160986632109, + -0.0029241242446005344, + -0.027252744883298874, + -0.0772257149219513, + -0.09107967466115952, + -0.11302012205123901, + -0.08569496124982834, + -0.07242150604724884, + -0.016465697437524796, + -0.04874062165617943, + -0.09103028476238251, + -0.09025602042675018, + -0.07523388415575027, + -0.06320428103208542, + -0.048220545053482056, + -0.028701437637209892, + -0.008647853508591652, + -0.022354092448949814, + -0.06076030433177948, + -0.030872423201799393, + -0.045786645263433456, + -0.04190178960561752, + 0.03718986362218857, + 0.021405767649412155, + 0.007675759959965944, + 0.02794131636619568, + 0.030316906049847603, + 0.007403802592307329, + 0.04861852154135704, + 0.023217258974909782, + 0.04545973241329193, + 0.07504793256521225, + 0.06824314594268799, + 0.07417462021112442, + 0.0769289955496788, + 0.0766506940126419, + -0.0028638055082410574, + 0.05911175534129143, + 0.055706772953271866, + 0.10735032707452774, + 0.10494870692491531, + 0.11092723160982132, + 0.09338293969631195, + 0.04235343262553215, + -0.022347571328282356, + -0.026347652077674866, + -0.06954608112573624, + -0.06944439560174942, + -0.05570404976606369, + -0.042987462133169174, + -0.056951191276311874, + -0.2151203453540802, + -0.3603246510028839, + -0.5899456143379211, + -0.6453464031219482, + -0.5338351726531982, + -0.31790611147880554, + 0.049492284655570984, + -0.12898015975952148, + -0.40155911445617676, + -0.6737278699874878, + -0.7170611619949341, + -0.5817899703979492, + -0.32979026436805725, + -0.005899591837078333, + -0.07673019915819168, + -0.190496027469635, + -0.34019437432289124, + -0.3314637243747711, + -0.2796767055988312, + -0.1381818801164627, + -0.008025999180972576, + 0.08429048955440521, + 0.2105528861284256, + 0.3415210545063019, + 0.4151126444339752, + 0.34003961086273193, + 0.21059827506542206, + -0.03514896333217621, + 0.1792585551738739, + 0.3903186321258545, + 0.6413942575454712, + 0.7557680010795593, + 0.6069726943969727, + 0.3415443003177643, + 0.03447553142905235, + 0.21517080068588257, + 0.4215562045574188, + 0.6151171922683716, + 0.6550290584564209, + 0.5680058002471924, + 0.33561068773269653, + -0.12205997854471207, + -0.0038300298620015383, + 0.3281119763851166, + -0.2328944057226181, + -0.03834507241845131, + 0.05432930961251259, + -0.014430212788283825, + 0.006271198857575655, + 0.32864242792129517, + 0.47277259826660156, + -0.5593215227127075, + -0.14971251785755157, + 0.13066314160823822, + -0.09738356620073318, + 0.2966129779815674, + 0.5606555342674255, + -0.3184640407562256, + -2.022890090942383, + -0.361995667219162, + 0.5496177673339844, + 0.02796279452741146, + -0.21818380057811737, + -0.5373459458351135, + -1.9538941383361816, + -1.9984712600708008, + 1.6747761964797974, + 1.5063239336013794, + -0.24534250795841217, + -0.040306344628334045, + -0.16963164508342743, + -0.40690454840660095, + 1.3548375368118286, + 3.922116279602051, + 0.8723023533821106, + -0.8986141681671143, + 0.06912416964769363, + 0.2192920595407486, + 0.352949321269989, + 1.2243634462356567, + 1.1395865678787231, + -1.5146961212158203, + -1.1557590961456299, + -0.05440744385123253, + -0.04629289731383324, + -0.002693743444979191, + -0.21906790137290955, + -0.5464610457420349, + -1.1933224201202393, + 0.01913866586983204, + 0.09363497048616409, + -0.06080613285303116, + -0.049100056290626526, + 0.04482033848762512, + -0.04087500274181366, + -0.009318803437054157, + 0.009458474814891815, + -0.09565524011850357, + -0.2264278084039688, + -0.0698866918683052, + 0.13825084269046783, + 0.014815542846918106, + -0.05801662430167198, + 0.012776852585375309, + -0.0753035843372345, + -0.07555855065584183, + 0.484436959028244, + 0.6397283673286438, + 0.12687323987483978, + -0.01779526099562645, + 0.05689511448144913, + 0.06747376173734665, + 0.26353734731674194, + 0.5908273458480835, + 0.4315526783466339, + -0.5426794290542603, + -0.44501280784606934, + -0.019558124244213104, + -0.03320806100964546, + -0.025809556245803833, + 0.17376014590263367, + -0.5201969742774963, + -1.2842578887939453, + -0.3674038052558899, + 0.0882175862789154, + -0.030023137107491493, + -0.1173325777053833, + 0.02555503323674202, + -0.39882710576057434, + -0.37364596128463745, + 0.3550366163253784, + 0.3903135359287262, + 0.04022252932190895, + 0.016731394454836845, + 0.11207644641399384, + -0.020967213436961174, + -0.028497911989688873, + 0.37590932846069336, + 0.14920172095298767, + 0.029958104714751244, + 0.039632707834243774, + -0.24969367682933807, + 0.16809938848018646, + 0.07703239470720291, + -0.03522319719195366, + -0.007072617299854755, + 0.07751759141683578, + -0.06782346963882446, + -0.4010501801967621, + 0.41269779205322266, + 0.1311105638742447, + -0.07331988960504532, + 0.08240311592817307, + -0.20034979283809662, + -0.4718745946884155, + -0.178948312997818, + 1.3285318613052368, + 0.20384186506271362, + -0.48546233773231506, + -0.09941625595092773, + 0.13249020278453827, + 0.29977336525917053, + 1.2681238651275635, + 1.5725642442703247, + -1.0834472179412842, + -1.0335719585418701, + 0.25975045561790466, + 0.06584863364696503, + 0.1609305590391159, + 0.25940945744514465, + -0.8426372408866882, + -2.590407609939575, + -0.4723183214664459, + 0.7581043243408203, + -0.03634117543697357, + -0.10199672728776932, + -0.3744191527366638, + -0.7823801636695862, + -0.7062401175498962, + 1.116550087928772, + 0.7735803127288818, + 0.012776976451277733, + 0.034575968980789185, + -0.10188565403223038, + 0.2212170958518982, + 0.5182898044586182, + 0.8056022524833679, + -0.1897655427455902, + -0.005556725896894932, + -0.003909373190253973, + -0.02175678312778473, + -0.04085654392838478, + -0.03573022410273552, + -0.0038509985897690058, + 0.02454996667802334, + 0.039437733590602875, + 0.02077251859009266, + 0.02166259102523327, + 0.17245841026306152, + 0.09513862431049347, + -0.10491111874580383, + -0.08084940910339355, + -0.026179829612374306, + 0.0215831957757473, + -0.16602416336536407, + -0.2803819179534912, + 0.23894084990024567, + 0.3269801735877991, + 0.04504352807998657, + 0.0009768904419615865, + 0.01959501951932907, + 0.24426960945129395, + -0.1451571136713028, + -0.5944203734397888, + -0.17875447869300842, + 0.028336334973573685, + 0.004323791246861219, + -0.045389141887426376, + 0.0343034490942955, + 0.46665430068969727, + 0.3707427978515625, + -0.114569291472435, + 0.04335101321339607, + -0.018011711537837982, + -0.021181274205446243, + -0.19074901938438416, + -0.20113815367221832, + 0.048786211758852005, + 0.08533122390508652, + -0.06084573268890381, + 0.01217757910490036, + 0.030666939914226532, + 0.05272842198610306, + 0.010849648155272007, + -0.05913804844021797, + -0.04202868044376373, + -0.0015147016383707523, + -0.03421122953295708, + 0.015080726705491543, + 0.12191007286310196, + 0.10450142621994019, + -0.04972418025135994, + -0.07557133585214615, + -0.02221665158867836, + -0.0861242413520813, + -0.14919178187847137, + -0.04388582333922386, + 0.4605262875556946, + 0.5697804093360901, + 0.1583399623632431, + -0.045628566294908524, + -0.05220475420355797, + -0.13630147278308868, + -0.7103163599967957, + -1.0178179740905762, + 0.1927143931388855, + 0.7479860186576843, + 0.47013771533966064, + 0.16943301260471344, + 0.2398149073123932, + 0.4710526168346405, + -0.5974176526069641, + -1.8564051389694214, + -0.7726883292198181, + 0.05584309995174408, + 0.08902852982282639, + 0.0931839719414711, + 0.46213099360466003, + 1.2080260515213013, + 0.6001025438308716, + -0.590207576751709, + -0.4145379662513733, + -0.04529324173927307, + -0.08303339034318924, + -0.2470429688692093, + -0.03481363505125046, + 0.4808541238307953, + 0.4001348614692688, + -0.1292688548564911, + -0.03635162487626076, + -0.006270444020628929, + -0.0314505510032177, + -0.13043232262134552, + -0.10837803781032562, + 0.10718243569135666, + 0.07523836195468903, + -0.00597786670550704, + 0.06580565124750137, + 0.11166563630104065, + 0.021869506686925888, + -0.10510984063148499, + -0.07651247084140778, + 0.01229890063405037, + -0.08976037800312042, + -0.14929910004138947, + -0.018859578296542168, + 0.4408939778804779, + 0.4029107689857483, + -0.05015433207154274, + -0.13887189328670502, + -0.04514491930603981, + -0.07346425950527191, + -0.5277182459831238, + -0.7335640788078308, + 0.24182197451591492, + 0.626846432685852, + 0.23399080336093903, + 0.09675730019807816, + 0.15529058873653412, + 0.42680656909942627, + -0.4012089967727661, + -1.3605350255966187, + -0.4793834686279297, + 0.10987094044685364, + 0.07592830061912537, + 0.003319029463455081, + 0.24004696309566498, + 0.9590277671813965, + 0.4946591258049011, + -0.4889579117298126, + -0.34744441509246826, + -0.020535729825496674, + -0.026767954230308533, + -0.2090117186307907, + -0.11841326951980591, + 0.37452432513237, + 0.39960840344429016, + -0.07025045901536942, + -0.022984744980931282, + 0.022319970652461052, + -0.0027356306090950966, + -0.13681942224502563, + -0.09797768294811249, + 0.09914079308509827, + 0.10856777429580688, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 3, 7, 7)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_501") + + initializers.append(tensor) + + list_value = [ + 3.085598945617676, + 2.2436060905456543, + 4.244357585906982, + 1.4069645404815674, + -4.00622034072876, + 2.595770835876465, + 2.7202603816986084, + 2.4405417442321777, + 1.1759933233261108, + 2.021026372909546, + 2.6628992557525635, + 6.445226192474365, + -7.029932498931885, + 1.1305793523788452, + 2.537140369415283, + 5.456772327423096, + 4.780154705047607, + 10.039976119995117, + 2.912492275238037, + 15.781542778015137, + 2.5154318809509277, + 2.628824472427368, + 2.2992050647735596, + 2.0950584411621094, + -7.93365478515625, + 2.067786931991577, + 4.094852447509766, + 1.673399806022644, + 3.1814424991607666, + 22.49496078491211, + 2.232640027999878, + 2.6427979469299316, + -9.418174743652344, + 1.790976643562317, + 2.3774726390838623, + 2.5836219787597656, + 2.5608203411102295, + 2.287343978881836, + 2.6439085006713867, + 16.859027862548828, + 1.8699607849121094, + -3.6987526416778564, + 2.6861538887023926, + 2.8997464179992676, + 2.689293384552002, + 2.6654043197631836, + 2.3799915313720703, + 2.5603086948394775, + 3.146122694015503, + 2.715951681137085, + 2.889486789703369, + 2.966134548187256, + -4.960191249847412, + 2.6123547554016113, + 1.3074164390563965, + 2.2033026218414307, + 2.2114620208740234, + 4.132844924926758, + 4.893764495849609, + 2.6469600200653076, + 2.654136896133423, + 1.9311997890472412, + 2.881012439727783, + 2.6991193294525146, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_502") + + initializers.append(tensor) + + list_value = [ + 0.057212892919778824, + 0.06299274414777756, + -0.018499961122870445, + -0.06501776725053787, + -0.015820641070604324, + 0.024293724447488785, + 0.05624663084745407, + -0.025112055242061615, + 0.043546054512262344, + 0.08439744263887405, + 0.005678815301507711, + 0.0034800865687429905, + 0.030301403254270554, + -0.011669250205159187, + -0.005434689112007618, + -0.1591511219739914, + 0.02324092946946621, + -0.018942436203360558, + 0.025366367772221565, + -0.07414374500513077, + 0.03468436002731323, + -0.003742520697414875, + -0.06651683896780014, + 0.005561002530157566, + 0.04527103528380394, + -0.13710148632526398, + 0.0025444801431149244, + 0.03583350405097008, + 0.015219246037304401, + -0.053635064512491226, + 0.004856681916862726, + -0.07223699986934662, + 0.016770021989941597, + 0.0012010147329419851, + 0.014582094736397266, + -0.005172556731849909, + 0.02009868621826172, + -0.0064261858351528645, + -0.029086023569107056, + 0.001915874076075852, + 0.0008194410474970937, + 0.01620865799486637, + 0.03067426010966301, + -0.0018463254673406482, + 0.05358384922146797, + -0.003966080490499735, + -0.05991416424512863, + -0.06455761194229126, + 0.01634763367474079, + -0.013959774747490883, + 0.03615918383002281, + 0.004434086848050356, + 0.02086004987359047, + -0.004025993403047323, + -0.8869641423225403, + 0.05558132007718086, + 0.024729542434215546, + -0.005809253081679344, + -0.025079259648919106, + 0.04757235199213028, + 0.0023902510292828083, + 0.01522061601281166, + 0.011692625470459461, + 0.023033330217003822, + -0.012664714828133583, + -0.29325294494628906, + -0.006855700630694628, + -0.243958979845047, + 0.0024398649111390114, + -0.060877203941345215, + -0.21996521949768066, + -0.008708474226295948, + -0.06639625877141953, + -0.03170674294233322, + -0.09708897024393082, + 0.013403226621448994, + 0.024766888469457626, + 0.2594103217124939, + -0.02221749909222126, + 0.0662861093878746, + -0.15123076736927032, + -0.010314224287867546, + -0.0029192541260272264, + 0.05985910817980766, + 0.021665453910827637, + 0.003247617743909359, + -0.006802591495215893, + 0.00772367138415575, + 0.0399332195520401, + 0.005198766943067312, + 0.006013805978000164, + -0.04212838411331177, + -0.03166411817073822, + 0.13363900780677795, + 0.006383878644555807, + -0.05536859482526779, + 0.02053261175751686, + 0.015062958002090454, + 0.03352641686797142, + -0.2944328486919403, + 0.019855381920933723, + -0.15567174553871155, + -0.06759943068027496, + 0.07467031478881836, + 0.01674237661063671, + 0.004549413453787565, + -0.0032498433720320463, + -0.1837870180606842, + -0.04725493863224983, + -0.111307792365551, + 0.022237055003643036, + 0.004200428258627653, + 0.00970534235239029, + -0.045657914131879807, + -0.024577995762228966, + 0.0035376595333218575, + 0.008936531841754913, + -0.03904002904891968, + 0.05013228952884674, + -0.011168933473527431, + -0.008444730192422867, + 0.0035155978985130787, + -0.023502476513385773, + 0.005275514908134937, + -0.09448224306106567, + -0.009177467785775661, + -0.010720008052885532, + 0.004110944457352161, + -0.0060218218713998795, + 0.058124978095293045, + -0.0016586220590397716, + 0.15812785923480988, + -0.049118027091026306, + -0.007983109913766384, + -0.04265601187944412, + -0.01627231575548649, + 0.33705562353134155, + 0.01555223111063242, + 0.035853929817676544, + 0.0005046340520493686, + 0.054810188710689545, + -0.08808254450559616, + -0.0013819067971780896, + -0.14938786625862122, + -0.019771935418248177, + 0.004152575507760048, + 0.021979758515954018, + 0.1985529363155365, + -0.07694264501333237, + 0.013187955133616924, + -0.016572976484894753, + -0.03094586730003357, + -0.03673199936747551, + -0.03916170820593834, + -0.003836784977465868, + -0.012262578122317791, + 0.005559554789215326, + 0.1488093137741089, + -0.01842501200735569, + -0.004847189411520958, + -0.02391587756574154, + 0.015824301168322563, + 0.012022596783936024, + 0.06724318116903305, + -0.032682593911886215, + 0.00450896704569459, + -0.0024625889491289854, + 0.00933725107461214, + -0.04473242908716202, + 0.06270455569028854, + -0.02062271721661091, + -0.01071448065340519, + -0.017757099121809006, + 0.01575278490781784, + -0.06489317119121552, + -0.01519051194190979, + 0.0028058059979230165, + 0.00917835533618927, + -0.01291860081255436, + -0.009537308476865292, + 0.041757628321647644, + 0.03203853219747543, + -0.10918509215116501, + -0.007152496371418238, + -0.06777876615524292, + 0.03223242610692978, + 0.01780836284160614, + -0.09791012853384018, + -0.009385241195559502, + 0.013184775598347187, + 0.0031673219054937363, + -0.010640445165336132, + 0.024713385850191116, + -0.026738369837403297, + -0.004191657993942499, + -0.13764967024326324, + -0.003720735665410757, + 0.01737186871469021, + 0.015459887683391571, + 0.033229030668735504, + 0.008042111992835999, + -0.007184108253568411, + 0.008226306177675724, + 0.0031303109135478735, + 0.0406314842402935, + -0.8669105768203735, + 0.02079751342535019, + -0.17030003666877747, + -0.03849703446030617, + 0.034153200685977936, + -0.007219486869871616, + 0.11227627843618393, + -0.2681085467338562, + 0.015872526913881302, + 0.10855260491371155, + -0.008631505072116852, + 0.02556358277797699, + 0.06043418496847153, + -0.012900532223284245, + -0.08834894001483917, + 0.028099440038204193, + -0.05156330019235611, + 0.032628703862428665, + 0.044928934425115585, + 0.006176372990012169, + 0.007333829998970032, + -0.037409231066703796, + -0.046724822372198105, + -0.011172871105372906, + 0.04603327810764313, + 0.03288746625185013, + -0.20848578214645386, + 0.0028185085393488407, + -0.032673876732587814, + 0.061944279819726944, + 0.016787173226475716, + 0.02703898213803768, + -0.0060023171827197075, + 0.06870592385530472, + 0.03154531493782997, + 0.02784041129052639, + 0.007780189625918865, + 0.02033168077468872, + 0.0019289497286081314, + 0.02545374445617199, + 0.04262726008892059, + 0.01301807351410389, + -0.023882156237959862, + 0.027872221544384956, + -0.013518108054995537, + -0.0031075032893568277, + 0.03753834590315819, + 0.0369209349155426, + -0.014378191903233528, + 0.004397932440042496, + -0.030286893248558044, + -0.007679021451622248, + -0.045032769441604614, + 0.032050322741270065, + -0.03373495861887932, + -0.04363032802939415, + 0.034301597625017166, + -0.07021668553352356, + 0.03942524269223213, + -0.11061309278011322, + 0.049139462411403656, + 0.04161922261118889, + -0.01507576834410429, + -0.012748259119689465, + 0.06599434465169907, + 0.007602245546877384, + -0.03973209857940674, + -0.06923151016235352, + 0.026153067126870155, + -0.04221056029200554, + -0.4828230142593384, + 0.03360651433467865, + 0.01847662217915058, + -0.08594681322574615, + 0.04071836546063423, + -0.0035729086957871914, + 0.0049045816995203495, + -0.036198534071445465, + 0.03046257793903351, + 0.013275806792080402, + 0.09266786277294159, + -0.03625647351145744, + -0.059672992676496506, + 0.050213005393743515, + -0.018153885379433632, + -0.0858495831489563, + 0.01621098257601261, + -0.03029749169945717, + 0.02193332649767399, + 0.0422661192715168, + 0.6109512448310852, + -0.01068826112896204, + -0.02184930257499218, + -0.03213764354586601, + -0.03148162364959717, + -0.055331334471702576, + 0.006972005590796471, + -0.00815682765096426, + 0.014874683693051338, + -0.012943249195814133, + -0.03318992629647255, + -0.0010484680533409119, + 0.005414161365479231, + -0.013610370457172394, + 0.008836873807013035, + -0.05890084058046341, + -0.022663919255137444, + -0.018899116665124893, + -0.01037894282490015, + 0.005064660683274269, + 0.08522599190473557, + 0.0075323861092329025, + 0.013720778748393059, + 0.032096460461616516, + -0.008450351655483246, + 0.020377663895487785, + 0.04537765309214592, + 0.014030816033482552, + 0.024340089410543442, + 0.0231801588088274, + -0.10347768664360046, + 0.041163086891174316, + -0.060614243149757385, + -0.09241361171007156, + 0.05831432715058327, + -0.16008608043193817, + -0.04505622759461403, + 0.04866329953074455, + -0.0656094029545784, + 0.09627313911914825, + 0.1153625100851059, + 0.008151216432452202, + 0.03813345730304718, + 0.05990723893046379, + 0.24788673222064972, + 0.06294118613004684, + 0.11761849373579025, + -0.0722033903002739, + -0.013892017304897308, + -0.016778236255049706, + 0.038522012531757355, + -0.015539593063294888, + 0.01263216882944107, + 0.0003969807003159076, + -0.0224238783121109, + -0.005919966846704483, + 0.031987495720386505, + -0.014712700620293617, + 0.03508169203996658, + 0.07568854838609695, + -0.011961974203586578, + 0.027983952313661575, + -0.03512958809733391, + -0.010324078612029552, + -0.2895449995994568, + 0.007338976487517357, + -0.042290836572647095, + -0.1640917807817459, + -0.034807007759809494, + -0.1268443465232849, + 0.18418198823928833, + -0.3867812156677246, + -0.14214494824409485, + 0.001021744217723608, + 0.11288078874349594, + 0.006741920951753855, + -0.006421610247343779, + 0.021150892600417137, + 0.02486848644912243, + 0.002660338068380952, + 0.03732302784919739, + 0.10844919830560684, + -0.032568808645009995, + 0.009477612562477589, + 0.053578171879053116, + -0.07421902567148209, + 0.05660263076424599, + 0.03038308583199978, + 0.049440011382102966, + 0.0395139642059803, + 0.0217339675873518, + 0.028231965377926826, + 0.1661153882741928, + -0.02168717049062252, + 0.055143170058727264, + -0.14159196615219116, + 0.05894732475280762, + 0.006888065952807665, + -0.06988262385129929, + 0.017527412623167038, + -0.007171930745244026, + -0.00448343763127923, + 0.02932717651128769, + -0.00652179354801774, + -0.002897858154028654, + 0.020487705245614052, + -0.027063967660069466, + -0.02539752423763275, + -0.1066114604473114, + -0.10011029988527298, + -0.03331710025668144, + -0.003807300003245473, + -0.010441976599395275, + -0.005605363752692938, + 0.09679440408945084, + 0.020033519715070724, + -0.010188378393650055, + -0.030630890280008316, + -0.00955540407449007, + 0.02825581096112728, + -0.4307324290275574, + 0.012557203881442547, + 0.043258048593997955, + 0.09386534243822098, + -0.009555542841553688, + 0.05304868891835213, + 0.014706632122397423, + -0.012911850586533546, + 0.0981304720044136, + -0.010722141712903976, + -0.027317194268107414, + 0.0893903523683548, + -0.19983792304992676, + -0.15778200328350067, + -0.1012115329504013, + -0.3758164644241333, + -0.05782865360379219, + -0.01230492815375328, + -0.37126046419143677, + -0.01596723683178425, + 0.0020407456904649734, + -0.017498979344964027, + 0.005369496997445822, + -0.023121315985918045, + 0.022279681637883186, + -0.006232256535440683, + 0.05115891620516777, + 0.006679570768028498, + 0.0026316209696233273, + 0.04291496425867081, + 0.04381528124213219, + -0.05994122102856636, + 0.007081915624439716, + -0.04571640491485596, + 0.07592425495386124, + -0.00836833007633686, + 0.008123279549181461, + -0.008003163151443005, + -0.003938044421374798, + 0.005643180105835199, + 0.016194086521863937, + -0.004063089843839407, + 0.012334472499787807, + 0.017072021961212158, + 0.005761854816228151, + 0.004702428821474314, + 0.005736868362873793, + 0.0017962371930480003, + 0.059996701776981354, + 0.19533602893352509, + 0.02649352326989174, + -0.06493135541677475, + -0.05955052375793457, + 0.015692468732595444, + -0.10623155534267426, + 0.07290898263454437, + 0.036108434200286865, + -0.01248949021100998, + 0.16444285213947296, + -0.005899128969758749, + 0.07875277101993561, + 0.0014204353792592883, + 0.03381470963358879, + -0.09680792689323425, + 0.002102318685501814, + 0.026962973177433014, + 0.031665392220020294, + -0.18168538808822632, + 0.11163855344057083, + -0.5409999489784241, + 0.07833191007375717, + -0.005324948113411665, + 0.0267564058303833, + 0.02250477857887745, + 0.03249068558216095, + -0.18441715836524963, + -0.006447427906095982, + 0.037927329540252686, + 0.0005173985846340656, + -0.02617005631327629, + 0.05929232016205788, + -0.028510913252830505, + 0.05447050556540489, + 0.012390155345201492, + 0.00046797769027762115, + -0.008598590269684792, + -0.17247197031974792, + -0.02855759859085083, + 0.033968932926654816, + -0.09011702984571457, + 0.05276056379079819, + 0.03299655020236969, + -0.005699596833437681, + -0.1954648792743683, + 0.011109501123428345, + -0.0013570536393672228, + -0.6543989181518555, + 0.009102803654968739, + 0.0407538004219532, + 0.04312055557966232, + 0.027609223499894142, + -0.035538043826818466, + 0.027167823165655136, + -0.024043193086981773, + 0.0047575319185853004, + -0.006788836792111397, + 0.025714389979839325, + 0.007848678156733513, + -0.07680192589759827, + 0.009700766764581203, + -0.0097329281270504, + 0.00586724653840065, + 0.022815868258476257, + -0.023448282852768898, + -0.05608998239040375, + 0.10786863416433334, + -0.02803603559732437, + 0.012898198328912258, + -0.009270391426980495, + -0.021972229704260826, + 0.26533082127571106, + -0.01021308358758688, + -0.01972626894712448, + 0.062940314412117, + 0.022569671273231506, + 0.027042347937822342, + -0.05669092759490013, + -0.01200617104768753, + -0.006279367487877607, + -0.009608528576791286, + -0.013600943610072136, + -0.02187415212392807, + 0.0351138636469841, + 0.006282923277467489, + -0.011123511008918285, + -0.009205769747495651, + 0.001010146806947887, + -0.4796978235244751, + -0.0030205894727259874, + -0.011987377889454365, + -0.027548225596547127, + 0.009372347965836525, + -0.005388603545725346, + -0.006444129627197981, + -0.02501147985458374, + 0.027465635910630226, + 0.027784524485468864, + 0.006878893356770277, + -0.027763860300183296, + -0.0047700353898108006, + -0.018965192139148712, + 0.027898501604795456, + 0.022454144433140755, + 0.02973407506942749, + 0.03505602851510048, + 0.04003170132637024, + -0.004336829297244549, + -0.01998550072312355, + -0.06097743660211563, + -0.07844759523868561, + 0.0013787010684609413, + 0.0066132270731031895, + -0.03124997951090336, + 0.0313432514667511, + 0.047656893730163574, + 0.06175797060132027, + -0.02077358029782772, + -0.004535601008683443, + -0.10219905525445938, + -0.07125344127416611, + -0.06927482783794403, + -0.04813461750745773, + -0.02618095651268959, + -0.01255929097533226, + -0.009180150926113129, + -0.005838831886649132, + 0.09108023345470428, + -0.032710760831832886, + 0.03091445378959179, + -0.01955563761293888, + 0.0959300771355629, + -0.09353741258382797, + -0.0761636272072792, + -0.023445438593626022, + -0.012328366748988628, + 0.05850536748766899, + -0.052494827657938004, + 0.0025638933293521404, + -0.017152179032564163, + -0.004435579292476177, + 0.12312240898609161, + -0.007241012528538704, + 0.09605048596858978, + 0.03355967625975609, + -0.015987426042556763, + -0.03470349311828613, + -0.02499505691230297, + -0.015004142187535763, + -0.018609771504998207, + -0.06654462963342667, + 0.013861652463674545, + -0.005973289255052805, + -0.04734775796532631, + 0.08755116909742355, + 0.03012942522764206, + 0.07887610793113708, + -0.01827712170779705, + 0.10793066769838333, + 0.10793614387512207, + -0.01075535174459219, + 0.03439560532569885, + 0.011567444540560246, + 0.0016386889619752765, + -0.031207261607050896, + -0.01707504875957966, + 0.20471863448619843, + 0.0025428179651498795, + 0.004082779865711927, + -0.012389302253723145, + 0.0400562584400177, + -0.21075034141540527, + 0.012872264720499516, + -0.01639414019882679, + 0.016652485355734825, + 0.0016037120949476957, + -0.006540367379784584, + -0.0068405005149543285, + -0.2484254390001297, + 0.0008089764742180705, + -0.022340824827551842, + -0.005441636312752962, + 0.002882100408896804, + 0.008654038421809673, + 0.07159754633903503, + -0.02537086047232151, + 0.011997461318969727, + -0.49913132190704346, + -0.02300887741148472, + 0.044442202895879745, + 0.001787978457286954, + 0.010291379876434803, + 0.009601960889995098, + -0.5312613248825073, + -0.014247804880142212, + 0.06685849279165268, + 0.035772595554590225, + 0.03432310372591019, + 0.03151272237300873, + -0.10318460315465927, + -0.030476456508040428, + -0.004469831008464098, + -0.16645164787769318, + -0.021104637533426285, + 0.013934006914496422, + -0.011767406016588211, + 0.008054615929722786, + 0.06089277192950249, + 0.0003409573109820485, + -0.0053401123732328415, + 0.05970478057861328, + -0.004363172687590122, + 0.014423285610973835, + -0.002795026171952486, + -0.019875092431902885, + -0.07540513575077057, + -0.09043378382921219, + 0.00750827556475997, + -0.045314721763134, + -0.00724808732047677, + 0.005193864461034536, + -0.020468784496188164, + -0.01098695583641529, + -0.0003122477210126817, + -0.007263806648552418, + -0.03325646370649338, + 0.021689830347895622, + -0.13272541761398315, + 0.02332465350627899, + -0.019292252138257027, + 0.05533658340573311, + -0.018616480752825737, + -0.015228793025016785, + -0.28432801365852356, + -0.29721561074256897, + 0.04648810625076294, + -0.014750649221241474, + -0.15370936691761017, + -0.1497083604335785, + 0.013243601657450199, + 0.042343802750110626, + -0.017519792541861534, + -0.0161418616771698, + 0.00807454064488411, + -0.023562468588352203, + -0.0315413773059845, + 0.03386805206537247, + 0.2854529917240143, + 0.0191020630300045, + -0.49126777052879333, + 0.052687134593725204, + -0.023298051208257675, + -0.009119837544858456, + 0.05149759724736214, + -0.8527837991714478, + 0.08062390983104706, + 0.057379938662052155, + -0.020724931731820107, + -0.006624895613640547, + 0.05322050303220749, + 0.017887847498059273, + 0.04229281470179558, + 0.04171830415725708, + 0.029683062806725502, + -0.00028416322311386466, + 0.1112222746014595, + -0.0448714978992939, + -0.005255761090666056, + 0.017773712053894997, + -0.0016064767260104418, + -0.013840594328939915, + -0.00398495327681303, + -4.32919041486457e-05, + 0.040796443819999695, + 0.018185198307037354, + -0.018671950325369835, + 0.0028256692457944155, + -0.020582057535648346, + 0.05567716807126999, + -0.056062404066324234, + 0.01614757999777794, + -0.0029299987945705652, + 0.048686008900403976, + 0.04299888014793396, + 0.12249592691659927, + 0.01469603180885315, + -0.1254546344280243, + -0.18532024323940277, + -0.003263876074925065, + 0.014804725535213947, + 0.004450956825166941, + -0.013681051321327686, + -0.0030781759414821863, + -0.03433656692504883, + -0.0035507124848663807, + 0.1600082814693451, + -0.028547707945108414, + -0.00989136379212141, + -0.012126478366553783, + -0.12963305413722992, + 0.008547360077500343, + 0.017959514632821083, + -0.012571084313094616, + 0.0008666724897921085, + -0.010519342496991158, + -0.009684977121651173, + -0.04285729303956032, + 0.015031769871711731, + -0.030043724924325943, + 0.018907636404037476, + 0.08019450306892395, + -0.04836742579936981, + 0.01025464478880167, + -0.004908542148768902, + -0.10327022522687912, + -0.10163667798042297, + -0.03403499722480774, + -0.019678063690662384, + -0.043049123138189316, + 0.0384567566215992, + -0.05596519634127617, + -0.09381429851055145, + -0.18688108026981354, + -0.09762943536043167, + -0.03164997324347496, + -0.006416287273168564, + 0.07003920525312424, + -0.016646990552544594, + -0.025972194969654083, + -0.028768088668584824, + -0.06332779675722122, + 0.045144014060497284, + -0.03735211119055748, + -0.010442189872264862, + 0.10948455333709717, + 0.14629514515399933, + -0.023416690528392792, + -0.01347778458148241, + 0.020830679684877396, + 0.0003131759003736079, + 0.007049075793474913, + 0.06547018885612488, + 0.03152740001678467, + 0.08380027115345001, + 0.03185325488448143, + -0.015359007753431797, + 0.08864206075668335, + 0.032676901668310165, + -0.002908645663410425, + 0.053111132234334946, + 0.0026159954722970724, + -0.05177146941423416, + -0.033048152923583984, + -0.0020293137058615685, + -0.07363513857126236, + -0.17662747204303741, + 0.004798125941306353, + 0.07139395922422409, + 0.019802849739789963, + 0.009199771098792553, + -0.009043877013027668, + -0.07681646943092346, + -0.06748555600643158, + 0.05094710737466812, + 0.0014789587585255504, + -0.0166088305413723, + -0.27988284826278687, + 0.03634800389409065, + 0.05322619527578354, + -0.15566207468509674, + -0.019964642822742462, + -0.010204506106674671, + -0.011832086369395256, + -0.0680927112698555, + -0.05793820694088936, + 0.0020100779365748167, + -0.24647225439548492, + 0.04904041066765785, + -0.05589786171913147, + -0.030167482793331146, + 0.023974033072590828, + -0.22719347476959229, + 0.019620347768068314, + -0.18078163266181946, + -0.11321499198675156, + -0.023790234699845314, + -0.1266157031059265, + 0.01117659267038107, + 0.13824795186519623, + -0.024211348965764046, + -0.0548308864235878, + 0.04849318787455559, + -0.0016174454940482974, + -0.01826266385614872, + 0.006709347013384104, + -0.350631982088089, + 0.03139018639922142, + 0.021502504125237465, + -0.12596893310546875, + 0.04311670735478401, + -0.005905786994844675, + -0.0807335153222084, + -0.07214773446321487, + -0.2054852843284607, + -0.04526854678988457, + -0.09145382046699524, + 0.002603817731142044, + -0.01951524056494236, + -0.0028278473764657974, + -0.03270411863923073, + -0.0003385065938346088, + -0.019816655665636063, + -0.003430107608437538, + 0.010664679110050201, + 0.030127109959721565, + 0.02611778862774372, + 0.030213139951229095, + 0.04682943969964981, + 0.010338326916098595, + -0.02618880569934845, + 0.014982170425355434, + -0.06979402899742126, + 0.06403722614049911, + 0.025545112788677216, + -0.11981001496315002, + 0.004320457112044096, + 0.008849565871059895, + 0.07450827211141586, + -0.04322020336985588, + -0.07648278027772903, + 0.009221173822879791, + -0.12771189212799072, + 0.027474528178572655, + -0.1637975573539734, + -0.022587651386857033, + 0.0713210329413414, + -0.09652210026979446, + -0.04942077025771141, + -0.08977267891168594, + -0.004629603121429682, + -0.09891843795776367, + 0.0004028059483971447, + 0.12999524176120758, + 0.009417874738574028, + -0.012465995736420155, + 0.09959464520215988, + 0.012048770673573017, + 0.00529639283195138, + -0.1231047734618187, + -0.010156300850212574, + -0.0067022680304944515, + 0.09231371432542801, + 0.1372271031141281, + 0.01140755694359541, + -0.014376018196344376, + 0.009014246053993702, + -0.0558021254837513, + 0.009297777898609638, + -0.023461824283003807, + 0.12312523275613785, + 0.0013492326252162457, + -0.10130659490823746, + 0.07867099344730377, + -0.04363301396369934, + -0.05203291028738022, + 0.010715829208493233, + 0.2679101228713989, + 0.047242000699043274, + 0.009700302965939045, + -0.004188477993011475, + 0.04595324397087097, + -0.10256988555192947, + 0.013266253285109997, + 0.13415516912937164, + -0.06461263447999954, + -0.04262775555253029, + 0.014638054184615612, + -0.020396970212459564, + 0.016008291393518448, + 0.012964261695742607, + 0.030219901353120804, + -0.03906702250242233, + -0.009459082037210464, + -0.006880247965455055, + 0.009383107535541058, + 0.0591101311147213, + -0.049882922321558, + -0.014105924405157566, + -0.04896679148077965, + 0.021726086735725403, + -0.013863577507436275, + -0.05801064148545265, + -0.031143831089138985, + 0.0010298469569534063, + -0.03104572743177414, + 0.1193046048283577, + 0.00880056619644165, + -0.01678626798093319, + 0.0014990485506132245, + -0.001967367948964238, + -0.0053575835190713406, + -0.006879259832203388, + -0.008937212638556957, + 0.014141763560473919, + 0.00687083275988698, + -0.0012949275551363826, + 0.017160816118121147, + -0.035110652446746826, + -0.00976842176169157, + 0.026605995371937752, + 0.004003277514129877, + 0.010927689261734486, + 0.002173327375203371, + -0.05133439600467682, + -0.04658171907067299, + 0.03023359179496765, + -0.015038624405860901, + 0.016580749303102493, + 0.02393144741654396, + 0.004817661829292774, + -0.008468102663755417, + 0.017239807173609734, + 0.019924553111195564, + 0.02557404898107052, + 0.01985766738653183, + -0.01881517469882965, + -0.14637643098831177, + -0.005403783638030291, + -0.013156545348465443, + -0.3882855176925659, + 0.01537711638957262, + 0.005061861593276262, + 0.018044542521238327, + 0.00010373388067819178, + -0.01769324019551277, + -0.020439250394701958, + 0.01761222817003727, + 0.017716309055685997, + -0.01828574948012829, + 0.0059916484169662, + 0.006117791403084993, + -0.0025541253853589296, + 0.01598154753446579, + 0.0015296537894755602, + 0.006711189169436693, + -0.005831963382661343, + 0.024547481909394264, + 0.011665170080959797, + 0.013990279287099838, + -0.009193074889481068, + -0.0014407691778615117, + 0.0025373499374836683, + -0.001535113900899887, + 0.022016262635588646, + 0.002165747107937932, + -0.00010288839985150844, + -0.01185672264546156, + 0.3959958255290985, + -0.06701132655143738, + 0.024550342932343483, + -0.007259713020175695, + 0.00011224728223169222, + 0.08959072828292847, + 0.006745494436472654, + -0.007461291737854481, + -0.0010788652580231428, + -0.003997487016022205, + 0.0023250498343259096, + 0.005845727398991585, + 0.002441686810925603, + 0.0010628585005179048, + 0.004687050357460976, + 0.03825820982456207, + 0.0027951127849519253, + 0.004356732591986656, + 0.0036379920784384012, + -0.00048690394032746553, + -0.31681910157203674, + 0.01621195860207081, + 0.009373913519084454, + -0.005099120549857616, + 0.004866141825914383, + 0.008112045004963875, + -0.009933174587786198, + -0.006929770577698946, + 0.005561198107898235, + -0.2225065976381302, + -0.00019208311277907342, + -0.003284667618572712, + 0.010527989827096462, + -0.010160842910408974, + -0.008410060778260231, + 0.004605174530297518, + 0.01542133092880249, + 0.013958578929305077, + 0.0021779180970042944, + 0.002810562262311578, + 0.001369283301755786, + -0.0003347232413943857, + 0.013902815990149975, + -0.0022218015510588884, + 0.00024955783737823367, + -0.0019350153161212802, + 0.0025213193148374557, + -0.0054915109649300575, + -0.00011564489977899939, + -0.0037644850090146065, + -0.002863431815057993, + -0.0025196163915097713, + 0.02352992817759514, + 0.00354134407825768, + -0.010700036771595478, + -0.03428381308913231, + 0.008170859888195992, + 0.005420713219791651, + -0.0013479178305715322, + 0.0015741022070869803, + -0.18286381661891937, + 0.03189067915081978, + 0.0014371845172718167, + -4.885893940809183e-05, + -0.004666821099817753, + -0.026595929637551308, + -0.0064376350492239, + 0.01583540253341198, + -0.085715651512146, + -0.00916224904358387, + -0.3605174124240875, + 0.019973354414105415, + 0.05533794313669205, + 0.053907446563243866, + 0.030877795070409775, + -0.919844925403595, + 8.968543988885358e-05, + -0.02068270742893219, + 0.012602192349731922, + 0.03245612978935242, + 0.06622699648141861, + 0.00882122665643692, + -0.03616628423333168, + -0.02428283728659153, + 0.003318701172247529, + -0.0007259293342940509, + -0.026197656989097595, + -0.059503961354494095, + 0.029495801776647568, + -0.006955073680728674, + -0.01926456019282341, + 0.009927013888955116, + 0.059641581028699875, + 0.0016886347439140081, + -0.029346982017159462, + 0.01948450319468975, + -0.04397860914468765, + 0.025248751044273376, + 0.04597266763448715, + 0.009454794228076935, + -0.018872544169425964, + -0.039650529623031616, + 0.026324709877371788, + -0.01808176562190056, + 0.028935831040143967, + 0.009501701220870018, + -0.05183069407939911, + -0.005787428934127092, + -0.021436212584376335, + 0.029735956341028214, + 0.0350160151720047, + 0.033825185149908066, + 0.03185566887259483, + 0.018431033939123154, + 0.02450188808143139, + 0.03271135315299034, + -0.0027792940381914377, + -0.0004625302099157125, + 0.01268392987549305, + 0.045023106038570404, + 0.05562014505267143, + 0.029052015393972397, + -0.002513203304260969, + -0.08349838852882385, + 7.017837560852058e-06, + -0.0014392733573913574, + 0.016982918605208397, + 0.016358936205506325, + -0.024013325572013855, + -0.004375616554170847, + -0.03734249249100685, + 0.04336351156234741, + 0.07323610782623291, + -0.0243068914860487, + 0.009403819218277931, + 0.02663031965494156, + 0.01930687017738819, + 0.02175578847527504, + 0.01639295555651188, + 0.024892140179872513, + 0.031219134107232094, + 0.02986173704266548, + -0.002100786194205284, + 0.05054357647895813, + 0.04015854373574257, + 0.0048207067884504795, + -0.03244275599718094, + 0.027246609330177307, + 0.00409608893096447, + -0.0054193479008972645, + 0.07014931738376617, + 0.009954879060387611, + 0.022472694516181946, + -0.47738370299339294, + -0.019097158685326576, + 0.028984038159251213, + -0.042564358562231064, + -0.006040808744728565, + 0.04094231128692627, + -0.007740774191915989, + -0.07854597270488739, + 0.003920051269233227, + -0.050799619406461716, + 0.023691626265645027, + 0.019952887669205666, + 0.00716764759272337, + -0.0046928380616009235, + 0.00041822553612291813, + 0.006359069608151913, + 0.017860781401395798, + -0.22999149560928345, + -0.02180831879377365, + -0.024055887013673782, + -0.0226126741617918, + -0.01795077696442604, + 0.015591473318636417, + -0.004053472075611353, + 0.016760380938649178, + 0.03378744795918465, + -0.0027090508956462145, + 0.00999806821346283, + 0.019252799451351166, + 0.0027550198137760162, + 0.03454355522990227, + -0.0295003242790699, + -0.007663591764867306, + 0.061172280460596085, + 0.049142658710479736, + -0.00858291145414114, + -0.0035321018658578396, + -0.7689260244369507, + 0.0004916944890283048, + 0.02915046364068985, + 0.017000442370772362, + -0.003298018593341112, + -0.0405484102666378, + 0.021160880103707314, + 0.0013289587805047631, + -0.07510386407375336, + 0.03890690207481384, + 0.03729970380663872, + -0.04906352981925011, + -0.10020274668931961, + 0.01506283599883318, + -0.053726132959127426, + 0.016631007194519043, + 0.03425036743283272, + 0.03358260169625282, + -0.023937245830893517, + -0.13656578958034515, + -0.13947314023971558, + 0.012915699742734432, + 0.02431132085621357, + -0.03089652583003044, + 0.1382707953453064, + 0.056695129722356796, + -0.09263960272073746, + 0.10406216233968735, + 0.02619105577468872, + -0.01678614132106304, + -0.16045455634593964, + 8.974489173851907e-05, + -0.03521093726158142, + -0.028908027336001396, + 0.21234789490699768, + -0.02046572044491768, + -0.09703273326158524, + 0.05248226970434189, + 0.011973158456385136, + 0.004557646345347166, + -0.018632734194397926, + -0.1649131029844284, + -0.00682018743827939, + -0.12712189555168152, + 0.10513507574796677, + 0.020745709538459778, + 0.02996259182691574, + -0.15409024059772491, + -0.08719073981046677, + -0.14634187519550323, + -0.16255779564380646, + -0.15963757038116455, + -0.1324772834777832, + -0.022830091416835785, + -0.06426219642162323, + -0.025459224358201027, + 0.00281702633947134, + 0.03255268186330795, + -0.05778049677610397, + -0.30381152033805847, + -0.06582051515579224, + -0.033722274005413055, + 0.014956191182136536, + 0.004153797868639231, + 0.2391217201948166, + -0.0311420951038599, + 0.001518488978035748, + 0.019769812002778053, + -0.056324463337659836, + -0.006009253207594156, + -0.21367721259593964, + -0.0481688529253006, + 0.22422266006469727, + 0.0402204655110836, + 0.1432792693376541, + 0.14159953594207764, + -0.0025862890761345625, + -0.028965365141630173, + 0.011978867463767529, + 0.161293163895607, + 0.028642605990171432, + -0.008417634293437004, + -0.10145614296197891, + 0.08381767570972443, + 0.05199432373046875, + 0.18680602312088013, + -0.023287687450647354, + 0.03601476550102234, + 0.03738229721784592, + 0.19291405379772186, + 0.03553088754415512, + 0.05483124405145645, + 0.09577616304159164, + -0.004635817836970091, + 0.052481625229120255, + -0.042084019631147385, + -0.2629147469997406, + -0.006157668773084879, + -0.0401761569082737, + 0.02154349908232689, + -0.056558139622211456, + -0.003753019031137228, + 0.01922912523150444, + 0.1291409730911255, + -0.21358416974544525, + 0.004696246236562729, + 0.13787509500980377, + -0.07022479176521301, + -0.06828727573156357, + 0.09193858504295349, + -0.06863763928413391, + -0.05677935853600502, + -0.030970478430390358, + -0.10181070864200592, + -0.1247706487774849, + 0.014181962236762047, + -0.09259836375713348, + -0.03174220770597458, + -0.014812505804002285, + -0.024658311158418655, + -0.04815720021724701, + -0.01683010160923004, + 0.015726473182439804, + 0.002938281511887908, + -0.1586887538433075, + -0.29276973009109497, + -0.029981529340147972, + -0.046828676015138626, + -0.04909103736281395, + 0.06043976545333862, + 0.03698069602251053, + -0.04807118698954582, + 0.0943484902381897, + 0.01930702105164528, + 0.06498143821954727, + 0.0381690077483654, + -0.19611406326293945, + 0.006944946013391018, + 0.06454038619995117, + -0.19779883325099945, + 0.04966692253947258, + 0.046355295926332474, + 0.0590626522898674, + -0.24392037093639374, + -0.0018132536206394434, + 0.010944955050945282, + -0.014556891284883022, + 0.051466893404722214, + -0.0059846509248018265, + -0.06719732284545898, + 0.030604040250182152, + 0.051190104335546494, + -0.053196243941783905, + -0.06912374496459961, + -0.06263922154903412, + 0.05626852437853813, + 0.013047950342297554, + -0.005828890949487686, + 0.056055404245853424, + 0.007044378202408552, + 0.030499491840600967, + -0.035373322665691376, + 0.030934391543269157, + 0.04358363524079323, + 0.001537138712592423, + 0.005963161122053862, + -0.005889860913157463, + 0.053225863724946976, + 0.052091702818870544, + -0.02871675044298172, + 0.05662619322538376, + -0.4585985839366913, + 0.06490323692560196, + 0.02542230300605297, + 0.017592567950487137, + 0.05066920816898346, + -0.20954127609729767, + -0.06689731031656265, + -0.3632309138774872, + -0.03407476842403412, + 0.04976007342338562, + 0.03856723755598068, + 0.009329214692115784, + -0.10107281804084778, + 0.007077769376337528, + -0.005482642911374569, + 0.04388934373855591, + 0.03984231874346733, + 0.005358297843486071, + 0.05032944679260254, + 0.007170544005930424, + 0.017318176105618477, + -0.03577208146452904, + -0.02195456624031067, + 0.014414021745324135, + -0.008203372359275818, + 0.04585091397166252, + -0.012298643589019775, + 0.03959968313574791, + -0.06015963852405548, + -0.1360240876674652, + -0.07704123109579086, + -0.0842466950416565, + -0.11261942237615585, + 0.0433686338365078, + -0.1059969812631607, + 0.014813154004514217, + 0.04216694459319115, + 0.10441470146179199, + 0.04579426348209381, + 0.026033954694867134, + 0.08725529909133911, + -0.14662955701351166, + -0.0726592168211937, + 0.1293957382440567, + 0.013497715815901756, + -0.01318936888128519, + -0.05188713222742081, + 0.08793413639068604, + 0.1094818189740181, + 0.07991892844438553, + 0.03549068048596382, + -0.04469897970557213, + -0.10442564636468887, + 0.13456915318965912, + 0.01154977548867464, + -0.05959299951791763, + 0.01768219843506813, + 0.0179652888327837, + -0.010112428106367588, + 0.020603090524673462, + -0.7144030928611755, + 0.20126283168792725, + 0.058172807097435, + -0.10543914139270782, + 0.07461538910865784, + -0.1744592934846878, + 0.055722273886203766, + -0.046595826745033264, + 0.06237049773335457, + 0.05800141766667366, + 0.04118870943784714, + 0.002582935383543372, + 0.010623090900480747, + -0.0439014658331871, + 0.044685740023851395, + -0.017063472419977188, + -0.0173367727547884, + -0.04761765897274017, + 0.06136244907975197, + 0.08495236933231354, + 0.24923592805862427, + -0.061080869287252426, + 0.15922360122203827, + -0.09322690963745117, + -0.09617402404546738, + 0.0029533954802900553, + 0.12630371749401093, + 0.0011397749185562134, + 0.0005059551913291216, + -0.060922350734472275, + -0.16446451842784882, + 0.057099178433418274, + 0.03073902614414692, + -0.031064951792359352, + 0.012277435511350632, + 0.020447896793484688, + 0.06010727211833, + 0.07065457105636597, + 0.026963504031300545, + 0.010798406787216663, + -0.02631279267370701, + 0.02046871930360794, + -0.004800989292562008, + -0.03282550349831581, + 0.053904879838228226, + -0.03294985368847847, + -0.4204113185405731, + 0.028552187606692314, + 0.023685462772846222, + 0.0017703581834211946, + 0.02868991158902645, + -0.3585520088672638, + -0.011516556143760681, + -0.00248165475204587, + 0.011379038915038109, + 0.0459531806409359, + 0.015357235446572304, + 0.05573337897658348, + 0.06516549736261368, + 0.02981666848063469, + 0.05498211458325386, + 0.028714550659060478, + -0.005899528972804546, + 0.008476868271827698, + 0.11328839510679245, + 0.020578190684318542, + -0.15382742881774902, + 0.015724696218967438, + -0.08402770012617111, + 0.060314107686281204, + 0.032343748956918716, + 0.014438764192163944, + -0.13614842295646667, + -0.0017508765449747443, + 0.09998518973588943, + -0.06364594399929047, + 0.049632295966148376, + -0.11922458559274673, + -0.08834195137023926, + 0.019541991874575615, + 0.06320779770612717, + 0.017419861629605293, + -0.0028468866366893053, + -0.14753428101539612, + 0.02623703144490719, + -0.011462770402431488, + 0.06676206737756729, + -0.014891563914716244, + -0.002118025440722704, + 0.02519390918314457, + -0.29581141471862793, + 0.0264339130371809, + 0.04027356952428818, + 0.00412194337695837, + 0.03778498247265816, + -0.012331741861999035, + 0.15336745977401733, + -0.034510836005210876, + 0.0319819413125515, + 0.01916184462606907, + 0.04952343553304672, + -0.026733938604593277, + -0.014996573328971863, + 0.0010714810341596603, + 0.01959756202995777, + -0.0392388179898262, + -0.0052064210176467896, + -0.05015777423977852, + -0.0002977418771479279, + -0.04029487073421478, + -0.012846150435507298, + -0.09198840707540512, + 0.0118671590462327, + -0.06176264211535454, + 0.006427878048270941, + 0.04043034091591835, + -0.017270859330892563, + -0.012422707863152027, + 0.01713552325963974, + -0.026697810739278793, + 0.2446632832288742, + -0.020500628277659416, + -0.0012782106641680002, + -0.13429665565490723, + 0.07528743892908096, + -0.002225265372544527, + 0.06695574522018433, + 0.0017388156848028302, + -0.0629071593284607, + -0.05081196129322052, + 0.042025983333587646, + 0.029097404330968857, + 0.07048555463552475, + -0.11881273239850998, + 0.012633765116333961, + -0.06181430071592331, + 0.038810230791568756, + 0.05186169967055321, + 0.03248963877558708, + 0.07868267595767975, + 0.024977494031190872, + 0.023991582915186882, + 0.0023529180325567722, + 0.07197123020887375, + 0.02653665468096733, + 0.058702051639556885, + 0.015001803636550903, + 0.043739400804042816, + -0.07251746207475662, + 0.045659150928258896, + -0.02111324854195118, + 0.26666632294654846, + 0.1975221484899521, + -0.031074335798621178, + 0.029075143858790398, + 0.013020229525864124, + 0.015244663693010807, + 0.01387549377977848, + -0.025354426354169846, + 0.06151636317372322, + -0.034430794417858124, + 0.00752665288746357, + 0.1678706705570221, + -0.016560610383749008, + 0.0421285480260849, + -0.02527586743235588, + -0.02166694961488247, + -0.034658536314964294, + 0.036866605281829834, + -0.036233626306056976, + 0.02042747661471367, + 0.028099242597818375, + 0.020503878593444824, + 0.022789381444454193, + 0.08666791766881943, + -0.06426636874675751, + -0.043599683791399, + 0.1136128157377243, + 0.020200412720441818, + -0.003839759388938546, + -0.06010120362043381, + -0.02218424715101719, + 0.09008956700563431, + 0.008711264468729496, + -0.04874516651034355, + -0.011533043347299099, + -0.036206502467393875, + -0.006006627343595028, + -0.0350450798869133, + 0.005623341538012028, + 0.09562186151742935, + -0.03952183946967125, + -0.013931595720350742, + -0.020029470324516296, + 0.0022144403774291277, + -0.020198611542582512, + 0.012238736264407635, + 0.054415784776210785, + -0.024457741528749466, + -0.01174110360443592, + 0.031656913459300995, + 0.060322560369968414, + 0.01573050767183304, + 0.03361794352531433, + 0.022875478491187096, + 0.036340806633234024, + -0.02932620421051979, + 0.0224352665245533, + -0.013475337065756321, + -0.030774995684623718, + 0.013921404257416725, + -0.01229875348508358, + -0.07986237108707428, + -0.007543445099145174, + 0.05208213999867439, + -0.04440496116876602, + -0.029659371823072433, + -0.029070377349853516, + 0.07376870512962341, + -0.07208643853664398, + -0.05429431423544884, + -0.007887271232903004, + 0.011400371789932251, + 0.014227204024791718, + 0.01763899251818657, + -0.0426466204226017, + 0.0024213625583797693, + 0.02564665488898754, + 0.0020850151777267456, + 0.027386819943785667, + 0.12722602486610413, + -0.060991525650024414, + -0.009061425924301147, + 0.014208497479557991, + -0.006956137716770172, + 0.09096626192331314, + 0.0037735258229076862, + -0.8347064852714539, + -0.2857951521873474, + 0.0011818337952718139, + 0.0341162234544754, + -0.04230167716741562, + 0.05230262130498886, + 0.08486262708902359, + -0.34235459566116333, + -0.02393503487110138, + 0.02718495950102806, + 0.050966840237379074, + 0.024611525237560272, + -0.004936584271490574, + -0.036420952528715134, + -0.009803534485399723, + 0.05421328917145729, + 0.008357672952115536, + 0.020987343043088913, + -0.007292840629816055, + 0.018060531467199326, + 0.06739793717861176, + 0.06161382421851158, + 0.000842935056425631, + -0.007857701741158962, + 0.023870037868618965, + -0.009690430946648121, + -0.04231289029121399, + -0.22531479597091675, + 0.034284885972738266, + 0.07360551506280899, + 0.0421777106821537, + 0.000788167177233845, + -0.3953339457511902, + -0.042627450078725815, + -0.02774403616786003, + 0.02647743932902813, + -0.01561375055462122, + 0.04745408892631531, + 0.021774733439087868, + 0.006606150884181261, + 0.03879173845052719, + 0.06500626355409622, + 0.044954728335142136, + 0.01523532159626484, + 0.04741065576672554, + -0.13645507395267487, + 0.0038059696089476347, + -0.012993253767490387, + -0.004529603291302919, + 0.03268986567854881, + -0.025349941104650497, + -0.02268051542341709, + -0.0001516443444415927, + -0.010289257392287254, + -0.0010476588504388928, + -0.0690254345536232, + 0.04298266023397446, + -0.05470968782901764, + 0.04369102790951729, + -0.007372597698122263, + 0.027607066556811333, + 0.0009343988494947553, + -0.09573916345834732, + 0.04389296472072601, + -0.01522558368742466, + -0.03138086944818497, + 0.04511113464832306, + -0.0342172235250473, + -0.00033129166695289314, + -0.037289440631866455, + 0.055575959384441376, + 0.01849759928882122, + 0.03041103295981884, + -0.01965116336941719, + 0.07604960352182388, + -0.0399625338613987, + -0.008190250024199486, + -0.015386211685836315, + -0.04315667226910591, + 0.0023679479490965605, + 0.018971435725688934, + -0.005599244497716427, + -0.029607947915792465, + 0.07574024051427841, + -0.013816094025969505, + 0.04464992880821228, + 0.00032806122908368707, + 0.06071484833955765, + 0.04261377081274986, + 0.012208743952214718, + 0.0801805928349495, + 0.02875029854476452, + -0.0662921741604805, + 0.015754999592900276, + 0.05831082537770271, + 0.03810921683907509, + 0.05483977496623993, + -0.019509335979819298, + 0.0032034649048000574, + 0.011807492934167385, + -0.01916244812309742, + 0.022101666778326035, + -0.0366031751036644, + 0.10915965586900711, + 0.030322788283228874, + -0.028386037796735764, + -0.05443429946899414, + -0.02489445172250271, + 0.0892239362001419, + -0.05427740886807442, + -0.034238025546073914, + -0.04136161506175995, + -0.041148390620946884, + 0.06879492849111557, + -0.37424594163894653, + 0.028803903609514236, + 0.05349116027355194, + 0.0359492301940918, + -0.3629145622253418, + -0.17875684797763824, + -0.012246759608387947, + 0.2744927704334259, + -0.010421697050333023, + -0.19415415823459625, + 0.005668101832270622, + 0.018326066434383392, + 0.28319111466407776, + -0.008164885453879833, + -0.07401272654533386, + -0.04154321923851967, + 0.030028337612748146, + -0.008959534578025341, + -0.03160349279642105, + -0.0191870778799057, + 0.044875819236040115, + 0.052173007279634476, + 0.012135458178818226, + 0.008775291964411736, + 0.005302258301526308, + 0.009224606677889824, + -0.07574712485074997, + 0.06096252053976059, + 0.02645082212984562, + 0.05135556682944298, + 0.021985528990626335, + 0.0076704383827745914, + 0.02961125783622265, + -0.07608609646558762, + -0.17564956843852997, + 0.03679918497800827, + -0.2696506083011627, + 0.0627906322479248, + 0.031165480613708496, + 0.01799822598695755, + 0.02351829782128334, + 0.015595306642353535, + -0.25137314200401306, + -0.011266927234828472, + 0.04895596578717232, + 0.01718883402645588, + 0.0009224268142133951, + 0.021923478692770004, + 0.044791676104068756, + 0.079147569835186, + 0.02014082670211792, + -0.0003547854721546173, + -0.02535748854279518, + -0.029639363288879395, + -0.01965961419045925, + -0.37630724906921387, + 0.01674639992415905, + 0.01316642016172409, + -0.025120021775364876, + -0.12474260479211807, + 0.059980470687150955, + 0.036066047847270966, + -0.15973420441150665, + -0.010871605016291142, + 0.014708316884934902, + -0.2174367904663086, + 0.012985467910766602, + -0.03782057762145996, + -0.003427069401368499, + -0.011010636575520039, + 0.02433733455836773, + 0.08641276508569717, + -0.004630533047020435, + 0.019430357962846756, + -0.02088969387114048, + -0.06182911619544029, + 0.02577812969684601, + 0.015741532668471336, + 0.04723552614450455, + -0.003783567575737834, + 0.11646346747875214, + 0.01827184483408928, + -0.0999741181731224, + -0.0031216999050229788, + -0.002268272452056408, + -0.019456079229712486, + -0.003156653605401516, + 0.0067732855677604675, + 0.027299508452415466, + 0.06979037076234818, + 0.013329057022929192, + -0.016705401241779327, + 0.33774301409721375, + 0.007617524825036526, + 0.044453222304582596, + 0.0016282782889902592, + 0.0010982973035424948, + 0.04183036834001541, + 0.016857653856277466, + 0.006673034280538559, + -0.0187662523239851, + 0.0037163379602134228, + -0.04568779841065407, + -0.007807960733771324, + 0.016653010621666908, + 0.0033014933578670025, + 0.015063234604895115, + 0.012843966484069824, + -0.012042546644806862, + 0.016909126192331314, + 0.022089935839176178, + -0.002550398698076606, + 0.04166745766997337, + -0.0014742743223905563, + -0.010846617631614208, + -0.12333541363477707, + 0.0018612967105582356, + 0.04913188889622688, + -0.029431112110614777, + 0.01824735291302204, + 0.10425490140914917, + -0.08880072832107544, + 0.03029320202767849, + 0.018876856192946434, + 0.016104502603411674, + 0.00882721971720457, + 0.0029782119672745466, + 0.007922517135739326, + -0.02030068263411522, + -0.029835309833288193, + 0.006661414168775082, + -0.04313879832625389, + -0.001850730157457292, + -0.0035070034209638834, + -0.0070700813084840775, + 0.009637435898184776, + -0.016844747588038445, + -0.026075454428792, + 0.0030682040378451347, + 0.004208600614219904, + -0.005515689495950937, + -0.018976539373397827, + -0.019196776673197746, + -0.008948019705712795, + 0.016215825453400612, + 0.00296461652033031, + 0.14222395420074463, + -0.029066482558846474, + -0.011013337410986423, + -0.01267730537801981, + -0.004976287949830294, + -0.016607511788606644, + -0.0005681798211298883, + -0.012520174495875835, + -0.0015903630992397666, + -0.0013642794219776988, + -0.21956196427345276, + -0.0011431180173531175, + -0.0008808697457425296, + -0.022889399901032448, + 0.024718068540096283, + -0.054929111152887344, + -0.015585094690322876, + -0.018188318237662315, + -0.0008287815726362169, + -0.01957552134990692, + 0.10818513482809067, + -0.0034382494632154703, + -0.02667389065027237, + -0.01304248720407486, + -0.0034645304549485445, + -0.008519704453647137, + -0.015123830176889896, + -0.008219013921916485, + -0.009952309541404247, + -2.3375787350232713e-05, + -0.012512428686022758, + -0.001955948770046234, + -0.0029842876829206944, + -0.004291659686714411, + 0.006655955221503973, + 0.007771315053105354, + 0.014132227748632431, + -0.007390063256025314, + -0.024650415405631065, + -0.022503213956952095, + 0.0032607221510261297, + -0.008497492410242558, + 0.00860870536416769, + 0.002819088753312826, + -0.01841069757938385, + -0.010009711608290672, + -0.2912862300872803, + 0.017160022631287575, + 0.11349690705537796, + -0.027656083926558495, + -0.04482223838567734, + -0.019336597993969917, + 0.07413014769554138, + 0.014554106630384922, + 0.020965611562132835, + -0.028231356292963028, + -0.0582813061773777, + 0.05617539584636688, + -0.05042734369635582, + 0.025630727410316467, + -0.0956532284617424, + -0.14554104208946228, + -0.020851148292422295, + 0.006990485824644566, + 0.08457829803228378, + -0.11314752697944641, + 0.004020951222628355, + -0.03477870300412178, + 0.005594289395958185, + 0.011181964538991451, + 0.010988114401698112, + 0.019416088238358498, + 0.026451971381902695, + -0.00452260859310627, + 0.0004952011513523757, + 0.012377702631056309, + -0.0063480171374976635, + 0.0256175734102726, + -0.020753338932991028, + 0.03223377838730812, + -0.1147943064570427, + -0.009170151315629482, + 0.015267477370798588, + -0.0009072314132936299, + -0.1621374636888504, + 0.022807778790593147, + 0.007394107989966869, + 0.01378557924181223, + -0.10719677805900574, + -0.000919080339372158, + -0.006567052565515041, + -0.007409179583191872, + -0.007469762582331896, + -0.004784661345183849, + -0.03967805579304695, + 0.015857066959142685, + -0.02015744335949421, + 0.056037548929452896, + 0.03962035849690437, + 0.08429893851280212, + 0.022117067128419876, + -0.2675061821937561, + 0.016738418489694595, + 0.0037785861641168594, + 0.004771686624735594, + -0.134505033493042, + -0.010618447326123714, + -0.004784524440765381, + 0.014044507406651974, + -0.03105556219816208, + 0.05049083009362221, + 0.012162688188254833, + 0.005920265335589647, + 0.008554516360163689, + 0.0025892227422446012, + 0.023483717814087868, + -0.20711173117160797, + 0.03360452130436897, + -0.24758699536323547, + -0.05136318504810333, + -0.015016172081232071, + 0.06466241925954819, + 0.023470288142561913, + 0.023495715111494064, + 0.004300899337977171, + 0.02461574412882328, + 0.025745516642928123, + -0.026187308132648468, + 0.08441776037216187, + -0.06955462694168091, + -0.11116205900907516, + -0.2169608771800995, + -0.004244703333824873, + -0.024184226989746094, + -0.10068271309137344, + -0.021129190921783447, + -0.021129680797457695, + -0.0054467362351715565, + 0.17416934669017792, + 0.015367642976343632, + -0.01237915363162756, + 0.024573752656579018, + 0.004588739015161991, + 0.05616860091686249, + -0.0018992060795426369, + -0.12394066900014877, + -0.03691404312849045, + -0.15878455340862274, + 0.10572423785924911, + 0.014409378170967102, + -0.008566108532249928, + -0.20319701731204987, + -0.018277373164892197, + -0.21615462005138397, + -0.11269525438547134, + -0.2767113745212555, + -0.25617966055870056, + -0.0036413148045539856, + -0.008058675564825535, + -0.051732294261455536, + -0.013052727095782757, + 0.05229722708463669, + -0.03535814583301544, + 0.3111231327056885, + -0.044130608439445496, + -0.02232682704925537, + -0.0040402463637292385, + 0.013798556290566921, + -0.07689940929412842, + -0.028940049931406975, + -0.00565366679802537, + -0.028972560539841652, + -0.007728889584541321, + 0.013665011152625084, + -0.014678380452096462, + -0.06747694313526154, + -0.06480871140956879, + -0.00028885426581837237, + -0.01525174267590046, + 0.027096102014183998, + -0.05200905352830887, + 0.0066903820261359215, + 0.0023834225721657276, + -0.002379713812842965, + -0.0208051148802042, + 0.335977703332901, + 0.03895771875977516, + -0.04814215749502182, + -0.037339694797992706, + -0.004409746266901493, + 0.07042848318815231, + -0.08318590372800827, + -0.04138712212443352, + 0.06309781968593597, + 0.007484383415430784, + 0.09696535021066666, + 0.024134323000907898, + -0.009859816171228886, + -0.06243982911109924, + 0.04630015045404434, + -0.06593744456768036, + 0.009306293912231922, + 0.5033899545669556, + 0.007804783061146736, + 0.024170484393835068, + -0.036085959523916245, + 0.016438491642475128, + 0.01678072288632393, + -0.006299734115600586, + -0.027441656216979027, + -0.014344800263643265, + 0.022293711081147194, + 0.011197407729923725, + -0.0026971842162311077, + 0.2685070335865021, + 0.01403988990932703, + -0.005100077483803034, + -0.026031343266367912, + -0.005419034510850906, + -0.014735087752342224, + -0.0283498577773571, + 0.002656748052686453, + -0.07137783616781235, + 0.02235356532037258, + -0.02970476634800434, + 0.20672672986984253, + 0.017398398369550705, + 0.02438206970691681, + 0.025746773928403854, + -0.03279582038521767, + 0.043908532708883286, + -0.003417646512389183, + 0.020200302824378014, + 0.007243862375617027, + -0.004560714587569237, + -0.01142876222729683, + -0.028091270476579666, + -0.2949703335762024, + 0.0729827880859375, + 0.004566277377307415, + 0.16689160466194153, + 0.034872010350227356, + -0.09590360522270203, + -0.13309867680072784, + 0.06429398059844971, + 0.04174232855439186, + -0.022723963484168053, + -0.04695400968194008, + 0.013115685433149338, + 0.013574879616498947, + 0.04794493317604065, + -0.015077140182256699, + 0.09493618458509445, + 0.008845972828567028, + 0.020302923396229744, + 0.02037016488611698, + 0.009083293378353119, + 0.0747746080160141, + -0.008078188635408878, + 0.024796344339847565, + -0.015212535858154297, + -0.005867444910109043, + 0.08309170603752136, + 0.03676094114780426, + 0.07232356816530228, + -0.3577176630496979, + 0.0013658110983669758, + -0.0009247250854969025, + 0.02284996211528778, + 0.012630275450646877, + 0.013745593838393688, + 0.003447894938290119, + 0.03563565015792847, + -0.031025355681777, + -0.07258180528879166, + -0.13482442498207092, + -0.029425248503684998, + -0.014927731826901436, + 0.045984312891960144, + -0.0176406130194664, + -0.22678181529045105, + -0.025248311460018158, + -0.11617762595415115, + -0.056157518178224564, + 0.009453062899410725, + -0.34616726636886597, + 0.05691010504961014, + -0.32302799820899963, + -0.026544231921434402, + -0.007374088745564222, + -0.07682909071445465, + -0.021214107051491737, + -0.07102422416210175, + 0.02693488635122776, + 0.014817211776971817, + 0.015572831965982914, + 0.04313618317246437, + -0.1277216374874115, + 0.02174532599747181, + -0.0226149819791317, + -0.00010956164624076337, + 0.023728065192699432, + 0.008212783373892307, + 0.010561724193394184, + -0.011036543175578117, + -0.022485855966806412, + 0.008243439719080925, + -0.03383245691657066, + -0.5630682110786438, + 0.0015974265988916159, + -0.28416821360588074, + 0.04123701527714729, + -0.0042976438999176025, + 0.03786511346697807, + 0.01862393692135811, + -0.04082413762807846, + -0.05792848393321037, + 0.0068894242867827415, + 0.0024085959885269403, + 0.001471342402510345, + 0.030681759119033813, + -0.026314062997698784, + 0.0555737242102623, + 0.03169534355401993, + 0.0031395808327943087, + 0.018701769411563873, + -0.5604594945907593, + 0.01526441890746355, + -0.00621993700042367, + 0.0009401043644174933, + 0.01587403193116188, + 0.030135583132505417, + -0.007350685074925423, + 0.006527469493448734, + 0.016000108793377876, + -0.042957425117492676, + 0.018247080966830254, + 0.0025622656103223562, + -0.03169511258602142, + 0.09235119074583054, + -0.013365034945309162, + 0.01607452519237995, + 0.017734844237565994, + 0.05609896034002304, + 0.04819876700639725, + -0.0871855691075325, + 0.05157865956425667, + 0.009171447716653347, + 0.022200705483555794, + -0.005507844965904951, + -0.024452703073620796, + 0.010224574245512486, + -0.006914906669408083, + 0.004650818649679422, + 0.02167516015470028, + 0.10456826537847519, + -0.07652094960212708, + -6.050072988728061e-05, + 0.012855490669608116, + 0.022669879719614983, + 0.022655120119452477, + 0.033012885600328445, + 0.025709744542837143, + 0.00481270719319582, + 0.005920717027038336, + -0.08545156568288803, + -0.004363589454442263, + -0.01531639602035284, + 0.030760569497942924, + 0.02796284481883049, + -0.03690989315509796, + 0.044959694147109985, + -0.14276015758514404, + -0.0002254673163406551, + -0.15694372355937958, + 0.012381293810904026, + -0.021977441385388374, + 0.005496624857187271, + -0.035593707114458084, + -0.0950438603758812, + 0.03825876861810684, + 0.05915532633662224, + -0.023323312401771545, + 0.017213119193911552, + -0.03807183355093002, + 0.02619507722556591, + 0.02741156332194805, + 0.005847832188010216, + 0.0020307491067796946, + 0.025714349001646042, + -0.04780200496315956, + 0.010206928476691246, + -0.01345440000295639, + 0.029133174568414688, + -0.0014764482621103525, + 0.004046705551445484, + -0.007725241594016552, + 0.013041527941823006, + 0.0018969239899888635, + 0.002417983952909708, + -0.010975837707519531, + 0.0015862436266615987, + 0.00597577728331089, + 0.002882696921005845, + 0.02855525352060795, + -0.005954153370112181, + 0.04090835899114609, + -0.39500924944877625, + 0.03586621209979057, + -0.5250031352043152, + -0.05697731301188469, + -0.09568691998720169, + -0.07179264724254608, + 0.04683076590299606, + 0.009320023469626904, + -0.11629963666200638, + -0.0016945215174928308, + 0.01624997705221176, + -0.0063682254403829575, + 0.15033549070358276, + -0.5171176791191101, + -0.01525783073157072, + 0.016417231410741806, + -0.00303818890824914, + 0.2500321865081787, + 0.022074062377214432, + 0.01191191840916872, + 0.012274803593754768, + 0.016534989699721336, + -0.028437916189432144, + 0.04241323843598366, + -0.01824999786913395, + -0.34815871715545654, + 0.04734490439295769, + -0.06419701874256134, + -0.022288290783762932, + -0.0004865761147812009, + 0.05369419604539871, + -0.058212973177433014, + -0.2196469008922577, + 0.010950890369713306, + 0.029042819514870644, + -0.07349151372909546, + -0.0422789566218853, + 0.062069639563560486, + 0.05589267984032631, + 0.014877256006002426, + 0.04236084595322609, + 0.03975239768624306, + 0.16930873692035675, + 0.03981085494160652, + 0.11499395221471786, + 0.0271450225263834, + 0.013969083316624165, + -0.0002660648606251925, + 0.010936664417386055, + -0.18389767408370972, + -0.10237602889537811, + 0.03041323646903038, + -0.013864071108400822, + -0.015729930251836777, + 0.037400804460048676, + -0.009598327800631523, + -0.09533312171697617, + -0.014712700620293617, + 0.08537333458662033, + -0.007200485561043024, + -0.31139102578163147, + -0.06366845220327377, + 0.02039063163101673, + -0.023356139659881592, + -0.0029549277387559414, + -0.12494662404060364, + 0.011755092069506645, + -0.26468148827552795, + -0.11541861295700073, + 0.010529865510761738, + -0.05965733155608177, + -0.05945499241352081, + -0.08796169608831406, + -0.014683439396321774, + 0.008732054382562637, + 0.010073489509522915, + 0.09553763270378113, + 0.034884922206401825, + 0.018675342202186584, + -0.009549405425786972, + -0.0007051719003356993, + -0.16936513781547546, + -0.0030460187699645758, + -0.022060535848140717, + -0.06689190864562988, + 0.013926704414188862, + 0.012043816037476063, + -0.0587068572640419, + -0.03814113140106201, + 0.06235629320144653, + 0.013228330761194229, + 0.04154474660754204, + -0.08039120584726334, + 0.028436705470085144, + -0.042226389050483704, + -0.019135186448693275, + 0.03747033327817917, + -0.14261123538017273, + 0.02827540971338749, + 0.0455685593187809, + -0.031124960631132126, + -0.007588588632643223, + 0.0034326373133808374, + -0.07682976871728897, + 0.24654042720794678, + -0.014518304727971554, + -0.07052458822727203, + -0.08241941034793854, + -0.04116151109337807, + -0.048463717103004456, + -0.038745298981666565, + 0.036902472376823425, + 0.0442035011947155, + 0.05572585016489029, + -0.014312628656625748, + 0.010794793255627155, + -0.3440641760826111, + -0.5161325335502625, + 0.0005156552069820464, + -0.010257269255816936, + -0.02412656880915165, + -0.023385023698210716, + 0.05533458665013313, + -0.012186119332909584, + -0.029286568984389305, + 0.04116401448845863, + -0.044610101729631424, + -0.019175484776496887, + 0.06835268437862396, + 0.06366674602031708, + 0.0373748242855072, + 0.03804386034607887, + 0.05369521677494049, + -0.04451881721615791, + 0.0018838117830455303, + 0.34775662422180176, + 0.010958605445921421, + -0.047990139573812485, + 0.04386777803301811, + -0.10427688807249069, + 0.04417382925748825, + 4.402965714689344e-05, + 0.01935163326561451, + -0.06753949075937271, + 0.02735923044383526, + 0.01465953141450882, + 0.06198301538825035, + -0.015980403870344162, + -0.2108263075351715, + 0.008177559822797775, + 0.006046924740076065, + 0.002665479900315404, + 0.20868580043315887, + -0.013740362599492073, + 0.008203004486858845, + -0.005066391546279192, + 0.026405498385429382, + 0.01383009273558855, + 0.012581533752381802, + 0.009014940820634365, + 0.022820021957159042, + -0.008534795604646206, + 0.2603924572467804, + 0.02297227643430233, + -0.000749691273085773, + 0.044753506779670715, + 0.018596511334180832, + 0.006852792575955391, + -0.008686172775924206, + -0.10452616959810257, + 0.017021872103214264, + 0.003722329391166568, + -0.025453045964241028, + -0.011473417282104492, + -0.017907623201608658, + 0.01400628499686718, + -0.1670989990234375, + 0.004298652987927198, + -0.0022204748820513487, + 0.16521315276622772, + -0.008831127546727657, + 0.026490870863199234, + 0.006190746556967497, + -0.0177209060639143, + 0.08967147767543793, + 0.0033069502096623182, + -0.005021366756409407, + 0.0004906906979158521, + 0.0169216375797987, + -0.06124846637248993, + -0.005200678016990423, + 0.08404737710952759, + -0.010559299029409885, + -0.006309974938631058, + 0.023113396018743515, + -0.010227260179817677, + 0.001256447983905673, + 0.019783375784754753, + -0.006308461539447308, + -0.04529590904712677, + -0.00908862054347992, + -0.043217338621616364, + -0.32200074195861816, + 0.02592635713517666, + 0.030795685946941376, + -0.001814531977288425, + 0.0092842485755682, + 0.07088880985975266, + -0.0867588147521019, + 0.024099843576550484, + -0.0034031609538942575, + 0.007234686985611916, + -0.02505563199520111, + 0.0030480287969112396, + -0.019158190116286278, + 0.26473408937454224, + -0.011918547563254833, + -0.023240016773343086, + -0.06084466353058815, + -0.021916134282946587, + -0.010251260362565517, + -0.0009625791572034359, + 0.082605741918087, + -0.013018425554037094, + 0.007627277635037899, + -0.0010813736589625478, + 0.007952406071126461, + 0.06551267951726913, + -0.026020025834441185, + 0.050048135221004486, + -0.010610008612275124, + -0.02429312653839588, + -0.025263017043471336, + -0.04611891135573387, + 0.04451768472790718, + -0.08045025914907455, + -0.048037610948085785, + 0.008019295521080494, + 0.0160224549472332, + 0.002078550634905696, + -0.0202508345246315, + -0.5446130633354187, + 0.012585492804646492, + -0.0331973135471344, + 0.08371605724096298, + -0.00590998912230134, + -0.013058983720839024, + 0.027742384001612663, + 0.1042199358344078, + -0.3072803318500519, + 0.06284149736166, + -0.28551968932151794, + 0.026768438518047333, + 0.022245990112423897, + 0.018242113292217255, + -0.035077981650829315, + 0.03546127676963806, + 0.10165776312351227, + -0.025475669652223587, + -0.014933750964701176, + 0.040547240525484085, + -0.033055808395147324, + 0.011755919083952904, + -0.014459444209933281, + -0.03455093130469322, + 0.020743343979120255, + 0.02720930427312851, + -0.287664532661438, + 0.008260028436779976, + -0.009877690114080906, + 0.16657423973083496, + -0.010943812318146229, + -0.012381386943161488, + 0.030678801238536835, + 0.1559792459011078, + 0.038967035710811615, + -0.023399239405989647, + 0.015019542537629604, + -0.014201333746314049, + -0.014202176593244076, + -0.006699408870190382, + -0.13175444304943085, + 0.004643211141228676, + 0.012747463770210743, + -0.04086190089583397, + 0.06581410765647888, + -0.12192045897245407, + -0.03126347437500954, + 0.011175516061484814, + -0.00914736744016409, + -0.02883930690586567, + -0.11305265873670578, + -0.04405384883284569, + -0.009120048955082893, + -0.008926079608500004, + -0.03169447183609009, + 0.05464877560734749, + 0.25674498081207275, + 0.08497058600187302, + -0.023222925141453743, + 0.35592252016067505, + -0.006929511670023203, + 0.025255810469388962, + -0.05150032415986061, + 0.039239466190338135, + -0.07082924991846085, + -0.017321549355983734, + 0.17293211817741394, + -0.02155853807926178, + -0.014333213679492474, + 0.0031305316369980574, + -0.013490653596818447, + -0.1376512199640274, + -0.021713266149163246, + -0.029826253652572632, + -0.0011473714839667082, + -0.012434332631528378, + -0.04860873892903328, + 0.013857590034604073, + 0.0703854188323021, + 0.034528713673353195, + -0.014423011802136898, + 0.0882454589009285, + -0.091700978577137, + 0.038885727524757385, + 0.012043441645801067, + -0.03183690831065178, + -0.014495689421892166, + -0.019726552069187164, + -0.010094117373228073, + -0.004218627233058214, + -0.04413086175918579, + -0.1344134360551834, + -0.0004976870259270072, + -0.0008357573533430696, + 0.04518067091703415, + 0.046797975897789, + 0.24766182899475098, + 0.01065139751881361, + -0.0034267394803464413, + -0.016103556379675865, + -0.05139121413230896, + 0.012563390657305717, + -0.03310413286089897, + -0.030157553032040596, + 0.046670909970998764, + 0.012565785087645054, + -0.040275491774082184, + 0.023816417902708054, + -0.38536572456359863, + 0.04508889466524124, + 0.13637560606002808, + -0.010654824785888195, + 0.0459851399064064, + -0.0046302699483931065, + -0.020852191373705864, + 0.10662271827459335, + 0.06486576050519943, + 0.05727925896644592, + 0.09816201776266098, + 0.04878557100892067, + -0.16256237030029297, + 0.014547038823366165, + 0.018567964434623718, + -0.07284612208604813, + 0.017150163650512695, + 0.0246741883456707, + -0.38470372557640076, + -0.07465949654579163, + 0.03010236658155918, + -0.004397575277835131, + -0.06618984788656235, + -0.02908281609416008, + 0.060166433453559875, + -0.0020949048921465874, + 0.007689109072089195, + -0.0047390698455274105, + -0.014199030585587025, + -0.01794746331870556, + -0.02528063952922821, + 0.002218312583863735, + 0.10169881582260132, + 0.010602130554616451, + -0.06605861335992813, + -0.0008762837387621403, + -0.035027723759412766, + -0.011684391647577286, + 0.02247578091919422, + 0.17245104908943176, + 0.22525252401828766, + -0.010771296918392181, + 0.05595310404896736, + 0.06338834017515182, + -0.0038216698449105024, + -0.0032836494501680136, + 0.005779017228633165, + -0.18020786345005035, + -0.05066698044538498, + -0.0035458216443657875, + -0.10578767210245132, + -0.041712939739227295, + 0.2104150652885437, + -0.03753345459699631, + 0.013989892788231373, + 0.01988149993121624, + 0.05108603090047836, + 0.04496738687157631, + -0.3034508526325226, + 0.0226743221282959, + -0.0431472510099411, + -0.025635428726673126, + -0.18961989879608154, + -0.17218825221061707, + 0.03576141223311424, + 0.060613714158535004, + -0.011970550753176212, + -0.21435107290744781, + 0.01422552578151226, + 0.02974064089357853, + -0.061079952865839005, + 0.031064646318554878, + 0.009629320353269577, + -0.13762925565242767, + 0.01928475871682167, + 0.007310172542929649, + 0.06103459745645523, + -0.16216528415679932, + 0.03330384939908981, + 0.09578404575586319, + -0.0037327276077121496, + 0.029233848676085472, + -0.0015759399393573403, + 0.005511409603059292, + -0.4195749759674072, + 0.024169376119971275, + 0.13220365345478058, + 0.007961929775774479, + 0.008045470342040062, + 0.01919495314359665, + -0.023188553750514984, + 0.07084394991397858, + -0.24922333657741547, + 0.02011212892830372, + -0.18514998257160187, + 0.03114209696650505, + 0.09826567023992538, + 0.00592303741723299, + -0.010020115412771702, + 0.027117054909467697, + -0.214133620262146, + -0.01214816514402628, + 0.06564164906740189, + 0.02513044886291027, + 0.02132420241832733, + -0.02127540111541748, + -0.041606876999139786, + 0.04196378216147423, + -0.02060609683394432, + 0.01730814389884472, + -0.17418994009494781, + 0.03462710976600647, + -0.017470642924308777, + -0.3992193639278412, + 0.02652592957019806, + 0.025042008608579636, + 0.026447610929608345, + -0.19199316203594208, + 3.27593952533789e-05, + 0.002988220192492008, + -0.21171888709068298, + 0.03300239518284798, + 0.015727035701274872, + -0.008947308175265789, + 0.03924538940191269, + -0.08990193158388138, + 0.023726975545287132, + 0.03463870286941528, + -0.05018220469355583, + 0.13170146942138672, + 0.054000236093997955, + 0.01158218178898096, + 0.062349993735551834, + -0.014724616892635822, + 0.039657603949308395, + 0.04436490684747696, + 0.014076294377446175, + 0.07666806876659393, + 0.09630247205495834, + -0.04152659326791763, + -0.1860806941986084, + -0.07671733945608139, + 0.031573690474033356, + -0.44617798924446106, + -0.004897239152342081, + -0.03991628438234329, + 0.01880800537765026, + -0.04769768565893173, + 0.02198435738682747, + 0.01341161783784628, + -0.12239313870668411, + 0.019765935838222504, + 0.005221452098339796, + -0.025201082229614258, + 0.005132562946528196, + 0.08668412268161774, + 0.0035341952461749315, + 0.008583099581301212, + 0.032979920506477356, + 0.03324040770530701, + 0.04411708936095238, + -0.008390798233449459, + 0.040486790239810944, + -0.059673551470041275, + 0.02003314346075058, + -0.0990666076540947, + 0.03971675783395767, + 0.012021057307720184, + 0.0017271327087655663, + 0.01818535290658474, + 0.0025106174871325493, + 0.043714240193367004, + 0.019146842882037163, + -0.0041794623248279095, + 0.033447377383708954, + 0.06863203644752502, + -0.004350902978330851, + 0.0113364327698946, + -0.05825724080204964, + -0.04649435728788376, + -0.10618306696414948, + 0.02653644233942032, + 0.012514552101492882, + 0.019399365410208702, + -0.0022177041973918676, + 0.017741208896040916, + 0.04115311801433563, + 0.05122101679444313, + 0.055051617324352264, + 0.01687677949666977, + -0.03698579967021942, + 0.10053858160972595, + -0.007528421934694052, + 0.003968802746385336, + 0.02458524890244007, + -0.02144794538617134, + 0.026791265234351158, + -0.016701897606253624, + 0.014119372703135014, + -0.03460531681776047, + -0.02320348098874092, + 0.056146953254938126, + 0.028700685128569603, + -0.14820916950702667, + -0.016996873542666435, + 0.025667931884527206, + 0.08408629894256592, + 0.00034475952270440757, + 0.007573155220597982, + 0.06784884631633759, + 0.025982951745390892, + -0.08363039791584015, + -0.015748541802167892, + -0.0029514851048588753, + -0.01523523684591055, + 0.10500328987836838, + 0.3070858418941498, + -0.024624783545732498, + 0.0058471946977078915, + -0.039751242846250534, + 0.0012745993444696069, + -0.0796508714556694, + 0.024727927520871162, + 0.056764136999845505, + -0.013338261283934116, + -0.04794292524456978, + -0.02609768509864807, + -0.010784422047436237, + -0.048712026327848434, + 0.020345501601696014, + 0.0021618579048663378, + -0.0021724768448621035, + 0.03056410700082779, + -0.01633712649345398, + -0.47168225049972534, + -0.014639903791248798, + -0.012550815008580685, + 0.03358187526464462, + 0.07889427989721298, + -0.03615899011492729, + -0.002809660043567419, + -0.006953644100576639, + 0.02024337276816368, + -0.0738825723528862, + -0.006984011270105839, + -0.04472561925649643, + -0.027498915791511536, + 0.07207506150007248, + -0.09166522324085236, + -0.008861960843205452, + 0.05264359340071678, + 0.01889069564640522, + -0.1380404680967331, + -0.010141258127987385, + 0.015403619967401028, + -0.16416165232658386, + -0.03529815003275871, + 0.042106859385967255, + 0.11173021793365479, + -0.3143587112426758, + 0.011045016348361969, + 0.0012351945042610168, + 0.03840603306889534, + 0.0685538575053215, + -0.000746160454582423, + -0.028142500668764114, + 0.027154160663485527, + 0.005731801502406597, + 0.04433267563581467, + -0.8158469796180725, + 0.02226361259818077, + -0.07650655508041382, + 0.026958195492625237, + -0.005810025613754988, + -0.020102059468626976, + -0.0019310436910018325, + 0.07697021961212158, + -0.057701658457517624, + 0.05954534560441971, + 0.0027106746565550566, + -0.06311310827732086, + 0.011713752523064613, + -0.0034454476553946733, + -0.0006881420267745852, + 0.08937360346317291, + -0.0008253820124082267, + -0.031066063791513443, + -0.14708301424980164, + -0.04438449814915657, + 0.004772413522005081, + 0.05992274731397629, + 0.07473544776439667, + -0.1784757375717163, + -0.19057415425777435, + -0.014637955464422703, + -0.24898527562618256, + 0.13606221973896027, + -0.018039124086499214, + -0.047193415462970734, + -0.06526428461074829, + 0.04075757786631584, + 0.049901530146598816, + -0.008585861884057522, + 0.01616351678967476, + -3.091737016802654e-05, + 0.024283329024910927, + 0.008861682377755642, + -0.0005823548417538404, + 0.0997646301984787, + 0.051001910120248795, + 0.009473294951021671, + -0.0032046104315668344, + 0.018362928181886673, + 0.008627718314528465, + -0.4148157835006714, + -0.016077928245067596, + 0.0745391696691513, + 0.00724065862596035, + 0.08948155492544174, + 0.11626332253217697, + -0.052439428865909576, + 0.005599102005362511, + 0.002622961765155196, + 0.07586965709924698, + 0.03274847939610481, + -0.02099076844751835, + -0.04666733741760254, + -0.0013019372709095478, + 0.04945925995707512, + 0.11393380910158157, + 0.006346395239233971, + 0.04721064493060112, + 0.010331138968467712, + 0.08918803185224533, + 0.04288423806428909, + -0.09234773367643356, + 0.020141584798693657, + -3.256054696976207e-05, + -0.02799108810722828, + 0.018966441974043846, + -0.4136410355567932, + -0.07217283546924591, + 0.01840362884104252, + -0.055327851325273514, + 0.003275467548519373, + -0.017174070701003075, + -0.032178670167922974, + 0.09021560847759247, + -0.524413526058197, + 0.01994725503027439, + 0.10380692034959793, + -0.01043684035539627, + -0.00011200909648323432, + 0.01331041194498539, + 0.020127851516008377, + -0.025159789249300957, + 0.05252581834793091, + 0.04759140685200691, + 0.0032084162812680006, + -0.03579062595963478, + 0.054719552397727966, + -0.04674411937594414, + 0.028389262035489082, + 0.001127603929489851, + -0.0006243048119358718, + -0.00550495833158493, + -0.022523507475852966, + -0.024282312020659447, + 0.009519628249108791, + -0.39908328652381897, + -0.009265545755624771, + -0.00037090369733050466, + 0.06425131112337112, + -0.05998316407203674, + -0.015221518464386463, + -0.004825026262551546, + 0.11847284436225891, + -0.011302731931209564, + -0.006884834263473749, + -0.04678218811750412, + -0.012078279629349709, + 0.021638741716742516, + -0.016819776967167854, + -0.009127719327807426, + -0.002491263672709465, + 0.0016752213705331087, + -0.016600262373685837, + 0.011772023513913155, + -0.013447183184325695, + -0.020662957802414894, + -0.011593316681683064, + 0.008270744234323502, + -0.0026990456972271204, + -0.004406482446938753, + -0.023110052570700645, + -0.00208942755125463, + -0.1711198389530182, + 0.012432538904249668, + -0.0045453268103301525, + 0.024807902052998543, + -0.0035043740645051003, + -0.004001997876912355, + -0.013488625176250935, + -0.02020987868309021, + -0.01216109935194254, + -0.004432092886418104, + 0.09323672950267792, + -0.015641510486602783, + -0.019307948648929596, + 0.01117538008838892, + -0.01422040443867445, + 0.01705607771873474, + -0.0029596879612654448, + -0.0021530911326408386, + -0.006551788654178381, + 0.00429268553853035, + -0.1620807945728302, + -0.014128226786851883, + -0.005428737495094538, + -0.006771362852305174, + 0.005730633158236742, + 0.0007243106956593692, + 0.0024031582288444042, + -0.00199915561825037, + 0.006133859045803547, + -0.013380909338593483, + 0.00733462069183588, + -0.001863821060396731, + -0.0020169683266431093, + -0.014070986770093441, + -0.006501683499664068, + -0.029421553015708923, + 0.0009377509704791009, + -0.01718256250023842, + -0.05819401144981384, + -0.018859732896089554, + 0.0010356366401538253, + 0.006394123658537865, + -0.021985618397593498, + -0.01204769592732191, + -0.002014884725213051, + -0.019398409873247147, + -0.013122898526489735, + -0.017277296632528305, + -0.002270353492349386, + -0.05294327810406685, + -0.020317314192652702, + -0.018196573480963707, + -0.010375416837632656, + -0.019704729318618774, + -0.016109557822346687, + -0.0167380403727293, + -0.0285252146422863, + -0.02665277197957039, + -0.03554505482316017, + -0.00741522666066885, + -0.013580105267465115, + -0.026335405185818672, + -0.011694515123963356, + -0.004639182705432177, + -0.03996071219444275, + -0.022463932633399963, + -0.007204636000096798, + -0.021065134555101395, + -0.014410646632313728, + 0.0035447971895337105, + -0.0013098351191729307, + -0.024171002209186554, + 0.00047751085367053747, + -0.01870289072394371, + -0.06016797944903374, + -0.025703946128487587, + -0.009730588644742966, + -0.021792838349938393, + -0.024519823491573334, + -0.01843440905213356, + -0.0016325484029948711, + -0.008116388693451881, + -0.017774557694792747, + -0.04375867918133736, + -0.03893980756402016, + -0.018188582733273506, + -0.007122726645320654, + -0.028115490451455116, + -0.01821342669427395, + -0.01011319737881422, + -0.02616124413907528, + -0.013797983527183533, + -0.03202736750245094, + -0.030110370367765427, + -0.01883666031062603, + -0.01185502391308546, + -0.006012012716382742, + -0.017311619594693184, + -0.022577986121177673, + -0.02101938985288143, + 0.0025952248834073544, + -0.005058783106505871, + -0.004162575118243694, + -0.01559755764901638, + -0.017923563718795776, + -0.04231095686554909, + -0.017630560323596, + -0.011938830837607384, + -0.01587115228176117, + 0.004972478374838829, + -0.016601158306002617, + 0.15419845283031464, + 0.0009241115767508745, + 0.051028184592723846, + 0.008128340356051922, + -0.019917558878660202, + -0.0010339801665395498, + 0.022349294275045395, + -0.0072520882822573185, + 0.0017750378465279937, + -0.10526080429553986, + 0.03420695662498474, + 0.019183926284313202, + -0.0006544998032040894, + -0.0032203509472310543, + -0.01216941885650158, + -0.03561796247959137, + 0.024905826896429062, + -0.026948239654302597, + -0.01913355104625225, + -0.014459407888352871, + 0.006972283590584993, + -0.033184293657541275, + 0.04884861409664154, + -0.002296984428539872, + -0.19194477796554565, + 0.00392142403870821, + 0.009490449912846088, + -0.02687196619808674, + -0.06327224522829056, + -0.03684951737523079, + -0.0002613202668726444, + -0.012086644768714905, + 0.03630973398685455, + 0.007296048104763031, + 0.011186012998223305, + 0.0074085514061152935, + -0.020394617691636086, + -0.010585476644337177, + -0.030289918184280396, + 0.0773506686091423, + 0.008841303177177906, + 0.019423579797148705, + 0.001184571417979896, + 0.005553434602916241, + 0.015373414382338524, + -0.0027953842654824257, + 0.013204757124185562, + 0.029097743332386017, + 0.012627501040697098, + 0.02102004364132881, + -0.09469914436340332, + -0.023324014618992805, + 0.029243655502796173, + 0.002979277865961194, + -0.004492263309657574, + 0.20549021661281586, + -0.3244459927082062, + 0.025892559438943863, + 0.009620796889066696, + -0.05520407855510712, + -0.02271144650876522, + 0.008378816768527031, + -0.0671214610338211, + -0.016056722030043602, + -0.02355658821761608, + 0.0005429868469946086, + -0.007960098795592785, + 0.02513299137353897, + -0.13005328178405762, + -0.0025323680602014065, + -0.02197088487446308, + -0.02404806576669216, + 0.08261960744857788, + 0.17078880965709686, + 0.02880753017961979, + -0.03642067685723305, + 0.021994341164827347, + -0.012368184514343739, + -0.10681373625993729, + 0.16371481120586395, + 0.17881983518600464, + -0.10202010720968246, + -0.08641688525676727, + -0.1259487271308899, + 0.06907707452774048, + 0.023792706429958344, + -0.02534419298171997, + 0.016984017565846443, + -0.06743635982275009, + 0.08445960283279419, + -0.08037827908992767, + -0.11935994029045105, + -0.31716489791870117, + -0.01860150322318077, + 0.060669515281915665, + -0.06137414649128914, + 0.09878886491060257, + 0.01794014871120453, + 0.12382296472787857, + -0.016424886882305145, + 0.09045679122209549, + -0.02998783066868782, + -0.00972777884453535, + -0.024124544113874435, + 0.09879253059625626, + 0.05500243604183197, + -0.06635259836912155, + 0.11268552392721176, + 0.011751363053917885, + -0.04690232127904892, + -0.025168607011437416, + 0.088335320353508, + -0.1140628531575203, + 0.04129032790660858, + -0.04258979484438896, + -0.0903872698545456, + 0.008473021909594536, + -0.026690304279327393, + -0.051559556275606155, + -0.05481572076678276, + -0.05251916125416756, + -0.0018165932269766927, + 0.09836867451667786, + 0.0054859439842402935, + 0.06432581692934036, + 0.10621821135282516, + -0.019325286149978638, + -0.028727786615490913, + 0.014013150706887245, + -0.008022608235478401, + -0.006281842477619648, + -0.0297000203281641, + 0.01525485422462225, + -0.4346403479576111, + 0.07787995040416718, + -0.25380268692970276, + 0.05261845141649246, + 0.010875157080590725, + 0.0014149334747344255, + 0.05021188035607338, + -0.24382442235946655, + 0.0807114690542221, + 0.022907381877303123, + 0.006440790370106697, + -0.017028095200657845, + 0.001552293193526566, + 0.05961666256189346, + -0.14113056659698486, + 0.03398876264691353, + -0.005411976482719183, + -0.014025667682290077, + -0.5433799624443054, + 0.019015472382307053, + 0.04091138765215874, + 0.05059061944484711, + 0.0274446289986372, + -0.010288042947649956, + -0.001335533568635583, + -0.013533512130379677, + 0.018798377364873886, + -0.04099345579743385, + 0.0031264263670891523, + -0.21071769297122955, + -0.014384736306965351, + -0.1045387014746666, + -0.014340974390506744, + 0.001986369490623474, + -0.04118456318974495, + -0.10952988266944885, + 0.049147430807352066, + -0.08382093161344528, + -0.1741400957107544, + -0.0885215476155281, + -0.10934099555015564, + 0.05553343519568443, + 0.02434251271188259, + 0.006634524557739496, + -0.0017163373995572329, + 0.0185443926602602, + 0.06250902265310287, + -0.17145656049251556, + -0.07543934881687164, + 0.026583310216665268, + 0.01634727604687214, + 0.003603539662435651, + -0.2817271649837494, + 0.03882112354040146, + 0.011341865174472332, + 0.00826666783541441, + 0.050427842885255814, + -0.22358834743499756, + 0.06419781595468521, + 0.03245265409350395, + -0.04503164440393448, + -0.023194484412670135, + -0.027968740090727806, + 0.08563586324453354, + 0.07954753190279007, + -0.08513130992650986, + 0.02850884199142456, + 0.008976672776043415, + 0.07886530458927155, + 0.0022273347713053226, + -0.09540755301713943, + 0.032016951590776443, + -0.05196075513958931, + 0.10555616766214371, + 0.07629868388175964, + 0.039732079952955246, + -0.0029798501636832952, + 0.014692343771457672, + 0.09200941026210785, + -0.04299614951014519, + -0.023488566279411316, + -0.01851060427725315, + 0.09257487207651138, + 0.055612049996852875, + 0.06423109769821167, + -0.28587806224823, + -0.09950444847345352, + 0.10397437959909439, + 0.025166453793644905, + -0.03235514089465141, + -0.033381711691617966, + 0.1513858139514923, + 0.06468874961137772, + 0.01928441785275936, + 0.0032701045274734497, + -0.0579083226621151, + -0.022929169237613678, + 0.012971373274922371, + -0.018524186685681343, + -0.06484643369913101, + 0.012233717367053032, + 0.06590451300144196, + -0.04558677598834038, + 0.05253027006983757, + 0.048656731843948364, + -0.2288871705532074, + 0.037114787846803665, + -0.20519588887691498, + 0.0058607361279428005, + -0.002009372925385833, + -0.006671734619885683, + -0.07107856124639511, + -0.07407436519861221, + 0.03941629081964493, + 0.0447598397731781, + 0.03509354963898659, + -0.061107732355594635, + -0.09305761009454727, + -0.012180411256849766, + 0.04902016744017601, + 0.07974442094564438, + -0.016854410991072655, + 0.005089411046355963, + -0.08127597719430923, + 0.03258403390645981, + 0.039813362061977386, + -0.01668727956712246, + 0.027226485311985016, + -0.029213925823569298, + -0.008598217740654945, + 0.00931101106107235, + 0.026936721056699753, + -0.03083401545882225, + -0.05799110606312752, + -0.008277476765215397, + -0.014854338951408863, + -0.20012643933296204, + 0.012290815822780132, + 0.007194168865680695, + 0.06858328729867935, + -0.3296163082122803, + -0.11424986273050308, + 0.009912200272083282, + -0.06211454048752785, + 0.0007546336855739355, + 0.03507614880800247, + 0.10649498552083969, + -0.03036407195031643, + 0.0646015852689743, + -0.01595110446214676, + -0.16919563710689545, + 0.0013865949586033821, + -0.08339446783065796, + 0.06962471455335617, + 0.016058098524808884, + -0.04729780554771423, + 0.010602935217320919, + 0.01470863912254572, + 0.06903938204050064, + 0.014901719056069851, + -0.15120048820972443, + 0.016727851703763008, + 0.05003673583269119, + 0.04370126873254776, + 0.029703885316848755, + 0.021875420585274696, + 0.026293285191059113, + -0.01048936415463686, + 0.00040810942300595343, + -0.015616541728377342, + -0.062451593577861786, + 0.010016348212957382, + -0.06790193170309067, + -0.02077331207692623, + 0.007985175587236881, + -0.04435744881629944, + 0.06920231133699417, + 0.018344474956393242, + 0.028591370210051537, + 0.021957838907837868, + 0.0017570338677614927, + 0.036665257066488266, + 0.015438515692949295, + -0.0006347382441163063, + 0.04621066153049469, + -0.001942177303135395, + 0.010664877481758595, + -0.016754357144236565, + 0.006541184149682522, + -0.027716301381587982, + -0.0058586387895047665, + -0.005346015095710754, + 0.020482052117586136, + 0.06882552057504654, + 0.0026622572913765907, + 0.016321638599038124, + 0.017728103324770927, + -0.13356441259384155, + 0.030281176790595055, + 1.0354949154134374e-05, + 0.050639618188142776, + 0.0013030078262090683, + -0.11136802285909653, + -0.006832807790488005, + -0.09628921747207642, + 0.046699415892362595, + 0.002175685251131654, + 0.008100612089037895, + 0.012449901551008224, + -0.01713990420103073, + -0.000769267207942903, + 0.022544430568814278, + -0.0018787183798849583, + -0.014189678244292736, + 0.37042510509490967, + -0.030317893251776695, + 0.012663356028497219, + -0.04071582853794098, + 0.01653047651052475, + 0.06578584760427475, + 0.005606585182249546, + 0.0029362838249653578, + -0.02035594917833805, + 0.016131827607750893, + -0.06512665003538132, + 0.020292088389396667, + 0.12818951904773712, + -0.00017647731874603778, + 0.0004811069811694324, + 0.013025660999119282, + -0.006004344671964645, + 0.011330580338835716, + 0.0021733916364610195, + -0.0026290342211723328, + 0.008579215034842491, + -0.017107143998146057, + 0.0032798980828374624, + 0.21415431797504425, + -0.011049880646169186, + 0.04915957152843475, + -0.01152863260358572, + 0.01988764852285385, + -0.30189022421836853, + 0.1491061896085739, + 0.022540517151355743, + 0.02323656715452671, + -0.0028044115751981735, + -0.02501249685883522, + 0.0016759912250563502, + 0.023405946791172028, + 0.0865691602230072, + 0.0056661744602024555, + 0.2334042638540268, + -0.05771901085972786, + 0.03428330272436142, + -0.05191519856452942, + 0.025708407163619995, + -0.11474912613630295, + 0.05345827341079712, + 0.050046734511852264, + -0.03785427287220955, + 0.02726786397397518, + 0.008640051819384098, + -0.05810163915157318, + 0.19147679209709167, + 0.12065602838993073, + -0.08667072653770447, + -0.12831886112689972, + 0.027053257450461388, + -0.1771622896194458, + -0.2615586817264557, + 0.112942636013031, + 0.002398239215835929, + 0.00907410029321909, + 0.059947770088911057, + 0.040937639772892, + 0.003431124845519662, + 0.012721046805381775, + -0.10228776186704636, + 0.04169567674398422, + -0.04826785624027252, + -0.021415220573544502, + 0.027615519240498543, + 0.16087181866168976, + 0.03552674129605293, + -0.36409878730773926, + 0.0015418739058077335, + 0.03940089792013168, + -0.12929502129554749, + 0.017082052305340767, + -0.07193783670663834, + 0.10395099222660065, + -0.2240910828113556, + -0.003303584409877658, + -0.0074868109077215195, + -0.13708709180355072, + 0.2098008245229721, + 0.013808795250952244, + -0.03606148064136505, + 0.001965852687135339, + 0.04186573252081871, + 0.02105732634663582, + -0.11873909085988998, + -0.08529136329889297, + 0.0060731275007128716, + 0.04803553968667984, + 0.07665349543094635, + 0.026997262611985207, + 0.05191565304994583, + 0.09013131260871887, + 0.013081093318760395, + 0.04667182266712189, + -0.19899451732635498, + 0.004642056301236153, + 0.0025570227298885584, + -0.2640555500984192, + 0.008254006505012512, + 0.05971720814704895, + -0.002980671590194106, + 0.0011313167633488774, + -0.004445134196430445, + 0.01951296627521515, + -0.006634386721998453, + -0.008033698424696922, + 0.012400158680975437, + -0.15906694531440735, + 0.007047838997095823, + 0.0003521084145177156, + -0.00517050176858902, + -0.0003226286207791418, + -0.01226231548935175, + -0.06750697642564774, + -0.03061128593981266, + -0.0027100055012851954, + 0.004726986400783062, + 0.010185977444052696, + 0.021205933764576912, + -0.05105980113148689, + -0.006725164130330086, + 0.26042309403419495, + 0.003935054875910282, + 0.009450466372072697, + -0.009512278251349926, + 0.036205559968948364, + 0.0066987741738557816, + 0.05687355250120163, + -0.0070350514724850655, + 0.021287698298692703, + 0.004246287513524294, + -0.004053668584674597, + 0.0030501342844218016, + -0.003596516093239188, + 0.00571554945781827, + 0.039099883288145065, + 0.06648323684930801, + 0.011140268296003342, + 0.002779693342745304, + 0.0004113377653993666, + 0.0019621821120381355, + 0.002047213725745678, + -9.034215327119455e-05, + 0.006674906238913536, + -0.024464793503284454, + 4.372629337012768e-05, + 0.04560312256217003, + 0.029951298609375954, + 0.0053787860088050365, + 0.010052027180790901, + 0.0018156497972086072, + 0.001613074098713696, + -0.3710610568523407, + 0.18385423719882965, + 0.0197732076048851, + -2.409513217571657e-05, + 0.043657880276441574, + 0.029824273660779, + -0.0015272254822775722, + -0.0009817760437726974, + 0.030571524053812027, + 0.05133187025785446, + 0.021092001348733902, + -0.022430723533034325, + -0.011050102300941944, + -0.01653454266488552, + 0.00856624636799097, + 0.007617316208779812, + 0.023697074502706528, + -0.00541776092723012, + -0.06940567493438721, + -0.024501511827111244, + 0.0029131292831152678, + 0.005110545549541712, + 0.02394089475274086, + 0.009317552670836449, + -0.05198051407933235, + -0.14872707426548004, + -0.03553030639886856, + 0.05354774370789528, + 0.053996339440345764, + 0.016679847612977028, + -0.4505158066749573, + 0.006403166800737381, + -0.014287465251982212, + 0.010499212890863419, + 0.00510875741019845, + 0.0230255089700222, + -0.04791099205613136, + -0.08405473828315735, + -0.00807158276438713, + -0.016310568898916245, + -0.018034789711236954, + -0.03381670266389847, + 0.038599055260419846, + 0.01189411524683237, + 0.0038598189130425453, + 0.0077203805558383465, + -0.0006835742969997227, + 0.3038807809352875, + 0.00930703990161419, + -0.017654214054346085, + -0.029550395905971527, + 0.0014829621650278568, + -0.010562432929873466, + -0.011867706663906574, + -0.008104459382593632, + 0.008003979921340942, + -0.028282882645726204, + 0.00898829661309719, + -0.04963170364499092, + 0.014971665106713772, + 0.028662119060754776, + 0.055792808532714844, + 0.018142173066735268, + 0.029526766389608383, + 0.04726170003414154, + 0.020290115848183632, + -0.01347910612821579, + -0.027794860303401947, + -0.033374592661857605, + 0.05699307844042778, + -0.005888971965759993, + 0.009723466821014881, + 0.011825029738247395, + 0.0005665962235070765, + -0.22433574497699738, + 0.04777664318680763, + 0.054696254432201385, + 0.06447272002696991, + 0.006656138692051172, + -0.2656468152999878, + -0.006602808367460966, + -0.04309352487325668, + 0.024392882362008095, + -0.046948980540037155, + 0.17317010462284088, + -0.014694501645863056, + 0.09150391072034836, + 0.05414793640375137, + -0.0034523033536970615, + -0.029682809486985207, + -0.11646991223096848, + 0.036394182592630386, + -0.008510537445545197, + -0.09555189311504364, + 0.012331446632742882, + 0.022554755210876465, + 0.037040166556835175, + 0.011939534917473793, + -0.035405583679676056, + -0.008284371346235275, + 0.008629710413515568, + -0.0017152110813185573, + -0.01656493730843067, + 0.02205522358417511, + -0.008015291765332222, + -0.02198217809200287, + -0.08165504783391953, + 0.018647879362106323, + 0.010489191859960556, + 0.0009643095545470715, + 0.08301698416471481, + 0.00795030314475298, + -0.08973152935504913, + 0.05324552580714226, + 0.0187348835170269, + 0.00770497927442193, + 0.016434336081147194, + 0.0031714467331767082, + 0.031489044427871704, + -0.01682765781879425, + -0.0006042059976607561, + 0.006229344755411148, + 0.0031935630831867456, + -0.03694210946559906, + -0.027148112654685974, + 0.03319454565644264, + 0.013541879132390022, + 0.04362545907497406, + 0.010766182094812393, + 0.01287879142910242, + 0.02723391354084015, + 0.01831277459859848, + -0.0028144901152700186, + 0.0317537821829319, + -0.05053209140896797, + 0.03341667726635933, + 0.009338690899312496, + 0.030376508831977844, + 0.028512636199593544, + 0.002190604107454419, + 0.031132254749536514, + 0.04174429178237915, + 0.025147251784801483, + 0.02602408640086651, + 0.022863827645778656, + 0.024160150438547134, + 0.04043813422322273, + 0.011693909764289856, + 0.008020071312785149, + 0.010814648121595383, + 0.014862221665680408, + 0.043966785073280334, + 0.04133215174078941, + 0.03920775279402733, + 0.02128027193248272, + -0.0024078795686364174, + 0.03185494989156723, + 0.030951442196965218, + 0.008766901679337025, + -0.0013500713976100087, + 0.012680909596383572, + 0.01911563239991665, + 0.02226334996521473, + 0.03873631730675697, + 0.005242412444204092, + 0.02335301972925663, + 0.00577192846685648, + 0.0019918885082006454, + 0.019501060247421265, + 0.048295676708221436, + 0.027288099750876427, + 0.03500128164887428, + 0.032504353672266006, + 0.03619033470749855, + 0.022762063890695572, + 0.014124974608421326, + 0.04055529460310936, + 0.040181197226047516, + 0.04843837395310402, + 0.019578352570533752, + 0.04370861127972603, + 0.024640914052724838, + 0.027013463899493217, + 0.04700532928109169, + 0.018523193895816803, + 0.03569294884800911, + 0.031140455976128578, + 0.010298499837517738, + 0.03979840502142906, + 0.015059049241244793, + 0.020604899153113365, + 0.010335667058825493, + 0.02557498589158058, + 0.015946611762046814, + 0.018900645896792412, + 0.05494159087538719, + 0.015756357461214066, + 0.0452926866710186, + 0.04820817708969116, + -0.0183499027043581, + 0.04002442955970764, + -0.08226092159748077, + -0.034417178481817245, + 0.059122342616319656, + 0.028960591182112694, + -0.020427608862519264, + -0.043222296983003616, + 0.023134637624025345, + -0.014232538640499115, + -0.06970997899770737, + -0.0035826240200549364, + -0.015384080819785595, + -0.0695020854473114, + 0.03645527362823486, + 0.013986784033477306, + -0.027729706838726997, + -0.05711805075407028, + -0.0763891413807869, + -0.16338491439819336, + -0.02358265034854412, + -0.004730133805423975, + 0.022057903930544853, + -0.011578230187296867, + 0.040772147476673126, + -0.059327173978090286, + -0.03819728270173073, + -0.050089117139577866, + -0.005152902565896511, + -0.3071111738681793, + -0.010683669708669186, + 0.030922774225473404, + 0.08924981951713562, + 0.005679265595972538, + 0.06334424018859863, + 0.016136568039655685, + -0.02575727365911007, + -0.012562219053506851, + 0.007206748705357313, + -0.1373208612203598, + -0.010450832545757294, + -0.05991309881210327, + -0.006700845435261726, + -0.006468744482845068, + -0.02040017955005169, + -0.010068708099424839, + 0.008442427963018417, + 0.012259873561561108, + -0.002103718463331461, + -0.019605906680226326, + -0.010690353810787201, + 0.0005222380859777331, + -0.015031278133392334, + -0.012983204796910286, + -0.03552224859595299, + -0.007792052812874317, + -0.035602111369371414, + -0.03479204699397087, + -0.02480080910027027, + -0.05733964219689369, + 4.38804054283537e-05, + -0.021825626492500305, + -0.03287259489297867, + -0.05437042564153671, + -0.007981077767908573, + 0.023045696318149567, + 0.05785335600376129, + 0.03685669228434563, + 0.04314129799604416, + -0.005843586288392544, + -0.024806369096040726, + -0.02562016434967518, + 0.0015172295970842242, + -0.01568800024688244, + -0.005925294477492571, + 0.010173594579100609, + 0.06834683567285538, + 0.024159085005521774, + -0.009547322988510132, + 0.014080812223255634, + 0.013578452169895172, + 0.035671167075634, + 0.01240566186606884, + -0.021352441981434822, + 0.05245270952582359, + -0.008943279273808002, + -0.010131126269698143, + 0.02976749651134014, + 0.0600045844912529, + 0.0014893191400915384, + 0.03796907886862755, + 0.01258794590830803, + -0.025344882160425186, + 0.14140591025352478, + 0.028354406356811523, + 0.0035325682256370783, + 0.05017172172665596, + 0.01994139887392521, + 0.03679897263646126, + -0.009579945355653763, + -0.012607194483280182, + -0.00034231581958010793, + 0.00046832446241751313, + 0.057916246354579926, + 0.02351403795182705, + 0.06157909706234932, + 0.00789523497223854, + -0.018361341208219528, + 0.0018971840618178248, + -0.007180131506174803, + -0.0010631990153342485, + -0.03140748664736748, + -0.028505641967058182, + 0.010669395327568054, + -0.036474280059337616, + 0.01703447848558426, + 0.04667484760284424, + -0.007303370162844658, + 0.01768752932548523, + 0.012412219308316708, + 0.013702306896448135, + 0.07651616632938385, + 0.05469715967774391, + 0.013292597606778145, + -0.006288900971412659, + 0.0215559434145689, + 0.010094149969518185, + -0.024216346442699432, + -0.15225785970687866, + 0.05467289313673973, + 0.019871067255735397, + 0.04662928730249405, + 0.05072600021958351, + -0.011824453249573708, + -0.028083933517336845, + 0.013322187587618828, + -0.044827401638031006, + 0.05955006927251816, + -0.006152187939733267, + 0.013426700606942177, + -0.014220507815480232, + 0.022510837763547897, + 0.019426455721259117, + -0.05546477064490318, + -0.49202534556388855, + 0.026985207572579384, + -0.08852843940258026, + 0.07166163623332977, + 0.05509938299655914, + -0.42284780740737915, + -0.05131356418132782, + 0.0196990966796875, + -0.008681846782565117, + 0.02739996463060379, + 0.0010900507913902402, + 0.04289104416966438, + -0.06694932281970978, + 0.05930810049176216, + -0.02174360118806362, + 0.03464379161596298, + 0.018284866586327553, + 0.018807150423526764, + 0.019874336197972298, + -0.03665176033973694, + -0.2980017066001892, + 0.050937239080667496, + -0.013874954544007778, + -0.0229057464748621, + 0.016420641914010048, + 0.024160616099834442, + -0.10750921070575714, + -0.010134756565093994, + 0.026874780654907227, + 0.007151094265282154, + 0.06304068863391876, + -0.11811652034521103, + -0.12590888142585754, + 0.031846947968006134, + -0.06898463517427444, + 0.03395693376660347, + -0.00010166154243052006, + -0.19019480049610138, + 0.06616076827049255, + -0.035927142947912216, + 0.08526375889778137, + 0.0015017242403700948, + -0.009137739427387714, + 0.04529058188199997, + -0.23621641099452972, + 0.02148340456187725, + -0.02741178683936596, + -0.20779411494731903, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 64, 1, 1)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_504") + + initializers.append(tensor) + + list_value = [ + 5.195802688598633, + 0.940099835395813, + -7.016428470611572, + 5.185446739196777, + -4.134859085083008, + 2.0121846199035645, + 5.215719223022461, + 3.371406078338623, + 3.7616095542907715, + -3.6593239307403564, + 15.99945068359375, + 3.306276321411133, + 5.790191173553467, + 6.33050537109375, + 3.4512906074523926, + 2.5531861782073975, + 4.278702259063721, + 4.350361347198486, + 8.025779724121094, + -2.8830037117004395, + 2.915111541748047, + 3.592482805252075, + 5.810481071472168, + 3.4743332862854004, + 3.5245680809020996, + 1.8243598937988281, + 8.069726943969727, + 1.401036024093628, + 5.110081672668457, + -12.873579978942871, + 10.977816581726074, + 5.909627437591553, + -0.4007779359817505, + -20.147268295288086, + 6.649413585662842, + 3.325921058654785, + 5.84471321105957, + 4.47447395324707, + 3.754193067550659, + -5.167671203613281, + 3.2778055667877197, + -9.067073822021484, + 2.6243438720703125, + 1.7002031803131104, + 5.476454734802246, + 2.510835647583008, + 3.856968402862549, + 2.3172807693481445, + 12.462139129638672, + 7.355924129486084, + 4.140628814697266, + 4.807559967041016, + 5.7524309158325195, + 4.128836154937744, + 11.4532470703125, + -12.482564926147461, + 5.590144157409668, + 0.9172697067260742, + 4.356811046600342, + 0.9934853315353394, + -4.3548994064331055, + 15.853201866149902, + -5.241130828857422, + 5.9644365310668945, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_505") + + initializers.append(tensor) + + # inputs + + inputs.append(make_tensor_value_info("input", 1, ["batch_size", 3, 32, 32])) + + # outputs + + outputs.append(make_tensor_value_info("/layer1/layer1.0/relu/Relu_output_0", 1, ["batch_size", 64, 8, 8])) + + # nodes + + node = make_node( + "Conv", + ["input", "onnx::Conv_501", "onnx::Conv_502"], + ["/conv1/Conv_output_0"], + name="/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[7, 7], + pads=[3, 3, 3, 3], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node("Relu", ["/conv1/Conv_output_0"], ["/relu/Relu_output_0"], name="/relu/Relu", domain="") + nodes.append(node) + + node = make_node( + "MaxPool", + ["/relu/Relu_output_0"], + ["/maxpool/MaxPool_output_0"], + name="/maxpool/MaxPool", + ceil_mode=0, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node( + "Conv", + ["/maxpool/MaxPool_output_0", "onnx::Conv_504", "onnx::Conv_505"], + ["/layer1/layer1.0/conv1/Conv_output_0"], + name="/layer1/layer1.0/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + domain="", + ) + nodes.append(node) + + node = make_node( + "Relu", + ["/layer1/layer1.0/conv1/Conv_output_0"], + ["/layer1/layer1.0/relu/Relu_output_0"], + name="/layer1/layer1.0/relu/Relu", + domain="", + ) + nodes.append(node) + + # opsets + opset_imports = [make_opsetid(domain, 1 if version is None else version) for domain, version in opsets.items()] + + # graph + graph = make_graph(nodes, "torch_jit", inputs, outputs, initializers) + # '7' + + onnx_model = make_model(graph, opset_imports=opset_imports, functions=functions) + onnx_model.ir_version = 7 + onnx_model.producer_name = "pytorch" + onnx_model.producer_version = "" + onnx_model.domain = "" + onnx_model.model_version = 0 + onnx_model.doc_string = "" + set_model_props(onnx_model, {}) + + return onnx_model diff --git a/onnxruntime/test/python/quantization/test_quantize_static_resnet.py b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py new file mode 100644 index 0000000000000..1efa283af6881 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import random +import tempfile +import unittest + +import numpy as np +import onnx +from numpy.testing import assert_allclose +from onnx.numpy_helper import to_array +from resnet_code import create_model + +from onnxruntime import InferenceSession +from onnxruntime import __version__ as ort_version +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static +from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod + + +class FakeResnetCalibrationDataReader(CalibrationDataReader): + def __init__(self, batch_size: int = 16): + super().__init__() + self.dataset = [ + (np.random.rand(1, 3, 32, 32).astype(np.float32), random.randint(0, 9)) for _ in range(batch_size) + ] + self.iterator = iter(self.dataset) + + def get_next(self) -> dict: + try: + return {"input": next(self.iterator)[0]} + except Exception: + return None + + +class TestStaticQuantizationResNet(unittest.TestCase): + def test_quantize_static_resnet(self): + kwargs = { + "activation_type": QuantType.QUInt8, + "weight_type": QuantType.QInt8, + "calibrate_method": CalibrationMethod.Percentile, + "extra_options": { + "ActivationSymmetric": False, + "EnableSubgraph": False, + "ForceQuantizeNoInputCheck": False, + "MatMulConstBOnly": False, + "WeightSymmetric": True, + "extra.Sigmoid.nnapi": False, + }, + "nodes_to_exclude": None, + "nodes_to_quantize": None, + "op_types_to_quantize": None, + "per_channel": True, + "quant_format": QuantFormat.QDQ, + "reduce_range": False, + } + + proto = create_model() + + with tempfile.TemporaryDirectory() as temp: + model = os.path.join(temp, "resnet_first_nodes.onnx") + with open(model, "wb") as f: + f.write(proto.SerializeToString()) + + for per_channel in [True, False]: + kwargs["per_channel"] = per_channel + dataloader = FakeResnetCalibrationDataReader(16) + with self.subTest(per_channel=per_channel): + qdq_file = os.path.join( + temp, f"preprocessed-small-qdq-{1 if per_channel else 0}-ort-{ort_version}.onnx" + ) + quantize_static( + model_input=model, + model_output=qdq_file, + calibration_data_reader=dataloader, + use_external_data_format=False, + **kwargs, + ) + + # With onnxruntime==1.15.1, the initializer 'onnx::Conv_504_zero_point' is: + # * uint8(128) if per_channel is False + # * int8([0, 0, ....]) if per_channel is True + # With onnxruntime>1.16.0 + # * uint8(128) if per_channel is False + # * uint8([128, 128, ..., 127, ...]) if per_channel is True + # QLinearConv : zero point of per-channel filter must be same. + # That's why the quantization forces a symmetric quantization into INT8. + # zero_point is guaranted to be zero whatever the channel is. + + with open(qdq_file, "rb") as f: + onx = onnx.load(f) + for init in onx.graph.initializer: + arr = to_array(init) + if ( + arr.dtype == np.int8 + and "zero_point" not in init.name + and not init.name.endswith("quantized") + ): + raise AssertionError( + f"Initializer {init.name!r} has type {arr.dtype} and " + f"shape {arr.shape} but should be {np.uint8}." + ) + + sess = InferenceSession(qdq_file, providers=["CPUExecutionProvider"]) + shape = (1, 3, 32, 32) + size = np.prod(shape) + dummy = (np.arange(size) / float(size)).astype(np.float32).reshape(shape) + got = sess.run(None, {"input": dummy}) + self.assertEqual(got[0].shape, (1, 64, 8, 8)) + self.assertEqual(got[0].dtype, np.float32) + if per_channel: + expected = np.array( + [ + [[1.0862497091293335, 0.9609132409095764], [1.0862497091293335, 0.9191343784332275]], + [[0.7520190477371216, 1.0026921033859253], [1.0444709062576294, 1.0862497091293335]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.9609132409095764, 0.7937979102134705]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + else: + expected = np.array( + [ + [[1.428238868713379, 1.2602107524871826], [1.3442248106002808, 1.2182037830352783]], + [[0.8821475505828857, 1.0921826362609863], [1.1341897249221802, 1.1761966943740845]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [1.2182037830352783, 1.050175666809082]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.py b/onnxruntime/test/python/transformers/benchmark_mha.py new file mode 100644 index 0000000000000..1e75268ea6c5d --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.py @@ -0,0 +1,343 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Benchmark performance of MultiHeadAttention with Nvidia GPU of Compute Capability 8.0, 8.6 or 8.9 in Linux: +sh benchmark_mha.sh +""" + +import math +import os +import statistics +import time + +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + + +class InputFormats: + Q_K_V_BSNH = 0 + QKV_BSN3H = 1 + Q_KV_BSNH_BSN2H = 2 + + @staticmethod + def input_format_str(format: int) -> str: + return "QKV" if format == 1 else "Q,KV" if format == 2 else "Q,K,V" + + +class Config: + batch_size: int = 0 + sequence_length: int = 0 + kv_sequence_length: int = 0 + num_heads: int = 0 + head_size: int = 0 + causal: bool = False + input_format: int = InputFormats.Q_K_V_BSNH + + def __init__(self, b, s, s2, n, h, causal, input_format): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.num_heads = n + self.head_size = h + self.causal = causal + self.input_format = input_format + + +def create_multihead_attention_graph(config: Config): + query = helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ) + + key = helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ) + + value = helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ) + + packed_qkv = helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads, + 3, + config.head_size, + ], + ) + + packed_kv = helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads, + 2, + config.head_size, + ], + ) + + if config.input_format == InputFormats.QKV_BSN3H: + input_names = ["query"] + inputs = [packed_qkv] + elif config.input_format == InputFormats.Q_KV_BSNH_BSN2H: + input_names = ["query", "key"] + inputs = [query, packed_kv] + else: # input_format==InputFormats.Q_K_V_BSNH + input_names = ["query", "key", "value"] + inputs = [query, key, value] + + nodes = [ + helper.make_node( + "MultiHeadAttention", + input_names, + ["output"], + "MultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + outputs = [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + ] + + graph = helper.make_graph( + nodes, + "MultiHeadAttention_Graph", + inputs, + outputs, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def input_output_shapes(config: Config): + if config.input_format == InputFormats.QKV_BSN3H: + return { + "query": (config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size), + "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + } + + if config.input_format == InputFormats.Q_KV_BSNH_BSN2H: + return { + "query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + "key": (config.batch_size, config.kv_sequence_length, config.num_heads, 2, config.head_size), + "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + } + + return { + "query": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + "key": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size), + "value": (config.batch_size, config.kv_sequence_length, config.num_heads * config.head_size), + "output": (config.batch_size, config.sequence_length, config.num_heads * config.head_size), + } + + +def create_session( + device_id: int, config: Config, provider: str = "CUDAExecutionProvider", enable_cuda_graph: bool = False +) -> CudaSession: + onnx_model_str = create_multihead_attention_graph(config) + provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) + ort_session = InferenceSession(onnx_model_str, providers=[(provider, provider_options), "CPUExecutionProvider"]) + device = torch.device("cuda", device_id) + cuda_session = CudaSession(ort_session, device, enable_cuda_graph) + shape_dict = input_output_shapes(config) + cuda_session.allocate_buffers(shape_dict) + return cuda_session + + +def measure_latency(cuda_session: CudaSession, input_dict): + start = time.time() + _ = cuda_session.infer(input_dict) + end = time.time() + return end - start + + +def flops(batch, sequence_length, head_size, num_heads, causal): + return 4 * batch * sequence_length**2 * num_heads * head_size // (2 if causal else 1) + + +def tflops_per_second(flop, time): + return (flop / time / 10**12) if not math.isnan(time) else 0.0 + + +def get_sm8x_kernel_name(config: Config) -> str: + # This classification is for Nvidia GPU of Compute Capability 8.* like A100. + # Note that some kernel might not exist in older or newer GPUs. + if os.getenv("ORT_DISABLE_FLASH_ATTENTION") != "1": + if config.input_format == InputFormats.QKV_BSN3H: + min_seq_len = os.getenv("ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV") + min_length = int(min_seq_len) if min_seq_len is not None else 513 + if config.sequence_length >= min_length: + return "Flash" + else: + return "Flash" + + if (os.getenv("ORT_DISABLE_FUSED_CROSS_ATTENTION") != "1" and config.kv_sequence_length <= 128) or ( + os.getenv("ORT_DISABLE_FUSED_ATTENTION") != "1" + and (config.sequence_length <= 384 or os.getenv("ORT_DISABLE_TRT_FLASH_ATTENTION") != "1") + ): + return "TRT" + + if os.getenv("ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION") != "1": + return "MemEff" + + return "Unfused" + + +def run_tflops_test(dtype=torch.float16, enable_cuda_graph: bool = False, repeats: int = 100): + device_id = torch.cuda.current_device() + device = torch.device("cuda", device_id) + + # (batch_size, sequence_length, num_heads, head_size) + configs = [ + (32, 512, 64, 32), + (32, 512, 128, 16), + (16, 1024, 64, 32), + (16, 1024, 128, 16), + (8, 2048, 64, 32), + (8, 2048, 128, 16), + (4, 4096, 64, 32), + (4, 4096, 128, 16), + (2, 8192, 64, 32), + (2, 8192, 128, 16), + (1, 16384, 64, 32), + (1, 16384, 128, 16), + # stable diffusion + (1, 4096, 8, 40), + (1, 4096, 8, 80), + (1, 4096, 8, 160), + (4, 4096, 8, 40), + (4, 4096, 8, 80), + (4, 4096, 8, 160), + (1, 16384, 8, 40), + (1, 16384, 8, 80), + (1, 16384, 8, 160), + # bert-base + (128, 128, 12, 64), + (64, 128, 12, 64), + (128, 384, 12, 64), + (64, 384, 12, 64), + (128, 512, 12, 64), + (64, 512, 12, 64), + # TNLGv4 + (4, 2048, 32, 128), + (4, 4096, 32, 128), + (8, 2048, 32, 128), + (8, 4096, 32, 128), + ] + + print(f"enable_cuda_graph={enable_cuda_graph}") + + # List of environment variables to enable/disable attention kernels + print("Environment Variables:") + env_names = [ + "ORT_DISABLE_FLASH_ATTENTION", + "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV", + "ORT_DISABLE_FUSED_ATTENTION", + "ORT_DISABLE_TRT_FLASH_ATTENTION", + "ORT_ENABLE_FUSED_CAUSAL_ATTENTION", + "ORT_DISABLE_FUSED_CROSS_ATTENTION", + "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION", + ] + for name in env_names: + value = os.getenv(name) + if value is not None: + print(f"{name}={value}") + print() + + print("format\tcausal\tbatch\tseqlen\theads\th_dim\tms\tTFLOPS\tkernel") + causal = False + for input_format in [InputFormats.Q_K_V_BSNH, InputFormats.Q_KV_BSNH_BSN2H, InputFormats.QKV_BSN3H]: + for batch_size, sequence_length, num_heads, head_size in configs: + config = Config(batch_size, sequence_length, sequence_length, num_heads, head_size, causal, input_format) + + session = create_session(device_id, config, enable_cuda_graph=enable_cuda_graph) + + qkv = torch.randn(batch_size, sequence_length, 3, num_heads, head_size, device=device, dtype=dtype) + q, k, v = qkv.unbind(dim=2) + + if input_format == InputFormats.QKV_BSN3H: + if config.sequence_length != config.kv_sequence_length: + continue + q = torch.reshape(q, (-1, config.num_heads, config.head_size)) + k = torch.reshape(k, (-1, config.num_heads, config.head_size)) + v = torch.reshape(v, (-1, config.num_heads, config.head_size)) + packed_qkv = torch.dstack((q, k, v)).reshape( + config.batch_size, config.sequence_length, config.num_heads, 3, config.head_size + ) + input_dict = {"query": packed_qkv.contiguous()} + elif input_format == InputFormats.Q_KV_BSNH_BSN2H: + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (-1, config.num_heads, config.head_size)) + v = torch.reshape(v, (-1, config.num_heads, config.head_size)) + packed_kv = torch.dstack((k, v)).reshape( + config.batch_size, config.sequence_length, config.num_heads, 2, config.head_size + ) + input_dict = {"query": q.contiguous(), "key": packed_kv.contiguous()} + else: # input_format == InputFormats.Q_K_V_BSNH + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + input_dict = { + "query": q.contiguous(), + "key": k.contiguous(), + "value": v.contiguous(), + } + + # warm up session + _ = measure_latency(session, input_dict) + + latency_list = [] + for _ in range(repeats): + latency = measure_latency(session, input_dict) + latency_list.append(latency) + average_latency = statistics.mean(latency_list) + + del session + + # compute TFLOPS per second + speed = tflops_per_second(flops(batch_size, sequence_length, head_size, num_heads, causal), average_latency) + + kernel = get_sm8x_kernel_name(config) + format = InputFormats.input_format_str(input_format) + print( + f"{format}\t{causal}\t{batch_size}\t{sequence_length}\t{num_heads}\t{head_size}\t{average_latency * 1000:.2f}\t{speed:.2f}\t{kernel}" + ) + + +if __name__ == "__main__": + run_tflops_test(enable_cuda_graph=False) diff --git a/onnxruntime/test/python/transformers/benchmark_mha.sh b/onnxruntime/test/python/transformers/benchmark_mha.sh new file mode 100644 index 0000000000000..7b21cf1cc1e08 --- /dev/null +++ b/onnxruntime/test/python/transformers/benchmark_mha.sh @@ -0,0 +1,14 @@ +echo "flash attention v2" +ORT_DISABLE_FLASH_ATTENTION=0 ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV=0 python benchmark_mha.py | tee result.txt + +echo "===" +echo "TensorRT attention kernels - cross attention (when kv_seq_len <= 128) or fused attention (when seq_len <= 384) or flash attention (seq_len > 384)" +ORT_DISABLE_FLASH_ATTENTION=1 python benchmark_mha.py | tee -a result.txt + +echo "===" +echo "Memory Efficient attention" +ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 python benchmark_mha.py | tee -a result.txt + +echo "===" +echo "Unfused Attention (some configurations might fail)" +ORT_DISABLE_FLASH_ATTENTION=1 ORT_DISABLE_TRT_FLASH_ATTENTION=1 ORT_DISABLE_FUSED_ATTENTION=1 ORT_DISABLE_FUSED_CROSS_ATTENTION=1 ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION=1 python benchmark_mha.py | tee -a result.txt diff --git a/onnxruntime/test/python/transformers/bert_padding.py b/onnxruntime/test/python/transformers/bert_padding.py new file mode 100644 index 0000000000000..a4ef7652643ab --- /dev/null +++ b/onnxruntime/test/python/transformers/bert_padding.py @@ -0,0 +1,131 @@ +# From https://github.com/Dao-AILab/flash-attention/blob/2286d7cea7ca8264165c16b2442b6436c43140de/flash_attn/bert_padding.py + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather(rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim)).reshape( + -1, *other_shape + ) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], device=grad_output.device, dtype=grad_output.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros(first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, _ = input.shape[0], input.shape[1:] + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz) + Return: + hidden_states: (batch, seqlen, ...) + """ + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx b/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx index 76d808538e0e4..216f5444d88ce 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx and b/onnxruntime/test/python/transformers/test_data/models/attention_mha.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx index ececb8701a283..cc712c61afb77 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/attention_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx index da048bbe5cd34..25dc71ff51215 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/attention_with_varied_qkv_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx index fe5384bd4e234..53a2809038aac 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/bert_3d_attention_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx index 177c29f607ddb..b4ed7169df447 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_no_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx index 036d6c16010a3..0ef3b083191c1 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_add_opt_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx index 7f1174d966011..62b51b9dd2dfa 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_no_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx index ee11024900e39..0ef3b083191c1 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_attention_opt_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx index debd5244abce5..8f545aa8b3bb9 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_no_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx index 856d76947a4da..8d0f1697b5386 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx and b/onnxruntime/test/python/transformers/test_data/models/gpt2_megatron_opt_skiplayernorm.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx b/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx index 51bf9f08ff0f5..337ede0e4ec39 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx and b/onnxruntime/test/python/transformers/test_data/models/pruned_attention_opt.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx index 2fc6a8959d9da..25265839c82a9 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_attention_with_sln_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx index 0c5035f7dcc6b..5f21da7e591d9 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx index 8759d958d3aef..1da242e19e711 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_mha_split_bias_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx index 7b3368d8245f8..e7a201658bb4d 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx index 3c7b613f427d2..bc72c9b350087 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_cross_mha_split_bias_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx index 1119e4c51a699..969f20b2860c3 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx index 6a4ee4761a94c..ca7f33a3f1d8d 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/decoder_with_past_self_mha_split_bias_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx b/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx index 190b70741fc4c..15a178863b9e5 100644 Binary files a/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx and b/onnxruntime/test/python/transformers/test_data/models/whisper/encoder_attention_with_sln_fused.onnx differ diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py new file mode 100644 index 0000000000000..f90a9475b4588 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -0,0 +1,528 @@ +# -------------------------------------------------------------------------- +# Copyright 2020 The HuggingFace Inc. team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import math + +import numpy +import torch +from bert_padding import pad_input, unpad_input +from einops import rearrange, repeat +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession, SessionOptions + +torch.manual_seed(0) + + +class Config: + batch_size = 0 + sequence_length = 0 + kv_sequence_length = 0 + num_heads = 0 + head_size = 0 + + def __init__(self, b, s, s2, n, h): + self.batch_size = b + self.sequence_length = s + self.kv_sequence_length = s2 + self.num_heads = n + self.head_size = h + + +def create_packed_multihead_attention_graph(config): + nodes = [ + helper.make_node( + "PackedMultiHeadAttention", + [ + "query", + "", + "", + "", + "token_offset", + "cumulative_sequence_length", + ], + ["output"], + "PackedMultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes, + "PackedMultiHeadAttention_Graph", + [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + -1, + config.num_heads, + 3, + config.head_size, + ], + ), + helper.make_tensor_value_info( + "token_offset", TensorProto.INT32, [config.batch_size, config.sequence_length] + ), + helper.make_tensor_value_info("cumulative_sequence_length", TensorProto.INT32, [config.batch_size + 1]), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [-1, config.num_heads * config.head_size], + ), + ], + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def create_multihead_attention_graph(config): + nodes = [ + helper.make_node( + "MultiHeadAttention", + [ + "query", + "key", + "value", + ], + ["output"], + "MultiHeadAttention_0", + num_heads=config.num_heads, + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes, + "MultiHeadAttention_Graph", + [ + helper.make_tensor_value_info( + "query", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.num_heads * config.head_size, + ], + ), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT16, + [config.batch_size, config.sequence_length, config.num_heads * config.head_size], + ), + ], + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def generate_random_padding_mask(max_seqlen, batch_size, device, mode="random"): + assert mode in ["full", "random", "third"] + if mode == "full": + lengths = torch.full((batch_size, 1), max_seqlen, device=device, dtype=torch.int32) + elif mode == "random": + lengths = torch.randint(max(1, max_seqlen - 20), max_seqlen, (batch_size, 1), device=device) + else: + lengths = torch.randint(max_seqlen // 3, max_seqlen, (batch_size, 1), device=device) + padding_mask = repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths + return padding_mask + + +def generate_qkv(q, k, v, query_padding_mask=None, key_padding_mask=None, kvpacked=False, qkvpacked=False): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, d) + k: (batch_size, seqlen_k, nheads_k, d) + v: (batch_size, seqlen_k, nheads_k, d) + query_padding_mask: (batch_size, seqlen), bool + key_padding_mask: (batch_size, seqlen), bool + """ + assert not (kvpacked and qkvpacked) + batch_size, seqlen_q, nheads, d = q.shape + _, seqlen_k, nheads_k, _ = k.shape + assert k.shape == (batch_size, seqlen_k, nheads_k, d) + assert v.shape == (batch_size, seqlen_k, nheads_k, d) + + if query_padding_mask is not None: + q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, query_padding_mask) + + def output_pad_fn(output_unpad): + return pad_input(output_unpad, indices_q, batch_size, seqlen_q) + + else: + q_unpad = rearrange(q, "b s h d -> (b s) h d") + cu_seqlens_q = torch.arange( + 0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q_unpad.device + ) + max_seqlen_q = seqlen_q + + def output_pad_fn(output_unpad): + return rearrange(output_unpad, "(b s) h d -> b s h d", b=batch_size) + + if key_padding_mask is not None: + k_unpad, indices_k, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask) + v_unpad, _, _, _ = unpad_input(v, key_padding_mask) + else: + k_unpad = rearrange(k, "b s h d -> (b s) h d") + v_unpad = rearrange(v, "b s h d -> (b s) h d") + cu_seqlens_k = torch.arange( + 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=k_unpad.device + ) + max_seqlen_k = seqlen_k + + if qkvpacked: + assert (query_padding_mask == key_padding_mask).all() + assert nheads == nheads_k + qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) + qkv = torch.stack([q, k, v], dim=2) + if query_padding_mask is not None: + + def dqkv_pad_fn(dqkv_unpad): + return pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) + + else: + + def dqkv_pad_fn(dqkv_unpad): + return rearrange(dqkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + qkv_unpad.detach().requires_grad_(), + cu_seqlens_q, + max_seqlen_q, + qkv.detach().requires_grad_(), + output_pad_fn, + dqkv_pad_fn, + ) + elif kvpacked: + kv_unpad = torch.stack([k_unpad, v_unpad], dim=1) + kv = torch.stack([k, v], dim=2) + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dkv_pad_fn(dkv_unpad): + return pad_input(dkv_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dkv_pad_fn(dkv_unpad): + return rearrange(dkv_unpad, "(b s) t h d -> b s t h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + kv_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + kv.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dkv_pad_fn, + ) + else: + dq_pad_fn = output_pad_fn + if key_padding_mask is not None: + + def dk_pad_fn(dk_unpad): + return pad_input(dk_unpad, indices_k, batch_size, seqlen_k) + + else: + + def dk_pad_fn(dk_unpad): + return rearrange(dk_unpad, "(b s) h d -> b s h d", b=batch_size) + + return ( + q_unpad.detach().requires_grad_(), + k_unpad.detach().requires_grad_(), + v_unpad.detach().requires_grad_(), + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q.detach().requires_grad_(), + k.detach().requires_grad_(), + v.detach().requires_grad_(), + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) + + +def create_inputs(config: Config, kv_packed=False, qkv_packed=True): + qkv = torch.randn( + config.batch_size, + config.sequence_length, + 3, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + key_padding_mask = generate_random_padding_mask( + config.sequence_length, config.batch_size, device="cuda", mode="random" + ) + qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn = generate_qkv( + *qkv.unbind(dim=2), key_padding_mask, key_padding_mask, kv_packed, qkv_packed + ) + return qkv_unpad, cu_seqlens, max_seqlen, qkv, output_pad_fn, dqkv_pad_fn, key_padding_mask + + +def generate_token_offset(cu_seqlens, max_seqlen): + token_offset = [] + token_padset = [] # These are the indices that contain padding tokens + for i in range(1, len(cu_seqlens)): + start = i - 1 + pre_seqlen = cu_seqlens[i - 1] + seqlen = cu_seqlens[i] + token_offset += range(start * max_seqlen, (start * max_seqlen) + (seqlen - pre_seqlen)) + token_padset += range((start * max_seqlen) + (seqlen - pre_seqlen), i * max_seqlen) + return numpy.asarray(token_offset + token_padset, dtype=numpy.int32) + + +def flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False): + onnx_model_str = create_packed_multihead_attention_graph(config) + qkv_unpad = torch.swapdims(qkv_unpad, 1, 2) + ort_inputs = { + "query": qkv_unpad.detach().cpu().numpy(), + "token_offset": token_offset, + "cumulative_sequence_length": cu_seqlens.cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + ort_output = ort_session.run(None, ort_inputs) + output = torch.tensor(ort_output) + return output + + +def flash_attn_func(q, k, v, config, causal=False): + onnx_model_str = create_multihead_attention_graph(config) + q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) + k = torch.reshape(k, (config.batch_size, config.kv_sequence_length, -1)) + v = torch.reshape(v, (config.batch_size, config.kv_sequence_length, -1)) + ort_inputs = { + "query": q.detach().cpu().numpy(), + "key": k.detach().cpu().numpy(), + "value": v.detach().cpu().numpy(), + } + sess_options = SessionOptions() + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) + ort_output = ort_session.run(None, ort_inputs) + output = torch.tensor(ort_output) + return output + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads_k, head_dim) + v: (batch_size, seqlen_k, nheads_k, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if causal: + causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=q.device), 1) + scores.masked_fill_(causal_mask, float("-inf")) + attention = torch.softmax(scores, dim=-1) + dropout_scaling = 1.0 / (1 - dropout_p) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +def attention_qkvpacked_ref( + qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False +): + return attention_ref( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + key_padding_mask, + key_padding_mask, + dropout_p, + dropout_mask, + upcast=upcast, + causal=causal, + reorder_ops=reorder_ops, + ) + + +def parity_check( + config, + packed, + rtol=1e-3, + atol=1e-3, +): + if packed: + qkv_unpad, cu_seqlens, _, qkv, output_pad_fn, _, key_padding_mask = create_inputs(config) + token_offset = generate_token_offset(cu_seqlens, config.sequence_length).reshape( + (config.batch_size, config.sequence_length) + ) + # ORT Flash + out_unpad = flash_attn_varlen_qkvpacked_func(qkv_unpad, cu_seqlens, token_offset, config, causal=False) + out_unpad = torch.squeeze(out_unpad, 0) + out = torch.reshape( + output_pad_fn(out_unpad), (config.batch_size, config.sequence_length, config.num_heads, config.head_size) + ) + out = out.detach().cpu().numpy() + # Pytorch to compare + out_ref, _ = attention_qkvpacked_ref(qkv, key_padding_mask, 0.0, None, causal=False) + out_ref = out_ref.detach().cpu().numpy() + else: + q = torch.randn( + config.batch_size, + config.sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + k = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + v = torch.randn( + config.batch_size, + config.kv_sequence_length, + config.num_heads, + config.head_size, + device="cuda", + dtype=torch.float16, + requires_grad=False, + ) + out = flash_attn_func(q, k, v, config) + out = torch.squeeze(out, 0) + out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) + out = out.detach().cpu().numpy() + # Pytorch to compare + out_ref, _ = attention_ref(q, k, v, None, None, 0.0, None) + out_ref = out_ref.detach().cpu().numpy() + # Compare results + print( + " B:", + config.batch_size, + " S:", + config.sequence_length, + " N:", + config.num_heads, + " h:", + config.head_size, + " Mean Error:", + numpy.mean(numpy.abs(out - out_ref)), + numpy.allclose( + out, + out_ref, + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +if __name__ == "__main__": + print("-------- TEST PACKED MHA ---------") + for b in [5]: + for s in [97, 128, 200, 256, 257, 384, 512, 768, 1024, 1025, 2048]: + for n in [6]: + for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: + config = Config(b, s, s, n, h) + parity_check(config, True) + print("-------- TEST MHA ---------") + for b in [5]: + for s, s2 in [ + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ]: + for n in [6]: + for h in [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]: + config = Config(b, s, s2, n, h) + parity_check(config, False) diff --git a/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py new file mode 100644 index 0000000000000..5b3a3f18cd744 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_skip_layer_norm_fusion.py @@ -0,0 +1,276 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest +from typing import Dict, List + +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestFusion(unittest.TestCase): + def verify_skip_layer_norm_fusion( + self, + model_path: str, + expected_counter: Dict[str, int], + expected_inputs: List[str], + expected_outputs: List[str], + ): + options = FusionOptions("bert") + optimized_model = optimize_model(model_path, optimization_options=options, opt_level=0) + + ops = ["Add", "LayerNormalization", "SkipLayerNormalization", "Cast"] + for op in ops: + nodes = optimized_model.get_nodes_by_op_type(op) + print(op, len(nodes), expected_counter[op]) + self.assertEqual(len(nodes), expected_counter[op]) + + if op == "SkipLayerNormalization" and expected_counter[op] == 1: + print(nodes[0].input) + print(nodes[0].output) + self.assertEqual(nodes[0].input, expected_inputs) + self.assertEqual(nodes[0].output, expected_outputs) + + def create_test_model( + self, + batch_size: int = 1, + sequence_length: int = 2, + hidden_size: int = 3, + add_graph_output: bool = True, + bias: int = 0, # 0 - no bias, 1 - bias in input_1, 2 - bias in input_2 + cast_before_add_bias=False, + ): + matmul = helper.make_node("MatMul", ["input_0", "matmul_weight"], ["matmul_output"], "matmul") + cast_node = helper.make_node("Cast", ["matmul_output"], ["matmul_output_cast"], to=1) + add_bias = helper.make_node( + "Add", + ["matmul_output_cast" if cast_before_add_bias else "matmul_output", "bias"], + ["input_1" if bias == 1 else "input_2"], + "add_bias", + ) + + add_before_layer_norm = helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm") + layer_norm = helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ) + + initializers = [ # initializers + float_tensor("layer_norm_weight", [hidden_size]), + float_tensor("layer_norm_bias", [hidden_size]), + ] + + if bias > 0: + weight_tensor = float_tensor("matmul_weight", [hidden_size, hidden_size]) + # MatMul weights is float16 when there is Cast node + if cast_before_add_bias: + weight_tensor.CopyFrom( + numpy_helper.from_array(numpy_helper.to_array(weight_tensor).astype(np.float16), weight_tensor.name) + ) + initializers.append(weight_tensor) + + bias_tensor = float_tensor("bias", [hidden_size]) + initializers.append(bias_tensor) + + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT16 if cast_before_add_bias else TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + input_1 = helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + input_2 = helper.make_tensor_value_info( + "input_2", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + output = helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + layernorm_input = helper.make_tensor_value_info( + "layernorm_input", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ) + + nodes = [add_before_layer_norm, layer_norm] + if bias > 0: + nodes.insert(0, add_bias) + if cast_before_add_bias: + nodes.insert(0, cast_node) + nodes.insert(0, matmul) + + node_name = "SkipLayerNormFusionModel" + if bias == 0: + graph = helper.make_graph( + nodes, + node_name, + [input_1, input_2], # inputs + [output, layernorm_input] if add_graph_output else [output], # outputs + initializers, + ) + elif bias == 1: + graph = helper.make_graph( + nodes, + node_name, + [input_0, input_2], # inputs + [output, layernorm_input] if add_graph_output else [output], # outputs + initializers, + ) + else: + graph = helper.make_graph( + nodes, + node_name, + [input_0, input_1], # inputs + [output, layernorm_input] if add_graph_output else [output], # outputs + initializers, + ) + + onnx_opset = helper.make_opsetid("ai.onnx", min(onnx.defs.onnx_opset_version(), 16)) + return helper.make_model(graph, opset_imports=(onnx_opset,)) + + def test_skip_layer_norm_no_graph_output(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=False) + model_name = "skip_layer_norm_add_no_graph_output.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"], + ["output"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True) + model_name = "skip_layer_norm_add_has_graph_output.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["input_1", "input_2", "layer_norm_weight", "layer_norm_bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_bias1(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=1) + model_name = "skip_layer_norm_add_has_graph_output_bias1.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["matmul_output", "input_2", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_bias2(self): + model = self.create_test_model(batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=2) + model_name = "skip_layer_norm_add_has_graph_output_bias1.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 0, + }, + ["matmul_output", "input_1", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_cast_bias1(self): + model = self.create_test_model( + batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=1, cast_before_add_bias=True + ) + model_name = "skip_layer_norm_add_has_graph_output_cast_bias1.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 1, + }, + ["matmul_output_cast", "input_2", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + def test_skip_layer_norm_graph_output_cast_bias2(self): + model = self.create_test_model( + batch_size=1, sequence_length=2, hidden_size=3, add_graph_output=True, bias=2, cast_before_add_bias=True + ) + model_name = "skip_layer_norm_add_has_graph_output_cast_bias2.onnx" + onnx.save(model, model_name) + self.verify_skip_layer_norm_fusion( + model_name, + { + "Add": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 1, + "Cast": 1, + }, + ["matmul_output_cast", "input_1", "layer_norm_weight", "layer_norm_bias", "bias"], + ["output", "", "", "layernorm_input"], + ) + os.remove(model_name) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/testdata/attention_no_mask_fp16.onnx b/onnxruntime/test/testdata/attention_no_mask_fp16.onnx new file mode 100644 index 0000000000000..fe8aa0038d4fd Binary files /dev/null and b/onnxruntime/test/testdata/attention_no_mask_fp16.onnx differ diff --git a/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx new file mode 100644 index 0000000000000..1dec9910087fc Binary files /dev/null and b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32.onnx differ diff --git a/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy new file mode 100644 index 0000000000000..706f508836888 Binary files /dev/null and b/onnxruntime/test/testdata/float8/te.cast_fp8_1_fp32_input.npy differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.onnx b/onnxruntime/test/testdata/ort_github_issue_17000.onnx new file mode 100644 index 0000000000000..8320c19cb6de4 Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.onnx differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.ort b/onnxruntime/test/testdata/ort_github_issue_17000.ort new file mode 100644 index 0000000000000..08d9826dd5346 Binary files /dev/null and b/onnxruntime/test/testdata/ort_github_issue_17000.ort differ diff --git a/onnxruntime/test/testdata/ort_github_issue_17000.py b/onnxruntime/test/testdata/ort_github_issue_17000.py new file mode 100644 index 0000000000000..43c10f5590212 --- /dev/null +++ b/onnxruntime/test/testdata/ort_github_issue_17000.py @@ -0,0 +1,77 @@ +import numpy as np +import onnx +from onnx import TensorProto, helper, numpy_helper + + +def order_repeated_field(repeated_proto, key_name, order): + order = list(order) + repeated_proto.sort(key=lambda x: order.index(getattr(x, key_name))) + + +def make_node(op_type, inputs, outputs, name=None, doc_string=None, domain=None, **kwargs): + node = helper.make_node(op_type, inputs, outputs, name, doc_string, domain, **kwargs) + if doc_string == "": + node.doc_string = "" + order_repeated_field(node.attribute, "name", kwargs.keys()) + return node + + +def make_graph(*args, doc_string=None, **kwargs): + graph = helper.make_graph(*args, doc_string=doc_string, **kwargs) + if doc_string == "": + graph.doc_string = "" + return graph + + +test_graph = make_graph( + name="test_graph", + # model input of a sequence type to test IsSparseTensor issue + inputs=[ + helper.make_tensor_sequence_value_info("seq_in", TensorProto.FLOAT, shape=None), + ], + outputs=[ + helper.make_tensor_value_info("still_has_elements", TensorProto.BOOL, shape=[]), + ], + initializer=[ + numpy_helper.from_array(np.array(0, dtype="int64"), name="i0"), + ], + nodes=[ + make_node("SequenceLength", inputs=["seq_in"], outputs=["seq_len"], name="get_seq_len"), + make_node("Greater", inputs=["seq_len", "i0"], outputs=["has_elements"], name="get_has_elements"), + # If node with one branch that has no nodes to test the allocation planner issue + # if sequence has elements: + # remove one + # output bool of whether it still has elements + # else: + # output false (gives us branch with no nodes) + make_node( + "If", + name="test_if", + inputs=["has_elements"], + outputs=["still_has_elements"], + then_branch=make_graph( + name="then", + inputs=[], + outputs=[helper.make_tensor_value_info("then_bool_out", TensorProto.BOOL, shape=[])], + nodes=[ + make_node("SequenceErase", inputs=["seq_in", "i0"], outputs=["seq_less_one"]), + make_node("SequenceLength", inputs=["seq_less_one"], outputs=["new_seq_len"]), + make_node("Greater", inputs=["new_seq_len", "i0"], outputs=["then_bool_out"]), + ], + ), + else_branch=make_graph( + name="else", + initializer=[numpy_helper.from_array(np.array(False, dtype="bool"), name="else_bool_out")], + inputs=[], + outputs=[helper.make_tensor_value_info("else_bool_out", TensorProto.BOOL, shape=[])], + nodes=[], + ), + ), + ], +) + +# Graph with Sequence operations and an If node that has a subgraph with no nodes +model = helper.make_model(opset_imports=[helper.make_operatorsetid("ai.onnx", 14)], ir_version=7, graph=test_graph) + +onnx.shape_inference.infer_shapes(model, strict_mode=True) +onnx.save(model, "ort_github_issue_17000.onnx") diff --git a/onnxruntime/test/testdata/required_ops.config b/onnxruntime/test/testdata/required_ops.config index ac9d46666e1b6..e70362bab4017 100644 --- a/onnxruntime/test/testdata/required_ops.config +++ b/onnxruntime/test/testdata/required_ops.config @@ -3,9 +3,9 @@ ai.onnx;7;Abs,Add,And,BatchNormalization,Concat,Conv,Dropout,Flatten,Foo,Gather, ai.onnx;8;Add,Conv,Flatten,Gemm,MatMul,MaxPool,Mul,Relu,Reshape ai.onnx;9;Abs,Add,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Gather,Gemm,Identity,If,LayerNormalization,LeakyRelu,Loop,MatMul,Mul,Pow,ReduceMean,Relu,Reshape,Scan,Shape,Sigmoid,Slice,Softmax,Softsign,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze ai.onnx;10;Add,Cast,Concat,ConstantOfShape,Div,Dropout,Erf,Expand,Gather,Greater,Identity,If,LayerNormalization,Loop,MatMul,Mul,Neg,NonZero,Pow,ReduceMean,ReduceSum,Shape,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze -ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where +ai.onnx;11;Abs,Add,ArgMax,BatchNormalization,Cast,Clip,Concat,Constant,ConstantOfShape,Conv,Div,Equal,Exp,Expand,Flatten,Gather,Gemm,Identity,If,LayerNormalization,Log,Loop,MatMul,MatMulInteger,Max,Min,Mul,Neg,Pow,RandomUniform,Range,ReduceMean,ReduceSum,ReduceSumSquare,Relu,Reshape,Scan,SequenceConstruct,SequenceErase,SequenceInsert,SequenceLength,Shape,Sigmoid,Slice,Softmax,Split,Sqrt,Squeeze,Sub,Sum,Tanh,Transpose,Unsqueeze,Where ai.onnx;12;Add,And,Cast,Concat,Constant,ConstantOfShape,Conv,CumSum,Div,Dropout,DynamicQuantizeLinear,Equal,Erf,Expand,Flatten,Gather,GatherND,Gemm,GlobalAveragePool,Greater,Identity,If,IsInf,LayerNormalization,Less,Loop,MatMul,MatMulInteger,Min,Mul,Not,Pad,Pow,RandomNormalLike,RandomUniform,ReduceMean,ReduceSum,Relu,Reshape,Shape,Slice,Softmax,SoftmaxCrossEntropyLoss,SparseSoftmaxCrossEntropy,Split,Sqrt,Squeeze,Sub,Tanh,Transpose,Unsqueeze,Where -ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Identity,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where +ai.onnx;13;Abs,Add,Cast,Concat,ConstantOfShape,Conv,DequantizeLinear,DynamicQuantizeLinear,Equal,Expand,FooBar,FooBar_Attr,Gather,Greater,Identity,If,LayerNormalization,MatMul,MatMulInteger,Mul,Pad,Pow,QuantizeLinear,Range,ReduceSum,Reshape,Shape,Tanh,Transpose,Unsqueeze,Where ai.onnx;14;Add,ArgMax,Cast,Conv,Identity,Relu,Sigmoid,Sub ai.onnx;314159;Add ai.onnx.contrib;1;StringLower diff --git a/onnxruntime/test/testdata/required_ops_and_types.config b/onnxruntime/test/testdata/required_ops_and_types.config index 17687906d7250..41f374214747b 100644 --- a/onnxruntime/test/testdata/required_ops_and_types.config +++ b/onnxruntime/test/testdata/required_ops_and_types.config @@ -1,9 +1,12 @@ # required ops and types for ORT format models in testdata -ai.onnx;1;Conv{"inputs": {"0": ["float"]}},Foo,Identity +ai.onnx;1;Conv{"inputs": {"0": ["float"]}} ai.onnx;5;Reshape ai.onnx;6;Relu{"inputs": {"0": ["float"]}} ai.onnx;7;Add{"inputs": {"0": ["float"]}},Gemm{"inputs": {"0": ["float"]}},Mul{"inputs": {"0": ["float"]}} ai.onnx;8;MaxPool{"inputs": {"0": ["float"]}},Sum{"inputs": {"0": ["float"]}} ai.onnx;9;Cast{"inputs": {"0": ["float"]}, "outputs": {"0": ["bool"]}} -ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},If,Loop +ai.onnx;10;QLinearConv{"inputs": {"0": ["uint8_t"]}} +ai.onnx;11;ArgMax{"inputs": {"0": ["float"]}},Clip{"inputs": {"0": ["float"]}},Conv{"inputs": {"0": ["float"]}},If,Loop,SequenceErase,SequenceLength +ai.onnx;13;DequantizeLinear{"inputs": {"0": ["int32_t", "uint8_t"]}},Greater{"inputs": {"0": ["int64_t"]}},If,QuantizeLinear{"outputs": {"0": ["uint8_t"]}} ai.onnx.ml;1;ArrayFeatureExtractor,LinearClassifier,Normalizer,ZipMap +test;1;Foo diff --git a/onnxruntime/test/testdata/training_api/ort_format/checkpoint b/onnxruntime/test/testdata/training_api/ort_format/checkpoint new file mode 100644 index 0000000000000..ab35c9ad5acde Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/checkpoint differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort new file mode 100644 index 0000000000000..69b2c7e029de0 Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/eval_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort new file mode 100644 index 0000000000000..88f192462362d Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/optimizer_model.ort differ diff --git a/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py new file mode 100644 index 0000000000000..70e8c4ac011a9 --- /dev/null +++ b/onnxruntime/test/testdata/training_api/ort_format/prepare_artifacts.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +"""This file is used to generate test data for ort format model tests in + orttraining/orttraining/test/training_api/core/training_capi_tests.cc.""" + +import onnx +import torch +import torch.nn as nn + +from onnxruntime.training import artifacts + + +class SimpleNet(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.fc1 = nn.Linear(input_size, hidden_size) + self.relu = nn.ReLU() + self.fc2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + out = self.fc2(out) + return out + + +def model_export(pt_model, model_path, input_size): + # Generate random input data + input_data = torch.randn(32, input_size) + torch.onnx.export( + pt_model, + input_data, + model_path, + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + + +def main(): + # Set the dimensions for input, hidden, and output layers + input_size = 10 + hidden_size = 20 + output_size = 5 + + # Create an instance of the neural network + pt_model = SimpleNet(input_size, hidden_size, output_size) + + train_model_path = "simplenet_training.onnx" + model_export(pt_model, train_model_path, input_size) + + onnx_model = onnx.load(train_model_path) + + requires_grad = ["fc2.weight", "fc2.bias"] + frozen_params = [param.name for param in onnx_model.graph.initializer if param.name not in requires_grad] + + # Generate the training artifacts. + artifacts.generate_artifacts( + onnx_model, + requires_grad=requires_grad, + frozen_params=frozen_params, + loss=artifacts.LossType.CrossEntropyLoss, + optimizer=artifacts.OptimType.AdamW, + ort_format=True, + ) + + +if __name__ == "__main__": + main() diff --git a/onnxruntime/test/testdata/training_api/ort_format/training_model.ort b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort new file mode 100644 index 0000000000000..94bda328a9f9f Binary files /dev/null and b/onnxruntime/test/testdata/training_api/ort_format/training_model.ort differ diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index f8e05455746e9..429ce6d9680ba 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2070,5 +2070,22 @@ IMPLEMENT_GRADIENT_BUILDER(GetLeakyReluGradient) { {GO(0), O(0)}, {GI(0)}, SrcNodeAttributes())}; } +IMPLEMENT_GRADIENT_BUILDER(GetConvTransposeGradient) { + std::vector outputs; + for (int i = 0; i < GetSrcNodeInputSize(); i++) { + if (IsGradientRequiredForSrcNodeInput(i)) { + outputs.push_back(GI(i)); + } else { + outputs.push_back(ArgDef("", nullptr)); + } + } + + return std::vector{ + NodeDef(OpDef{"ConvTransposeGrad", kMSDomain, 1}, + {GO(0), I(0), I(1)}, + outputs, + SrcNodeAttributes())}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index ca86777d36bfb..84880b88506e1 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -88,6 +88,7 @@ DECLARE_GRADIENT_BUILDER(GetLSTMGradient) DECLARE_GRADIENT_BUILDER(GetGRUGradient) DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) +DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index cc9a762ff8fa5..c84fc0d360114 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -120,6 +120,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("GRUTraining", GetGRUGradient); REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient); REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); + REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 60867accb8ed9..eb84865fd707c 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4908,6 +4908,21 @@ Return true if all elements are true and false otherwise. } } }); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ConvTransposeGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "dY", "Gradient of output Y", "T") + .Input(1, "X", "Input tensor", "T") + .Input(2, "W", "Weight tensor", "T") + .Output(0, "dX", "Gradient of X", "T", OpSchema::Optional) + .Output(1, "dW", "Gradient of W", "T", OpSchema::Optional) + .Output(2, "dB", "Gradient of B", "T", OpSchema::Optional) + .AllowUncheckedAttributes() + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors."); } } // namespace training diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index eac17f3d4d2e8..3f3aa396e6ca0 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -174,10 +174,11 @@ struct PyOptimizer { PyOptimizer(const std::string optimizer_model_uri, onnxruntime::training::api::CheckpointState* state, std::vector> providers, PySessionOptions* session_options) : optimizer_() { + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers("", std::nullopt, optimizer_model_uri); auto env = GetTrainingEnv().GetORTEnv(); // XXX: We hope that env will be around when optimizer needs it. optimizer_ = std::make_shared( - optimizer_model_uri, state, session_options->value, *env, providers, session_options->custom_op_domains_); + model_identifiers, state, session_options->value, *env, providers, session_options->custom_op_domains_); } std::shared_ptr optimizer_; @@ -941,9 +942,10 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn OrtDevice device, PySessionOptions* session_options) { std::vector> provider = GetExecutionProvidersForTrainingApis(device); auto env = GetTrainingEnv().GetORTEnv(); - return std::make_unique( - model_uri, state, session_options->value, *env, provider, eval_model_uri, - session_options->custom_op_domains_); + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers(model_uri, eval_model_uri, std::nullopt); + return std::make_unique(model_identifiers, + state, session_options->value, *env, provider, + session_options->custom_op_domains_); })) .def("train_step", [](onnxruntime::training::api::Module* model, diff --git a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py index 7b24bb400b162..1213342004d48 100644 --- a/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py +++ b/orttraining/orttraining/python/training/onnxblock/_training_graph_utils.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import copy +import logging import os from typing import List, Optional, Set, Tuple, Union @@ -70,13 +71,16 @@ def _move_initializers_to_inputs(model: onnx.ModelProto, initializer_names: Opti def _gradient_model_for( model: onnx.ModelProto, requires_grad: Set[str], - output_names: List[str], loss_name: str, options: Optional[SessionOptions] = None, ) -> onnx.ModelProto: """Builds the gradient graph on top of the given input forward only graph.""" - builder = GradientGraphBuilder(model.SerializeToString(), set(output_names), requires_grad, loss_name, options) + logging.debug( + "The loss output is %s. The gradient graph will be built starting from %s_grad.", loss_name, loss_name + ) + + builder = GradientGraphBuilder(model.SerializeToString(), {loss_name}, requires_grad, loss_name, options) builder.build() return onnx.load_from_string(builder.get_model()) @@ -123,7 +127,7 @@ def build_gradient_graph( optimized_model = onnx.load_from_string(get_optimized_model(model.SerializeToString(), requires_grad, options)) # Assumption is that the first graph output is the loss output - gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names, output_names[0], options) + gradient_model = _gradient_model_for(optimized_model, requires_grad, output_names[0], options) _reorder_outputs(gradient_model, output_names, requires_grad) diff --git a/orttraining/orttraining/python/training/onnxblock/onnxblock.py b/orttraining/orttraining/python/training/onnxblock/onnxblock.py index 9f90a5a0c30cd..a2922353ac70e 100644 --- a/orttraining/orttraining/python/training/onnxblock/onnxblock.py +++ b/orttraining/orttraining/python/training/onnxblock/onnxblock.py @@ -205,6 +205,8 @@ def __call__(self, *args, **kwargs): model, self._requires_grad, self._frozen_params, output, accessor._GLOBAL_CUSTOM_OP_LIBRARY ) + logging.debug("Adding gradient accumulation nodes for training block %s", self.__class__.__name__) + _training_graph_utils.build_gradient_accumulation_graph(self._training_model, self._requires_grad) accessor._GLOBAL_ACCESSOR.model.CopyFrom(self._training_model) diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 150f41eaeccb1..59cf05bb082fc 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -18,7 +18,7 @@ from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed if not is_ortmodule_available(): - raise RuntimeError("ORTModule is not supported on this platform.") + raise ImportError("ORTModule is not supported on this platform.") def _defined_from_envvar(name, default_value, warn=True): diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 39cc6bdd11443..178d5db627888 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3039,6 +3039,204 @@ TEST(GradientCheckerTest, LeakyReluGrad) { UnaryOpGradientTest("LeakyRelu", kOnnxDomain, 16, nullptr, &transformer); } +#ifdef USE_CUDA +void ConvTransposeGradientCheckerTest(std::vector>* execution_providers) { + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"ConvTranspose"}; + + float error_tolerance = 3e-1f; + + // 1D convolution + { + TensorShape x_shape({2, 2, 5}); + TensorShape w_shape({2, 2, 3}); + TensorShape b_shape({2}); + TensorShape y_shape({2, 2, 5}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3}), MakeAttribute("pads", std::vector{1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D strided convolution + { + TensorShape x_shape({2, 1, 7}); + TensorShape w_shape({1, 1, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3}), MakeAttribute("pads", std::vector{1, 1}), + MakeAttribute("strides", std::vector{2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D pointwise convolution (with padding) + { + TensorShape x_shape({2, 1, 5}); + TensorShape w_shape({1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 3}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1}), MakeAttribute("pads", std::vector{1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 1D pointwise convolution (no padding) + { + TensorShape x_shape({2, 1, 5}); + TensorShape w_shape({1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1}), MakeAttribute("pads", std::vector{0, 0})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D convolution + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 3, 3}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D convolution + { + TensorShape x_shape({2, 1, 5, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5, 5}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D pointwise convolution (with padding) + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 1, 1}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1, 1}), + MakeAttribute("pads", std::vector{1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D pointwise convolution (no padding) + { + TensorShape x_shape({1, 1, 3, 3}); + TensorShape w_shape({1, 1, 1, 1}); + TensorShape b_shape({1}); + TensorShape y_shape({1, 1, 3, 3}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{1, 1}), + MakeAttribute("pads", std::vector{0, 0, 0, 0})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D strided convolution + { + TensorShape x_shape({2, 1, 7, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13, 9}); + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError( + op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1}), MakeAttribute("strides", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D dilated convolution (no padding) + { + TensorShape x_shape({2, 1, 5, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 9, 9}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{0, 0, 0, 0}), + MakeAttribute("dilations", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 2D dilated convolution (with padding) + { + TensorShape x_shape({2, 1, 7, 5}); + TensorShape w_shape({1, 1, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 9, 7}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1}), + MakeAttribute("dilations", std::vector{2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 3D convolution + { + TensorShape x_shape({2, 1, 5, 5, 5}); + TensorShape w_shape({1, 1, 3, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 5, 5, 5}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1, 1, 1})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } + + // 3D strided convolution + { + TensorShape x_shape({2, 1, 7, 5, 5}); + TensorShape w_shape({1, 1, 3, 3, 3}); + TensorShape b_shape({1}); + TensorShape y_shape({2, 1, 13, 9, 9}); + ASSERT_STATUS_OK( + gradient_checker.ComputeGradientError(op_def, {x_shape, w_shape, b_shape}, {y_shape}, &max_error, + {MakeAttribute("kernel_shape", std::vector{3, 3, 3}), + MakeAttribute("pads", std::vector{1, 1, 1, 1, 1, 1}), + MakeAttribute("strides", std::vector{2, 2, 2})}, + false, false, execution_providers)); + EXPECT_IS_TINIER_THAN(max_error, error_tolerance); + } +} + +TEST(GradientCheckerTest, ConvTransposeGrad) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + ConvTransposeGradientCheckerTest(&execution_providers); +} +#endif // USE_CUDA + } // namespace test } // namespace onnxruntime diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index b62e959556184..dc4c8ba75336b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -4002,6 +4002,7 @@ def forward(self, bool_argument, input1): ], ) def test_unused_parameters(model, none_pt_params): + torch.manual_seed(2333) device = "cuda" N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 # noqa: F841, N806 @@ -6096,7 +6097,6 @@ def run_step(model, input, positions): found_missing_inference_log = False for record in caplog.records: msg = record.getMessage() - print(msg) if "The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing" in msg: found_missing_inference_log = True break @@ -6205,3 +6205,167 @@ def run_step(model, x): _test_helpers.assert_values_are_close(pt_prediction, ort_prediction) _test_helpers.assert_values_are_close(pt_loss, ort_loss) _test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad) + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("use_fp16", [False, True]) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient(use_fp16, conv_algo_search): + class ChainedTransposedConv(nn.Module): + def __init__(self): + super().__init__() + + # Transposed Convolution 1D + self.conv1d_transpose = nn.ConvTranspose1d( + in_channels=4, out_channels=2, kernel_size=3, stride=2, padding=1 + ) + self.relu1 = nn.ReLU() + + # Transposed Convolution 2D + self.conv2d_transpose = nn.ConvTranspose2d( + in_channels=2, out_channels=3, kernel_size=3, stride=2, padding=1 + ) + self.relu2 = nn.ReLU() + + # Transposed Convolution 3D + self.conv3d_transpose = nn.ConvTranspose3d( + in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1 + ) + self.relu3 = nn.ReLU() + + def forward(self, x): + out1d = self.relu1(self.conv1d_transpose(x)) + out2d = self.relu2(self.conv2d_transpose(out1d.unsqueeze(2))) + out3d = self.relu3(self.conv3d_transpose(out2d.unsqueeze(2))) + return out3d.squeeze(2) + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + with amp.autocast(use_fp16): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv1d_transpose.weight.grad, + model.conv1d_transpose.bias.grad, + model.conv2d_transpose.weight.grad, + model.conv2d_transpose.bias.grad, + model.conv3d_transpose.weight.grad, + model.conv3d_transpose.bias.grad, + ) + + device = "cuda" + pt_model = ChainedTransposedConv().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(1, 4, 8, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + if use_fp16: + assert torch.allclose(pt_grad, ort_grad, atol=1e-3, rtol=1e-3) + else: + assert torch.allclose(pt_grad, ort_grad) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient_with_groups(conv_algo_search): + class TransposedConv3DWithGroups(nn.Module): + def __init__(self): + super().__init__() + # in_channels, out_channels, kernel_size, stride, padding + self.conv_transpose = nn.ConvTranspose3d( + in_channels=6, out_channels=4, kernel_size=3, stride=2, padding=1, groups=2 + ) + + def forward(self, x): + return self.conv_transpose(x) + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv_transpose.weight.grad, + model.conv_transpose.bias.grad, + ) + + device = "cuda" + pt_model = TransposedConv3DWithGroups().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + pt_x = torch.randn(1, 6, 8, 16, 16, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + assert torch.allclose(pt_grad, ort_grad) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] + + +@pytest.mark.skipif( + os.getenv("ORTMODULE_ROCM_TEST", "0") == "1", reason="Skip for ROCm because the kernel is not implemented for ROCm" +) +@pytest.mark.parametrize("conv_algo_search", [None, "EXHAUSTIVE", "HEURISTIC"]) +def test_conv_transpose_gradient_with_strides_padding_and_dilation(conv_algo_search): + class ConvTransposeComplexModel(nn.Module): + def __init__(self): + super().__init__() + self.conv_transpose = nn.ConvTranspose3d( + 16, 33, (3, 5, 2), stride=(2, 1, 1), padding=(0, 4, 2), dilation=(1, 2, 1) + ) + self.param = nn.Parameter(torch.randn(20, 33, 21, 50, 97)) + + def forward(self, x): + return self.conv_transpose(x) * self.param + + if conv_algo_search is not None: + os.environ["ORTMODULE_CONV_ALGO_SEARCH"] = conv_algo_search + + def run_step(model, x): + loss = model(x).sum() + loss.backward() + + return ( + x.grad, + model.conv_transpose.weight.grad, + model.conv_transpose.bias.grad, + ) + + device = "cuda" + pt_model = ConvTransposeComplexModel().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)).to(device) + + pt_x = torch.randn(20, 16, 10, 50, 100, requires_grad=True, device=device) + ort_x = copy.deepcopy(pt_x) + + pt_grads = run_step(pt_model, pt_x) + ort_grads = run_step(ort_model, ort_x) + + for pt_grad, ort_grad in zip(pt_grads, ort_grads): + assert torch.allclose(pt_grad, ort_grad, atol=1e-2, rtol=1e-2) + + if conv_algo_search is not None: + del os.environ["ORTMODULE_CONV_ALGO_SEARCH"] diff --git a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc index 4fa3844717ef9..1369c9c69865a 100644 --- a/orttraining/orttraining/test/training_api/core/checkpoint_test.cc +++ b/orttraining/orttraining/test/training_api/core/checkpoint_test.cc @@ -331,9 +331,12 @@ TEST(CheckpointApiTest, SaveOptimizerStateAsCheckpoint_ThenLoad) { #if defined(USE_CUDA) providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif - auto model = std::make_unique(model_uri, &state, session_option, + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + auto model = std::make_unique(model_identifier, &state, session_option, *env, providers); - auto optimizer = std::make_unique(optim_uri, &state, session_option, + auto optimizer = std::make_unique(model_identifier, &state, session_option, *env, providers); // Remove the temporary directory if it already exists. diff --git a/orttraining/orttraining/test/training_api/core/training_api_tests.cc b/orttraining/orttraining/test/training_api/core/training_api_tests.cc index ec0c7a1968ba4..2170f7957e6a6 100644 --- a/orttraining/orttraining/test/training_api/core/training_api_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_api_tests.cc @@ -76,9 +76,12 @@ void TestModuleExport(const std::vector>& pr std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(training_model_uri), + std::optional(onnxruntime::ToUTF8String(eval_model_uri)), + std::nullopt); auto model = std::make_unique( - ToUTF8String(training_model_uri), &state, onnxruntime::SessionOptions(), - *env, providers, ToUTF8String(eval_model_uri)); + model_identifier, &state, onnxruntime::SessionOptions(), + *env, providers); auto test_dir = ORT_TSTR("export_model_for_inferencing_test_dir"); if (Env::Default().FolderExists(test_dir)) { @@ -141,7 +144,9 @@ TEST(TrainingApiTest, ModuleParametersSize) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifiers = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, std::nullopt); + auto model = std::make_unique(model_identifiers, &state, session_option, *env, std::vector>()); size_t params_size = 0; @@ -164,7 +169,10 @@ TEST(TrainingApiTest, ModuleCopyBufferToParameters) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); int64_t params_size = static_cast(model->GetParametersSize()); @@ -202,7 +210,10 @@ TEST(TrainingApiTest, ModuleTrainStep) { onnxruntime::SessionOptions session_option; std::unique_ptr env; ASSERT_STATUS_OK(Environment::Create(nullptr, env)); - auto model = std::make_unique(ToUTF8String(model_uri), + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::nullopt); + auto model = std::make_unique(model_identifier, &state, session_option, *env, std::vector>()); ASSERT_EQ(model->GetTrainingModelOutputCount(), 1); @@ -274,8 +285,12 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // Load state dict from faked optimizer checkpoint state. @@ -285,7 +300,7 @@ TEST(TrainingApiTest, OptimizerCreatedWithOptimizerCheckpointState) { {"momentum0", "momentum1"}, external_optimizer_checkpoint_state)); std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &new_state, session_option, *env, providers); + model_identifier, &new_state, session_option, *env, providers); ASSERT_TRUE(optim.get() != nullptr); } @@ -320,8 +335,12 @@ void TestLRSchduler(const std::basic_string& test_file_name, ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); + std::shared_ptr model = std::make_shared( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; @@ -351,7 +370,7 @@ void TestLRSchduler(const std::basic_string& test_file_name, } std::shared_ptr optim = std::make_shared( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); // KNOWN ISSUE: LinearLRScheduler by default use optim's states to calculate the first step's learning rate. @@ -445,11 +464,15 @@ TEST(TrainingApiTest, OptimStep) { providers.push_back(onnxruntime::test::DefaultCudaExecutionProvider()); #endif ASSERT_STATUS_OK(Environment::Create(nullptr, env)); + + auto model_identifier = ModelIdentifiers(onnxruntime::ToUTF8String(model_uri), + std::nullopt, + std::optional(onnxruntime::ToUTF8String(optim_uri))); auto model = std::make_unique( - ToUTF8String(model_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); auto optim = std::make_unique( - ToUTF8String(optim_uri), &state, session_option, + model_identifier, &state, session_option, *env, providers); OrtValue input, target; diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index e864f3b8632de..d734be8e3474b 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "gtest/gtest.h" +#include "gmock/gmock.h" #include "onnxruntime_c_api.h" #include "onnxruntime_training_c_api.h" @@ -16,6 +17,7 @@ namespace onnxruntime::training::test { #define MODEL_FOLDER ORT_TSTR("testdata/training_api/") +#define ORT_FORMAT_MODEL_FOLDER ORT_TSTR("testdata/training_api/ort_format/") TEST(TrainingCApiTest, SaveCheckpoint) { auto model_uri = MODEL_FOLDER "training_model.onnx"; @@ -220,4 +222,100 @@ TEST(TrainingCApiTest, RegisterCustomOps) { ASSERT_TRUE(loss.front().IsTensor()); } +TEST(TrainingCApiTest, LoadModelsAndCreateSession) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + model_path); +} + +TEST(TrainingCApiTest, LoadModelsAndCreateSession_ORTFormat) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_train_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_path, + eval_train_model_path, + optimizer_model_path); +} + +TEST(TrainingCApiTest, LoadONNXModelsFromBuffer) { + auto model_path = MODEL_FOLDER "training_model.onnx"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(model_path, model_data_len)); + std::vector train_model_data(model_data_len); + std::ifstream bytes_stream(model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); +} + +TEST(TrainingCApiTest, LoadORTFormatModelsFromBuffer) { + auto train_model_path = ORT_FORMAT_MODEL_FOLDER "training_model.ort"; + auto eval_model_path = ORT_FORMAT_MODEL_FOLDER "eval_model.ort"; + auto optimizer_model_path = ORT_FORMAT_MODEL_FOLDER "optimizer_model.ort"; + size_t model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(train_model_path, model_data_len)); + std::vector train_model_data(model_data_len); + { + std::ifstream bytes_stream(train_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(train_model_data.data()), model_data_len); + ASSERT_TRUE(train_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(eval_model_path, model_data_len)); + std::vector eval_model_data(model_data_len); + { + std::ifstream bytes_stream(eval_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(eval_model_data.data()), model_data_len); + ASSERT_TRUE(eval_model_data.size() == model_data_len); + } + + model_data_len = 0; + ASSERT_STATUS_OK(Env::Default().GetFileLength(optimizer_model_path, model_data_len)); + std::vector optimizer_model_data(model_data_len); + { + std::ifstream bytes_stream(optimizer_model_path, std::ifstream::in | std::ifstream::binary); + bytes_stream.read(reinterpret_cast(optimizer_model_data.data()), model_data_len); + ASSERT_TRUE(optimizer_model_data.size() == model_data_len); + } + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(ORT_FORMAT_MODEL_FOLDER "checkpoint"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), + checkpoint_state, train_model_data, + eval_model_data, optimizer_model_data); +} + +TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + + try { + std::vector train_model_data; + Ort::TrainingSession training_session = Ort::TrainingSession(env, + Ort::SessionOptions(), + checkpoint_state, + train_model_data); + } catch (const std::exception& ex) { + ASSERT_THAT(ex.what(), + testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); + } +} } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc new file mode 100644 index 0000000000000..18c5ff94373dc --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/conv_transpose_grad_test.cc @@ -0,0 +1,360 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime::contrib::test { + +using namespace onnxruntime::test; + +#if USE_CUDA +namespace { + +struct ConvTransposeGradOpAttributes { + std::vector dilations; + int64_t group; + std::vector kernel_shape; + std::vector pads; + std::vector strides; +}; + +void TestConvTransposeGradOp(const ConvTransposeGradOpAttributes& attributes, + const std::vector>& inputs, + const std::vector>& input_shapes, + const std::vector>& outputs, + const std::vector>& output_shapes, + bool is_half = false) { + OpTester test("ConvTransposeGrad", 1, kMSDomain); + test.AddAttribute("group", attributes.group); + test.AddAttribute("kernel_shape", attributes.kernel_shape); + test.AddAttribute("pads", attributes.pads); + + if (!attributes.dilations.empty()) { + test.AddAttribute("dilations", attributes.dilations); + } + + if (!attributes.strides.empty()) { + test.AddAttribute("strides", attributes.strides); + } + + if (is_half) { + std::vector dY_half(inputs[0].size()); + ConvertFloatToMLFloat16(inputs[0].data(), dY_half.data(), static_cast(inputs[0].size())); + test.AddInput("dY", input_shapes[0], dY_half); + + std::vector X_half(inputs[1].size()); + ConvertFloatToMLFloat16(inputs[1].data(), X_half.data(), static_cast(inputs[1].size())); + test.AddInput("X", input_shapes[1], X_half); + + std::vector W_half(inputs[2].size()); + ConvertFloatToMLFloat16(inputs[2].data(), W_half.data(), static_cast(inputs[2].size())); + test.AddInput("W", input_shapes[2], W_half); + + std::vector dX_half(outputs[0].size()); + ConvertFloatToMLFloat16(outputs[0].data(), dX_half.data(), static_cast(outputs[0].size())); + test.AddOutput("dX", output_shapes[0], dX_half); + + std::vector dW_half(outputs[1].size()); + ConvertFloatToMLFloat16(outputs[1].data(), dW_half.data(), static_cast(outputs[1].size())); + test.AddOutput("dW", output_shapes[1], dW_half); + + if (outputs.size() >= 3) { + std::vector dB_half(outputs[2].size()); + ConvertFloatToMLFloat16(outputs[2].data(), dB_half.data(), static_cast(outputs[2].size())); + test.AddOutput("dB", output_shapes[2], dB_half); + } + } else { + test.AddInput("dY", input_shapes[0], inputs[0]); + test.AddInput("X", input_shapes[1], inputs[1]); + test.AddInput("W", input_shapes[2], inputs[2]); + + test.AddOutput("dX", output_shapes[0], outputs[0]); + test.AddOutput("dW", output_shapes[1], outputs[1]); + + if (outputs.size() >= 3) { + test.AddOutput("dB", output_shapes[2], outputs[2]); + } + } + + test.Run(); +} + +} // namespace + +TEST(ConvTransposeGradTest, ConvTranspose1DDefaultAttributes) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1}, // dilations + 1, // group + std::vector{2}, // kernel_shape + std::vector{0, 0}, // pads + std::vector{1}, // strides + }; + + std::vector dY(12, 1.0f); + std::vector dY_shape = {1, 2, 6}; + std::vector X = {0.1868f, -0.1679f, 1.2677f, 2.1288f, -0.0331f, + 1.0454f, 0.7722f, 0.2963f, -0.8684f, -0.0547f}; + std::vector X_shape = {1, 2, 5}; + std::vector W = {0.0847f, -0.0066f, + 0.1212f, 0.2317f, + -0.4975f, 0.2762f, + -0.2644f, 0.3210f}; + std::vector W_shape = {2, 2, 2}; + std::vector dX = {0.4309f, 0.4309f, 0.4309f, 0.4309f, 0.4309f, + -0.1647f, -0.1647f, -0.1647f, -0.1647f, -0.1647f}; + std::vector dX_shape = X_shape; + std::vector dW = {3.3823f, 3.3823f, + 3.3823f, 3.3823f, + 1.1908f, 1.1908f, + 1.1908f, 1.1908f}; + std::vector dW_shape = W_shape; + std::vector dB = {6.f, 6.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose1DStrideAndPadding) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1}, // dilations + 1, // group + std::vector{2}, // kernel_shape + std::vector{2, 2}, // pads + std::vector{2}, // strides + }; + + std::vector dY(12, 1.0f); + std::vector dY_shape = {1, 2, 6}; + std::vector X = {-0.0254f, -1.4303f, -0.1568f, 1.2318f, -0.8365f, + 2.0836f, -1.0181f, -0.7539f, 0.4484f, -0.5799f}; + std::vector X_shape = {1, 2, 5}; + std::vector W = {-0.1438f, 0.2386f, + -0.3085f, 0.1149f, + -0.1653f, -0.0707f, + -0.1479f, -0.0918f}; + std::vector W_shape = {2, 2, 2}; + std::vector dX = {0.0000f, -0.0988f, -0.0988f, -0.0988f, 0.0000f, + 0.0000f, -0.4757f, -0.4757f, -0.4757f, 0.0000f}; + std::vector dX_shape = X_shape; + std::vector dW = {-0.3553f, -0.3553f, + -0.3553f, -0.3553f, + -1.3236f, -1.3236f, + -1.3236f, -1.3236f}; + std::vector dW_shape = W_shape; + std::vector dB = {6.f, 6.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose1D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2}, // dilations + 2, // group + std::vector{3}, // kernel_shape + std::vector{2, 2}, // pads + std::vector{2}, // strides + }; + + std::vector dY(38, 1.0f); + std::vector dY_shape = {1, 2, 19}; + std::vector X = {0.2816f, 1.4660f, 0.1002f, -0.2460f, -0.1027f, 0.1228f, -0.8516f, -1.0246f, -0.6576f, -1.0280f, + 0.1093f, 0.1447f, 1.1279f, 0.1085f, -0.3438f, -0.6224f, -0.0902f, 2.2791f, -2.1910f, 1.9736f}; + std::vector X_shape = {1, 2, 10}; + std::vector W = {-0.1050f, -0.0622f, -0.3632f, + -0.3861f, -0.0134f, -0.0277f}; + std::vector W_shape = {2, 1, 3}; + std::vector dX = {-0.4254f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.5304f, -0.1672f, + -0.0411f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.4272f, -0.3995f}; + std::vector dX_shape = X_shape; + std::vector dW = {-2.2215f, -1.9400f, -0.9120f, + 2.3863f, 2.4956f, 0.5220f}; + std::vector dW_shape = W_shape; + std::vector dB = {19.f, 19.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose2DDefaultAttributes) { + ConvTransposeGradOpAttributes attrs = { + std::vector{1, 1}, // dilations + 1, // group + std::vector{3, 3}, // kernel_shape + std::vector{0, 0, 0, 0}, // pads + std::vector{1, 1}, // strides + }; + + std::vector dY(98, 1.0f); + std::vector dY_shape = {1, 2, 7, 7}; + std::vector X = {1.1371f, -0.1498f, -1.7541f, -0.7585f, 1.6009f, -0.7496f, 0.1535f, -0.2533f, -1.0811f, 0.9760f, + -0.2528f, 0.1820f, -1.7450f, 0.1632f, -0.3469f, 1.1150f, -2.6888f, -0.1632f, -0.3269f, 0.6904f, + 1.3036f, 0.7883f, 0.4459f, 0.1223f, 0.1576f, -0.8187f, 0.2281f, 1.5320f, 1.2643f, -0.5163f, + 1.0677f, -0.2141f, 1.2992f, -2.1865f, -0.6346f, 0.8938f, 0.8346f, -2.7397f, 0.9223f, 0.8166f, + 1.1736f, -1.3644f, 0.0316f, -1.2904f, 0.7062f, 0.2470f, 0.4559f, 0.8493f, 1.0519f, 0.9915f}; + std::vector X_shape = {1, 2, 5, 5}; + std::vector W = {0.0761f, 0.0270f, -0.1677f, 0.1803f, -0.0824f, -0.0285f, + 0.2098f, -0.0569f, -0.1514f, 0.0338f, -0.1962f, -0.2169f, + 0.0432f, -0.1977f, -0.0814f, -0.1866f, -0.1574f, -0.0198f, + 0.0097f, 0.0019f, -0.1204f, 0.2018f, -0.1750f, -0.0549f, + -0.0687f, -0.1269f, 0.1913f, 0.1331f, -0.0632f, 0.0821f, + 0.0127f, 0.1761f, -0.0883f, -0.1370f, 0.1472f, 0.0690f}; + std::vector W_shape = {2, 2, 3, 3}; + std::vector dX = {-0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, + -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, + -0.9725f, -0.9725f, -0.9725f, -0.9725f, -0.9725f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, + 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, + 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f, 0.1905f}; + std::vector dX_shape = X_shape; + std::vector dW = {-1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, -1.4343f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, + 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f, 4.6009f}; + std::vector dW_shape = W_shape; + std::vector dB = {49.f, 49.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose2D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2, 2}, // dilations + 2, // group + std::vector{3, 3}, // kernel_shape + std::vector{2, 2, 2, 2}, // pads + std::vector{2, 2}, // strides + }; + + std::vector dY(162U, 1.0f); + std::vector dY_shape = {1, 2, 9, 9}; + std::vector X = {-1.0158f, 0.1709f, -0.1660f, 0.3881f, 0.4017f, 1.5497f, 1.1205f, 0.2553f, -0.4359f, -0.0467f, + 1.1374f, -0.0713f, 0.2248f, 0.8915f, -0.7239f, 0.1679f, -1.5604f, -0.8521f, 0.8966f, 3.3743f, + -0.5516f, 0.2516f, -0.4091f, -0.9868f, 0.3008f, 1.1066f, -0.7039f, -1.5273f, -0.3666f, 0.9392f, + 0.1264f, -1.6604f, -1.4810f, 0.6654f, -0.2007f, -1.0660f, -0.5420f, -0.7030f, 0.0411f, 2.1082f, + -0.7995f, 0.2422f, 1.2848f, -0.1747f, 1.7935f, -0.1123f, -0.6668f, -2.2383f, 1.5419f, -2.7614f}; + std::vector X_shape = {1, 2, 5, 5}; + std::vector W = {-0.2057f, -0.0411f, 0.0277f, 0.2221f, 0.1901f, 0.1435f, + -0.2249f, 0.3299f, -0.2203f, -0.1013f, -0.3326f, 0.1005f, + -0.0536f, 0.3067f, 0.3297f, 0.2728f, 0.1649f, -0.2548f}; + std::vector W_shape = {2, 1, 3, 3}; + std::vector dX = {0.4431f, 0.4403f, 0.4403f, 0.4403f, 0.5171f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, + 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, 0.4297f, 0.2212f, 0.2212f, 0.2212f, 0.2704f, + 0.3202f, 0.3366f, 0.3366f, 0.3366f, 0.1654f, 0.5465f, 0.7658f, 0.7658f, 0.7658f, 0.6908f, + 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, + 0.3144f, 0.4323f, 0.4323f, 0.4323f, 0.2569f, 0.4043f, 0.2494f, 0.2494f, 0.2494f, -0.1808f}; + std::vector dX_shape = X_shape; + std::vector dW = {2.2293f, 4.5327f, 1.6281f, 3.0240f, 4.3115f, 1.0052f, + 3.8675f, 5.7067f, 2.7011f, -2.7512f, -4.6026f, -5.5423f, + -4.4098f, -5.1546f, -7.0335f, -0.2852f, -0.9177f, -5.5580f}; + std::vector dW_shape = W_shape; + std::vector dB = {81.f, 81.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} + +TEST(ConvTransposeGradTest, ConvTranspose3D) { + ConvTransposeGradOpAttributes attrs = { + std::vector{2, 2, 2}, // dilations + 2, // group + std::vector{2, 2, 2}, // kernel_shape + std::vector{2, 2, 2, 2, 2, 2}, // pads + std::vector{2, 2, 2}, // strides + }; + + std::vector dY(250U, 1.0f); + std::vector dY_shape = {1, 2, 5, 5, 5}; + std::vector X = {-0.2396f, 0.4280f, -1.3505f, -0.4366f, -1.3296f, 0.3531f, 0.0645f, -1.5480f, + -1.7464f, -0.9160f, 1.5065f, -0.0788f, 0.0487f, 2.4641f, 0.3855f, 2.0499f, + 0.7068f, -0.8076f, -0.4442f, 0.1003f, -0.5056f, -0.1430f, -0.3744f, -0.2637f, + -1.1012f, 1.0213f, 0.0503f, 0.0147f, -0.3664f, 0.8834f, -1.1478f, -0.8221f, + -0.5649f, -0.4224f, -0.6779f, -0.9363f, 1.1972f, 0.2094f, 0.5676f, -0.2718f, + -0.1678f, -0.4178f, -0.4672f, 0.2777f, -0.7953f, -0.5603f, -2.8694f, 1.5743f, + -0.5057f, -0.2529f, 0.5894f, -0.3980f, -0.6719f, -0.3425f, 0.0821f, 0.8672f, + 0.7218f, 1.5519f, 1.6513f, -1.1956f, 0.8471f, 0.4295f, -1.3917f, -1.2202f, + 0.1054f, -2.2191f, -0.9546f, 1.1750f, -2.3637f, 1.6297f, -0.5796f, 0.3850f, + 0.9287f, -0.3492f, -0.7284f, 0.2987f, -0.7534f, 0.7747f, -1.3198f, -0.3633f, + 1.8635f, -0.3187f, 0.9032f, -0.6083f, -0.4236f, -0.1929f, -1.1715f, -0.5591f, + -1.8290f, -1.1503f, 0.1430f, 0.6048f, -0.3148f, 1.0638f, -0.2946f, -0.4990f, + -1.4443f, -0.7757f, -1.5374f, -0.4567f, -0.2998f, 0.0521f, 1.6293f, -0.6720f, + -0.0102f, -0.6598f, 0.5005f, 0.4203f, 1.3911f, 1.5988f, 0.3991f, 1.4931f, + 0.9741f, 0.3557f, 0.1088f, -1.1806f, 1.1115f, -1.3283f, 1.7235f, 0.4177f, + 0.7992f, -1.7248f, -0.5339f, -0.3153f, 0.1379f, 0.7493f, 0.3028f, -0.9473f}; + std::vector X_shape = {1, 2, 4, 4, 4}; + std::vector W = {-0.1093f, -0.0511f, 0.1132f, 0.3369f, -0.3531f, -0.1766f, 0.0628f, 0.2118f, + 0.3068f, 0.3217f, -0.2903f, -0.1633f, -0.3261f, -0.0990f, 0.2497f, -0.1553f}; + std::vector W_shape = {2, 1, 2, 2, 2}; + std::vector dX = {0.2118f, 0.2746f, 0.2746f, 0.0628f, 0.0352f, -0.2550f, -0.2550f, -0.2902f, + 0.0352f, -0.2550f, -0.2550f, -0.2902f, -0.1766f, -0.5297f, -0.5297f, -0.3531f, + 0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f, + 0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f, + 0.5487f, 0.7247f, 0.7247f, 0.1760f, 0.3210f, 0.0346f, 0.0346f, -0.2864f, + 0.3210f, 0.0346f, 0.0346f, -0.2864f, -0.2277f, -0.6901f, -0.6901f, -0.4624f, + 0.3369f, 0.4501f, 0.4501f, 0.1132f, 0.2858f, 0.2897f, 0.2897f, 0.0038f, + 0.2858f, 0.2897f, 0.2897f, 0.0038f, -0.0511f, -0.1604f, -0.1604f, -0.1093f, + -0.1553f, 0.0944f, 0.0944f, 0.2497f, -0.2542f, -0.3307f, -0.3307f, -0.0765f, + -0.2542f, -0.3307f, -0.3307f, -0.0765f, -0.0990f, -0.4251f, -0.4251f, -0.3261f, + -0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f, + -0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f, + -0.3185f, -0.3592f, -0.3592f, -0.0407f, -0.0958f, -0.1557f, -0.1557f, -0.0600f, + -0.0958f, -0.1557f, -0.1557f, -0.0600f, 0.2227f, 0.2035f, 0.2035f, -0.0193f, + -0.1633f, -0.4536f, -0.4536f, -0.2903f, 0.1584f, 0.1749f, 0.1749f, 0.0165f, + 0.1584f, 0.1749f, 0.1749f, 0.0165f, 0.3217f, 0.6285f, 0.6285f, 0.3068f}; + std::vector dX_shape = X_shape; + std::vector dW = {-2.3068f, -2.1096f, -0.4322f, 0.4820f, 1.5420f, -4.1569f, -4.9628f, -5.5716f, + 1.0492f, 1.6683f, -6.3262f, -3.2359f, 2.4532f, -2.3299f, -5.1917f, -9.2525f}; + std::vector dW_shape = W_shape; + std::vector dB = {125.f, 125.f}; + std::vector dB_shape = {2}; + + for (const bool is_half : {false, true}) + TestConvTransposeGradOp( + attrs, // attributes + {dY, X, W}, // inputs + {dY_shape, X_shape, W_shape}, // input shapes + {dX, dW, dB}, // outputs + {dX_shape, dW_shape, dB_shape}, // output shapes + is_half); +} +#endif // USE_CUDA + +} // namespace onnxruntime::contrib::test diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index b3042c449a50b..0af737074964d 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -190,7 +190,29 @@ struct OrtTrainingApi { ORT_API2_STATUS(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + + /** \brief Create a training session that can be used to begin or resume training. + * This api provides a way to load all the training artifacts from buffers instead of files. + * + * \param[in] env Environment to be used for the training session. + * \param[in] options Session options that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing the model data to be used to perform training + * \param[in] train_data_length Length of the buffer containing train_model_data + * \param[in] eval_model_data Buffer containing the model data to be used to perform evaluation + * \param[in] eval_data_length Length of the buffer containing eval_model_data + * \param[in] optim_model_data Buffer containing the model data to be used to perform weight update + * \param[in] optim_data_length Length of the buffer containing optim_model_data + * \param[out] out Created training session. + * + */ + ORT_API2_STATUS(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); /// @} diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 5bfdfcc74e817..0edef20ba6da8 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -176,6 +176,20 @@ class TrainingSession : public detail::Base { const std::optional>& eval_model_path = std::nullopt, const std::optional>& optimizer_model_path = std::nullopt); + /** \brief Create a training session that can be used to begin or resume training. + * This constructor allows the users to load the models from buffers instead of files. + * + * \param[in] env Env to be used for the training session. + * \param[in] session_options SessionOptions that the user can customize for this training session. + * \param[in] checkpoint_state Training states that the training session uses as a starting point for training. + * \param[in] train_model_data Buffer containing training model data. + * \param[in] eval_model_data Buffer containing evaluation model data. + * \param[in] optim_model_data Buffer containing optimizer model (used for performing weight/parameter update). + * + */ + TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state, + const std::vector& train_model_data, const std::vector& eval_model_data = {}, + const std::vector& optim_model_data = {}); /// @} /// \name Implementing The Training Loop diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index 393e5b01f7f85..066147708863f 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -24,6 +24,23 @@ inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& se ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); } +inline TrainingSession::TrainingSession(const Env& env, const SessionOptions& session_options, + CheckpointState& checkpoint_state, + const std::vector& train_model_data, + const std::vector& eval_model_data, + const std::vector& optim_model_data) { + ThrowOnError(GetTrainingApi().CreateTrainingSessionFromBuffer( + env, session_options, checkpoint_state, + train_model_data.data(), train_model_data.size(), + eval_model_data.data(), eval_model_data.size(), + optim_model_data.data(), optim_model_data.size(), + &p_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetTrainingModelOutputCount(p_, &training_model_output_count_)); + + ThrowOnError(GetTrainingApi().TrainingSessionGetEvalModelOutputCount(p_, &eval_model_output_count_)); +} + inline std::vector TrainingSession::TrainStep(const std::vector& input_values) { std::vector output_values; output_values.reserve(training_model_output_count_); diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index 29300bbb7e8ec..d1775e358163c 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -12,7 +12,6 @@ #include "core/graph/graph_utils.h" #include "orttraining/training_api/checkpoint.h" -#include "orttraining/training_api/utils.h" using namespace onnxruntime; @@ -150,12 +149,11 @@ Status Parameter::ResetGrad() { return Status::OK(); } -Module::Module(const std::string& train_model_path_or_bytes, +Module::Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes, [[maybe_unused]] gsl::span op_domains) : state_{state} { // Enforce weight prepacking is disabled @@ -176,7 +174,12 @@ Module::Module(const std::string& train_model_path_or_bytes, } #endif - ORT_THROW_IF_ERROR(train_sess_->Load(train_model_path_or_bytes)); + // Load the training model + ORT_THROW_IF_ERROR(std::holds_alternative(model_identifiers.train_model) + ? train_sess_->Load(std::get(model_identifiers.train_model)) + : train_sess_->Load(std::get>(model_identifiers.train_model).data(), + static_cast(std::get>(model_identifiers.train_model).size()))); + for (const auto& provider : providers) { ORT_THROW_IF_ERROR(train_sess_->RegisterExecutionProvider(provider)); } @@ -239,7 +242,6 @@ Module::Module(const std::string& train_model_path_or_bytes, // Copy ortvalue buffer from CPU to target_device for this "param_name" (based on graph partitioning) // Only copies data if the target device is not the same as the current device the buffer is placed on - OrtValue& param_data = params_iter->second->Data(); ORT_ENFORCE(param_data.IsTensor()); const Tensor& param_data_tensor = param_data.Get(); @@ -278,47 +280,57 @@ Module::Module(const std::string& train_model_path_or_bytes, } } - if (eval_model_path_or_bytes.has_value()) { + if (model_identifiers.IsEvalModelAvailable()) { eval_sess_ = std::make_unique(session_options, env); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) if (!op_domains.empty()) { ORT_THROW_IF_ERROR(eval_sess_->AddCustomOpDomains(op_domains)); } #endif - - ORT_THROW_IF_ERROR(eval_sess_->Load(eval_model_path_or_bytes.value())); - for (const auto& provider : providers) { - ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); - } - ORT_THROW_IF_ERROR(eval_sess_->Initialize()); - utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); - - // Eval model validation - // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval - // graphs, and all the weights present in both graphs match. - // TODO: Add the checks instead of making assumptions?? - InlinedVector eval_user_input_names, eval_param_input_names; - for (const auto& input_name : eval_input_names_) { - if (state_->module_checkpoint_state.named_parameters.find(input_name) != - state_->module_checkpoint_state.named_parameters.end()) { - // it is a parameter - eval_param_input_names.emplace_back(input_name); - continue; - } else { - // It is user input. We handle user inputs separately in the eval - // because the eval graph might have different user inputs. - // Eg if loss is not a part of the eval graph, it won't have - // certain inputs like targets - eval_user_input_names.emplace_back(input_name); - } + if (std::holds_alternative>(model_identifiers.eval_model)) { + ORT_THROW_IF_ERROR(eval_sess_->Load(std::get>(model_identifiers.eval_model).value())); + } else { + auto model_data = std::get>(model_identifiers.eval_model); + ORT_THROW_IF_ERROR(eval_sess_->Load(model_data.data(), static_cast(model_data.size()))); } - eval_input_names_ = eval_user_input_names; - eval_user_input_count_ = eval_user_input_names.size(); - eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + } else { + return; + } - // Keep a copy of the eval model path to be able to later export the model for inferencing. - // The inference model will be reconstructed from the eval model. - eval_model_path_ = eval_model_path_or_bytes.value(); + for (const auto& provider : providers) { + ORT_THROW_IF_ERROR(eval_sess_->RegisterExecutionProvider(provider)); + } + ORT_THROW_IF_ERROR(eval_sess_->Initialize()); + utils::GetGraphInputOutputNames(eval_sess_, eval_input_names_, eval_output_names_); + + // Eval model validation + // We are making certain assumptions: Like the order in which parameters occur will be same between train and eval + // graphs, and all the weights present in both graphs match. + // TODO(askhade): Add the checks instead of making assumptions?? + InlinedVector eval_user_input_names, eval_param_input_names; + for (const auto& input_name : eval_input_names_) { + if (state_->module_checkpoint_state.named_parameters.find(input_name) != + state_->module_checkpoint_state.named_parameters.end()) { + // it is a parameter + eval_param_input_names.emplace_back(input_name); + continue; + } else { + // It is user input. We handle user inputs separately in the eval + // because the eval graph might have different user inputs. + // Eg if loss is not a part of the eval graph, it won't have + // certain inputs like targets + eval_user_input_names.emplace_back(input_name); + } + } + eval_input_names_ = eval_user_input_names; + eval_user_input_count_ = eval_user_input_names.size(); + eval_input_names_.insert(eval_input_names_.end(), eval_param_input_names.begin(), eval_param_input_names.end()); + + // Keep a copy of the eval model path to be able to later export the model for inferencing. + // The inference model will be reconstructed from the eval model. + // TODO(askhade): Find a fix to export model for inference when the eval model is loaded from a buffer. + if (std::holds_alternative>(model_identifiers.eval_model)) { + eval_model_path_ = std::get>(model_identifiers.eval_model); } } @@ -486,14 +498,14 @@ Status Module::EvalStep(const std::vector& inputs, std::vector graph_output_names) const { - ORT_RETURN_IF(!eval_sess_ || eval_model_path_.empty(), + ORT_RETURN_IF(!eval_sess_ || !eval_model_path_.has_value(), "Eval model was not provided. Cannot export a model for inferencing."); ONNX_NAMESPACE::ModelProto eval_model; - ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_), eval_model)); + ORT_THROW_IF_ERROR(Model::Load(ToPathString(eval_model_path_.value()), eval_model)); // Clone the eval mode into an inference onnxruntime::Model. std::shared_ptr inference_model; diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index 9013ab22c124f..adb633343263e 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/session/inference_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { @@ -73,12 +75,12 @@ struct Module { public: // Initialize a module from an ORT inference session with loaded // training ONNX model and load parameters - Module(const std::string& train_model_path_or_bytes, + // The model and checkpoint state can be provided as a file path or a byte array + Module(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, - const std::optional& eval_model_path_or_bytes = std::nullopt, gsl::span op_domains = gsl::span()); // Return the trainable/nontrainable parameters @@ -159,7 +161,7 @@ struct Module { CheckpointState* state_; // Non owning pointer to the state. bool accumulate_gradient_ = false; - std::string eval_model_path_; + std::optional eval_model_path_; size_t train_user_input_count_{0U}; size_t eval_user_input_count_{0U}; }; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index b84009e7f3591..6693bba348648 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -13,6 +13,8 @@ #include "orttraining/training_api/ort_training_apis.h" #include "orttraining/training_api/training_session.h" +using namespace onnxruntime::training::api; + namespace { std::vector> CreateProviders( @@ -26,44 +28,85 @@ std::vector> CreateProviders( return execution_providers; } +static OrtStatus* CreateSessionAndLoadModel(_In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, + _Inout_ OrtCheckpointState* checkpoint_state, + const ModelIdentifiers& model_identifiers, + std::unique_ptr& train_sess) { + auto chkpt_state = reinterpret_cast(checkpoint_state); + + using ProvidersType = std::vector>; + train_sess = std::make_unique(env->GetEnvironment(), + options == nullptr ? onnxruntime::SessionOptions() : options->value, + options == nullptr + ? ProvidersType() + : CreateProviders(options->provider_factories), + chkpt_state, + model_identifiers, + options == nullptr + ? gsl::span() + : options->custom_op_domains_); + + return nullptr; +} + } // namespace ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, - _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_ OrtTrainingSession** out) { + _In_ const ORTCHAR_T* optimizer_model_path, _Outptr_result_maybenull_ OrtTrainingSession** out) { API_IMPL_BEGIN if (options != nullptr && options->value.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigUseEnvAllocators, "0") == "1") { return OrtApis::CreateStatus(ORT_NOT_IMPLEMENTED, "Use Env Allocators is not supported for on device training."); } std::unique_ptr train_sess; - auto chkpt_state = reinterpret_cast(checkpoint_state); OrtStatus* status = nullptr; *out = nullptr; - ORT_TRY { - using ProvidersType = std::vector>; - train_sess = std::make_unique( - env->GetEnvironment(), - options == nullptr ? onnxruntime::SessionOptions() : options->value, - options == nullptr ? ProvidersType() : CreateProviders(options->provider_factories), - chkpt_state, - onnxruntime::training::api::ModelIdentifiers( - onnxruntime::ToUTF8String(train_model_path), - eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) - : std::nullopt, - optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) - : std::nullopt), - options == nullptr ? gsl::span() : options->custom_op_domains_); - - *out = reinterpret_cast(train_sess.release()); - } - ORT_CATCH(const std::exception& e) { - ORT_HANDLE_EXCEPTION([&]() { - status = OrtApis::CreateStatus(ORT_FAIL, e.what()); - }); - } + ORT_ENFORCE(train_model_path != nullptr, + "Train model path is required to create TrainingSession, it cannot be empty."); + + auto model_identifiers = onnxruntime::training::api::ModelIdentifiers( + onnxruntime::ToUTF8String(train_model_path), + eval_model_path ? std::optional(onnxruntime::ToUTF8String(eval_model_path)) + : std::nullopt, + optimizer_model_path ? std::optional(onnxruntime::ToUTF8String(optimizer_model_path)) + : std::nullopt); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); + + return status; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out) { + API_IMPL_BEGIN + std::unique_ptr train_sess; + OrtStatus* status = nullptr; + *out = nullptr; + ORT_ENFORCE(train_model_data != nullptr && train_data_length != 0, + "Training Session Creation failed. Train model data cannot be NULL."); + + auto model_identifiers = ModelIdentifiers(gsl::make_span(reinterpret_cast(train_model_data), + train_data_length), + eval_data_length == 0 || eval_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(eval_model_data), + eval_data_length), + optim_data_length == 0 || optim_model_data == nullptr + ? gsl::span() + : gsl::make_span(reinterpret_cast(optim_model_data), + optim_data_length)); + + ORT_API_RETURN_IF_ERROR(CreateSessionAndLoadModel(env, options, checkpoint_state, model_identifiers, train_sess)); + *out = reinterpret_cast(train_sess.release()); return status; API_IMPL_END } @@ -523,6 +566,7 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::LoadCheckpoint, &OrtTrainingApis::SaveCheckpoint, &OrtTrainingApis::CreateTrainingSession, + &OrtTrainingApis::CreateTrainingSessionFromBuffer, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputCount, &OrtTrainingApis::TrainingSessionGetEvalModelOutputCount, &OrtTrainingApis::TrainingSessionGetTrainingModelOutputName, diff --git a/orttraining/orttraining/training_api/optimizer.cc b/orttraining/orttraining/training_api/optimizer.cc index a6b82f1d50fc0..7f583ce8f6e76 100644 --- a/orttraining/orttraining/training_api/optimizer.cc +++ b/orttraining/orttraining/training_api/optimizer.cc @@ -61,19 +61,10 @@ Status GraphInputsAreExpected(gsl::span actual_graph_inputs, } // namespace std::unique_ptr OptimizerAlorithmFactory::CreateInstance( - const std::string& optim_path, int32_t& group_count) { + std::shared_ptr model, int32_t& group_count) { std::map, int32_t> opt_type_to_freq_map; #if !defined(ORT_MINIMAL_BUILD) - if (const auto optim_path_str = ToPathString(optim_path); - fbs::utils::IsOrtFormatModel(optim_path_str)) { - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from an ort format model. - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; - } else { - std::shared_ptr model; - ORT_ENFORCE(Model::Load(optim_path_str, model, nullptr, - logging::LoggingManager::DefaultLogger()) - .IsOK()); + if (model != nullptr) { Graph& graph = model->MainGraph(); for (auto& node : graph.Nodes()) { if (node.Domain() == kMSDomain && (node.OpType() == "AdamWOptimizer" || node.OpType() == "SGDOptimizerV2")) { @@ -85,33 +76,71 @@ std::unique_ptr OptimizerAlorithmFactory::CreateInstance opt_type_to_freq_map[domain_type_pair] += 1; } } - } + } else { #else - // TODO (baijumeswani): Figure out the best way to extract the optimizer type - // from the model (either onnx model or ort format model) or from the checkpoint. - // For now, assume that the optimizer type is AdamWOptimizer in a minimal build. - ORT_UNUSED_PARAMETER(optim_path); - - opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; + ORT_UNUSED_PARAMETER(model); +#endif + // TODO(baijumeswani): Figure out the best way to extract the optimizer type + // from the model (either onnx model or ort format model) or from the checkpoint. + // For now, assume that the optimizer type is AdamWOptimizer when using ort format models. + opt_type_to_freq_map[std::make_pair(kMSDomain, "AdamWOptimizer")] = 1; +#if !defined(ORT_MINIMAL_BUILD) + } #endif ORT_ENFORCE(opt_type_to_freq_map.size() == 1U, "Only support one type of optimizer algorithm, but got: " + std::to_string(opt_type_to_freq_map.size())); auto opt_it = opt_type_to_freq_map.begin(); + auto& op_type = opt_it->first.second; group_count = opt_it->second; - auto& domain = opt_it->first.first; - auto& type = opt_it->first.second; + ORT_ENFORCE(group_count == 1, "Group count can only be 1, but got: " + std::to_string(group_count)); // TODO: to support multiple groups, need to create a mapping between each group to its parameter list. - if (domain == kMSDomain && type == "AdamWOptimizer") { + if (op_type == "AdamWOptimizer") { return std::make_unique(); - } else if (domain == kMSDomain && type == "SGDOptimizerV2") { + } else if (op_type == "SGDOptimizerV2") { return std::make_unique(); } else { ORT_NOT_IMPLEMENTED("Not implemented for optimizer algo: " + opt_it->first.second); } } +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const PathString& optim_path, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModel(optim_path)) { + ORT_ENFORCE(Model::Load(optim_path, model, nullptr, + logging::LoggingManager::DefaultLogger()) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_path); +#endif + return CreateInstance(model, group_count); +} + +std::unique_ptr OptimizerAlorithmFactory::CreateInstance( + const uint8_t* optim_model_data, size_t optim_model_data_len, int32_t& group_count) { + std::shared_ptr model = nullptr; +#if !defined(ORT_MINIMAL_BUILD) + if (!fbs::utils::IsOrtFormatModelBytes(optim_model_data, static_cast(optim_model_data_len))) { + ONNX_NAMESPACE::ModelProto model_proto; + ORT_ENFORCE(model_proto.ParseFromArray(optim_model_data, static_cast(optim_model_data_len)) == true, + "Failed to load model because protobuf parsing failed."); + + ORT_ENFORCE(Model::Load(std::move(model_proto), model, nullptr, + logging::LoggingManager::DefaultLogger(), ModelOptions(true, true)) + .IsOK()); + } +#else + ORT_UNUSED_PARAMETER(optim_model_data); + ORT_UNUSED_PARAMETER(optim_model_data_len); +#endif + + return CreateInstance(model, group_count); +} + Status Optimizer::GenerateMomentumNamedStates(OptimizerCheckpointState& optimizer_checkpoint_states) { auto group_optimizer_state_it = optimizer_checkpoint_states.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -200,14 +229,14 @@ Status Optimizer::ConstructInputs() { return Status::OK(); } // namespace api -Optimizer::Optimizer(const std::string& optim_path_or_bytes, +Optimizer::Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, const std::vector>& providers, gsl::span op_domains) : optim_sess_(std::make_unique(session_options, env)), state_(state) { - Initialize(optim_path_or_bytes, providers, op_domains); + Initialize(model_identifiers, providers, op_domains); ORT_ENFORCE(state != nullptr, "Checkpoint state cannot be null."); auto g_it = state_->optimizer_checkpoint_state.group_named_optimizer_states.find(GROUP_ZERO_NAME); @@ -223,7 +252,7 @@ Optimizer::Optimizer(const std::string& optim_path_or_bytes, } } -void Optimizer::Initialize(const std::string& optim_path_or_bytes, +void Optimizer::Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, [[maybe_unused]] gsl::span op_domains) { #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) @@ -236,7 +265,22 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, ORT_THROW_IF_ERROR(optim_sess_->RegisterExecutionProvider(execution_provider)); } - ORT_THROW_IF_ERROR(optim_sess_->Load(optim_path_or_bytes)); + ORT_ENFORCE(model_identifiers.IsOptimizerModelAvailable(), "Optimizer model is not available."); + + if (std::holds_alternative>(model_identifiers.optim_model)) { + auto optimizer_model = std::get>(model_identifiers.optim_model); + // The above call to IsOptimizerModelAvailable() ensures that optimizer_model is not nullopt + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.value())); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(ToWideString(optimizer_model.value()), group_count_); + } else { + auto optimizer_model = std::get>(model_identifiers.optim_model); + ORT_THROW_IF_ERROR(optim_sess_->Load(optimizer_model.data(), + static_cast(optimizer_model.size()))); + optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optimizer_model.data(), + optimizer_model.size(), + group_count_); + } + ORT_THROW_IF_ERROR(optim_sess_->Initialize()); // Make sure that the checkpoint state can copy tensors @@ -244,10 +288,6 @@ void Optimizer::Initialize(const std::string& optim_path_or_bytes, utils::GetGraphInputOutputNames(optim_sess_, input_names_, output_names_); - optimizer_algo_ptr_ = OptimizerAlorithmFactory::CreateInstance(optim_path_or_bytes, group_count_); - ORT_ENFORCE(group_count_ == 1, "Group count can only be 1, but got: " + std::to_string(group_count_)); - ORT_ENFORCE(optimizer_algo_ptr_, "optimizer_algo_ptr_ should not be nullptr."); - InlinedVector all_input_names; all_input_names.reserve(CommonOptimizerInputs.size() + optimizer_algo_ptr_->optimizer_states_inputs.size()); all_input_names.insert(all_input_names.end(), CommonOptimizerInputs.begin(), diff --git a/orttraining/orttraining/training_api/optimizer.h b/orttraining/orttraining/training_api/optimizer.h index 36ce3297fe3c4..d9bc4870bb7ed 100644 --- a/orttraining/orttraining/training_api/optimizer.h +++ b/orttraining/orttraining/training_api/optimizer.h @@ -64,8 +64,11 @@ struct SGDOptimizerV2Algorithm : public OptimizerAlgorithmBase { }; struct OptimizerAlorithmFactory { - static std::unique_ptr CreateInstance(const std::string& optim_path_or_bytes, + static std::unique_ptr CreateInstance(const PathString& optim_path, int32_t& group_count); + static std::unique_ptr CreateInstance(const uint8_t* optim_model_data, + size_t optim_model_data_len, int32_t& group_count); + static std::unique_ptr CreateInstance(std::shared_ptr model, int32_t& group_count); }; struct CheckpointState; @@ -96,7 +99,7 @@ struct Optimizer { // Initialize an optimizer module from an ORT inference session with loaded // training ONNX model For each parameter, initialize the OptimizerState based // on the graph input's ValueInfoProto if the parameter doesn't have it already. - Optimizer(const std::string& optim_path_or_bytes, + Optimizer(const ModelIdentifiers& model_identifiers, CheckpointState* state, const onnxruntime::SessionOptions& session_options, const Environment& env, @@ -121,7 +124,7 @@ struct Optimizer { } private: - void Initialize(const std::string& optim_path_or_bytes, + void Initialize(const ModelIdentifiers& model_identifiers, const std::vector>& providers, gsl::span op_domains); diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index 2b383f3b9782a..c87108957c975 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -8,7 +8,14 @@ ORT_API(const OrtTrainingApi*, GetTrainingApi, uint32_t version); ORT_API_STATUS_IMPL(CreateTrainingSession, _In_ const OrtEnv* env, _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, _In_ const ORTCHAR_T* train_model_path, _In_ const ORTCHAR_T* eval_model_path, _In_ const ORTCHAR_T* optimizer_model_path, - _Outptr_ OrtTrainingSession** out); + _Outptr_result_maybenull_ OrtTrainingSession** out); + +ORT_API_STATUS_IMPL(CreateTrainingSessionFromBuffer, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* options, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const void* train_model_data, size_t train_data_length, + _In_ const void* eval_model_data, size_t eval_data_length, + _In_ const void* optim_model_data, size_t optim_data_length, + _Outptr_result_maybenull_ OrtTrainingSession** out); ORT_API_STATUS_IMPL(TrainingSessionGetTrainingModelOutputCount, _In_ const OrtTrainingSession* sess, _Out_ size_t* out); diff --git a/orttraining/orttraining/training_api/training_session.cc b/orttraining/orttraining/training_api/training_session.cc index 6915193a8ff7c..45f0f0ddcf7f4 100644 --- a/orttraining/orttraining/training_api/training_session.cc +++ b/orttraining/orttraining/training_api/training_session.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "orttraining/training_api/training_session.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime::training::api { @@ -12,13 +13,12 @@ TrainingSession::TrainingSession(const Environment& session_env, const ModelIdentifiers& model_identifiers, gsl::span custom_op_domains) : state_{state}, - module_{std::make_unique(model_identifiers.train_model, state_, - session_options, session_env, providers, - model_identifiers.eval_model, custom_op_domains)}, - optimizer_{model_identifiers.optim_model.has_value() + module_{std::make_unique(model_identifiers, state_, + session_options, session_env, providers, custom_op_domains)}, + optimizer_{model_identifiers.IsOptimizerModelAvailable() ? std::make_unique( - model_identifiers.optim_model.value(), state_, - session_options, session_env, providers, custom_op_domains) + model_identifiers, state_, + session_options, session_env, providers) : std::unique_ptr()} {} Status TrainingSession::RegisterScheduler( diff --git a/orttraining/orttraining/training_api/training_session.h b/orttraining/orttraining/training_api/training_session.h index 1a16acd5115f0..13b0ae79093de 100644 --- a/orttraining/orttraining/training_api/training_session.h +++ b/orttraining/orttraining/training_api/training_session.h @@ -3,25 +3,17 @@ #pragma once #include "core/common/common.h" -#include "module.h" -#include "optimizer.h" -#include "lr_scheduler.h" -#include "checkpoint.h" +#include "orttraining/training_api/module.h" +#include "orttraining/training_api/optimizer.h" +#include "orttraining/training_api/lr_scheduler.h" +#include "orttraining/training_api/checkpoint.h" +#include "orttraining/training_api/utils.h" namespace onnxruntime { namespace training { namespace api { using namespace common; -struct ModelIdentifiers { - const std::string train_model; - const std::optional eval_model, optim_model; - ModelIdentifiers(const std::string& train_model_uri, - const std::optional& eval_model_uri, - const std::optional& optim_model_uri) - : train_model(train_model_uri), eval_model(eval_model_uri), optim_model(optim_model_uri) {} -}; - // Wrapper on top of module and optimizer classes and is the only class exposed via capis class TrainingSession { public: diff --git a/orttraining/orttraining/training_api/utils.h b/orttraining/orttraining/training_api/utils.h index e856554c971ec..f16f0f947fbd5 100644 --- a/orttraining/orttraining/training_api/utils.h +++ b/orttraining/orttraining/training_api/utils.h @@ -10,6 +10,40 @@ namespace onnxruntime { namespace training { namespace api { + +struct ModelIdentifiers { + // ModelIdentifiers struct enables an easy way to store and identify the models used for training, evaluation + // and model updates(optimizer model). + // The model can be specified by a path to the model file or by a span of bytes containing the model data. + // Training model is required, evaluation and optimizer models are optional. + std::variant> train_model; + std::variant, gsl::span> eval_model; + std::variant, gsl::span> optim_model; + + ModelIdentifiers(std::variant> training_model, + std::variant, gsl::span> evaluation_model, + std::variant, gsl::span> optimzer_model) + : train_model(training_model), eval_model(evaluation_model), optim_model(optimzer_model) {} + + bool IsModelAvailable(const std::variant, gsl::span>& model) const { + if ((std::holds_alternative>(model) && + std::get>(model).has_value()) || + (std::holds_alternative>(model) && + std::get>(model).size() > 0)) { + return true; + } + return false; + } + + bool IsEvalModelAvailable() const { + return IsModelAvailable(eval_model); + } + + bool IsOptimizerModelAvailable() const { + return IsModelAvailable(optim_model); + } +}; + namespace utils { // Get names of graph inputs and outputs diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 6aac9ad7ecbc3..8ec884382c916 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -85,6 +85,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvGrad); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ConvTransposeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ConvTransposeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ConvTransposeGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GatherGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DropoutGrad); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BitmaskDropoutGrad); @@ -346,6 +349,9 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc index f69da000be01c..f6c58445c0a5d 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc @@ -3,13 +3,6 @@ #include "orttraining/training_ops/cuda/nn/conv_grad.h" -#include "core/providers/common.h" -#include "core/providers/cuda/shared_inc/fpgeneric.h" -#include "core/platform/ort_mutex.h" - -// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad is referenced on PyTorch's implementation -// from aten/src/ATen/native/cudnn/Conv_v7.cpp. - namespace onnxruntime { namespace cuda { @@ -22,229 +15,6 @@ REGISTER_GRADIENT_KERNEL_TYPED(float) REGISTER_GRADIENT_KERNEL_TYPED(double) REGISTER_GRADIENT_KERNEL_TYPED(MLFloat16) -using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t; -using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t; -using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t; -using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t; - -cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { - return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc, - args.x_tensor, algo, workspace_size); -} - -cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { - return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc, - args.w_desc, algo, workspace_size); -} - -template -size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { - // Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info. - size_t free, total; - CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); - // Assuming 10% of fragmentation. - free = static_cast(static_cast(free) * 0.9); - size_t max_workspace_size = 0; - for (int i = 0; i < n_algo; i++) { - cudnnStatus_t status; - size_t workspace_size; - status = GetWorkspaceSize(args, algo[i], &workspace_size); - if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size || - workspace_size > free) - continue; - max_workspace_size = workspace_size; - } - - return max_workspace_size; -} - -template -std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { - std::vector result; - result.reserve(n_algo); - for (int i = 0; i < n_algo; i++) { - T_Perf perf = perf_results[i]; - if (perf.status == CUDNN_STATUS_SUCCESS) { - result.emplace_back(perf); - } - } - ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN"); - // TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups - // when cuDNN version < 7.5. Need to add handling for such special case. - return result; -} - -struct ConvParamsHash { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); - size_t operator()(const ConvParams& conv_params) const { - auto ptr = reinterpret_cast(&conv_params); - uint32_t value = 0x811C9DC5; - for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { - value ^= ptr[i]; - value *= 0x01000193; - } - return static_cast(value); - } -}; - -struct ConvParamsEqual { - // ConvParams must be a POD because we read out its memory constant as char* when hashing. - static_assert(std::is_pod::value, "ConvParams is not POD"); - bool operator()(const ConvParams& a, const ConvParams& b) const { - auto ptr1 = reinterpret_cast(&a); - auto ptr2 = reinterpret_cast(&b); - return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; - } -}; - -template -struct AlgoPerfCache { - mutable OrtMutex mutex; - std::unordered_map map; - - bool Find(const ConvParams& params, T_Perf* result) { - std::lock_guard guard(mutex); - auto it = map.find(params); - if (it == map.end()) { - return false; - } - *result = it->second; - return true; - } - - void Insert(const ConvParams& params, const T_Perf& algo_perf) { - std::lock_guard guard(mutex); - map[params] = algo_perf; - } -}; - -// TODO: Currently we use global AlgoPerfCache for ConvGrad only. Conv's perf cache is till per node. -// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. -AlgoPerfCache bwd_data_algos; -AlgoPerfCache bwd_filter_algos; - -template -struct AlgoSearch {}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; - static AlgoPerfCache& Cache() { return bwd_data_algos; } - static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdDataAlgo algos[] = { - CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, - CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; - static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); - int perf_count; - std::unique_ptr candidates = std::make_unique(num_algos); - if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor, - args.conv_desc, args.x_tensor, num_algos, - &perf_count, candidates.get())); - } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { - size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( - args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor, - args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); - } else { - ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); - } - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template <> -struct AlgoSearch { - static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; - static AlgoPerfCache& Cache() { return bwd_filter_algos; } - static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, - std::vector& perf_results) { - static const T_BwdFilterAlgo algos[] = { - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, - }; - - // NOTE: - 1 because ALGO_WINOGRAD is not implemented. - static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); - std::unique_ptr candidates = std::make_unique(num_algos); - int perf_count; - if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor, - args.conv_desc, args.w_desc, num_algos, - &perf_count, candidates.get())); - } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { - size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) - : AlgoSearchWorkspaceSize; - // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. - // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. - IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx( - args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc, - args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); - } else { - ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); - } - perf_results = GetValidAlgorithms(candidates.get(), perf_count); - return Status::OK(); - } -}; - -template -class AlgoIterator { - public: - AlgoIterator(const ConvArgs& args) : args_(args) {} - - static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { - perf_results.resize(1); - perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; - if (args.params.data_type == CUDNN_DATA_HALF) { - perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; - } else { - perf_results[0].mathType = CUDNN_DEFAULT_MATH; - } - CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory))); - return Status::OK(); - } - - Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, std::function f) { - auto& cache = AlgoSearch::Cache(); - - if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { - return Status::OK(); - } - - std::vector perf_results; - ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault - ? OnlyDefaultAlgorithm(args_, perf_results) - : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); - for (auto& algo_perf : perf_results) { - if (f(algo_perf) == Status::OK()) { - cache.Insert(args_.params, algo_perf); - return Status::OK(); - } - } - ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution."); - return Status::OK(); - } - - private: - const ConvArgs& args_; -}; - template Status ConvGrad::PrepareArgs(const Tensor& x, const Tensor& dY, const Tensor& w, Tensor* dB, Tensor* dX, Tensor* dW, cudnnHandle_t cudnn_handle) const { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h index 5d0c123fd9358..9bbcd5b30d168 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_grad.h @@ -3,47 +3,11 @@ #pragma once -#include "core/providers/cuda/cudnn_common.h" -#include "core/providers/cpu/nn/conv_attributes.h" -#include "core/providers/cuda/nn/conv.h" +#include "orttraining/training_ops/cuda/nn/conv_shared.h" namespace onnxruntime { namespace cuda { -// cuDNN only takes 4D or 5D x tensor. -static constexpr int MAX_DIM = 3; - -struct ConvParams { - int8_t device_id; - cudnnDataType_t data_type; - int input_size[2 + MAX_DIM]; - uint8_t input_dim; - int weight_size[2 + MAX_DIM]; - int padding[MAX_DIM * 2]; - int stride[MAX_DIM]; - int dilation[MAX_DIM]; - int64_t groups; - int algo_mode; -}; - -struct ConvArgs { - // Update needed if x or w's dims changed. - TensorShapeVector last_x_dims; - TensorShapeVector last_w_dims; - - cudnnHandle_t handle; - ConvParams params; - CudnnTensor x_tensor, y_tensor, b_tensor; - CudnnFilterDescriptor w_desc; - CudnnConvolutionDescriptor conv_desc; - const void* x_data; - const void* w_data; - const void* dy_data; - void* dx_data; - void* dw_data; - void* db_data; -}; - template class ConvGrad final : public CudaKernel { public: diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc new file mode 100644 index 0000000000000..5dc16c68f6210 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -0,0 +1,275 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/nn/conv_shared.h" + +#include "core/platform/ort_mutex.h" +#include "core/providers/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime::cuda { + +namespace { + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdDataAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardDataWorkspaceSize(args.handle, args.w_desc, args.y_tensor, args.conv_desc, + args.x_tensor, algo, workspace_size); +} + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_BwdFilterAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionBackwardFilterWorkspaceSize(args.handle, args.x_tensor, args.y_tensor, args.conv_desc, + args.w_desc, algo, workspace_size); +} + +cudnnStatus_t GetWorkspaceSize(const ConvArgs& args, T_FwdAlgo algo, size_t* workspace_size) { + return cudnnGetConvolutionForwardWorkspaceSize(args.handle, args.x_tensor, args.w_desc, args.conv_desc, + args.y_tensor, algo, workspace_size); +} + +template +size_t GetMaxWorkspaceSize(const ConvArgs& args, const T_Algo* algo, int n_algo) { + // Calling cudaMemGetInfo is not ideal, but our cuda allocator doesn't have a way to get this info. + size_t free, total; + CUDA_CALL_THROW(cudaMemGetInfo(&free, &total)); + // Assuming 10% of fragmentation. + free = static_cast(static_cast(free) * 0.9); + size_t max_workspace_size = 0; + for (int i = 0; i < n_algo; i++) { + cudnnStatus_t status; + size_t workspace_size; + status = GetWorkspaceSize(args, algo[i], &workspace_size); + if (CUDNN_STATUS_SUCCESS != status || workspace_size == 0 || workspace_size < max_workspace_size || + workspace_size > free) + continue; + max_workspace_size = workspace_size; + } + + return max_workspace_size; +} + +template +std::vector GetValidAlgorithms(const T_Perf* perf_results, int n_algo) { + std::vector result; + result.reserve(n_algo); + for (int i = 0; i < n_algo; i++) { + T_Perf perf = perf_results[i]; + if (perf.status == CUDNN_STATUS_SUCCESS) { + result.emplace_back(perf); + } + } + ORT_ENFORCE(result.size() > 0, "No valid convolution algorithms available in CuDNN"); + // TODO: This is a cuDNN bug that gave wrong results in certain strided convolution gradient setups + // when cuDNN version < 7.5. Need to add handling for such special case. + return result; +} + +template +struct AlgoPerfCache { + mutable OrtMutex mutex; + std::unordered_map map; + + bool Find(const ConvParams& params, T_Perf* result) { + std::lock_guard guard(mutex); + auto it = map.find(params); + if (it == map.end()) { + return false; + } + *result = it->second; + return true; + } + + void Insert(const ConvParams& params, const T_Perf& algo_perf) { + std::lock_guard guard(mutex); + map[params] = algo_perf; + } +}; + +// TODO: Currently we use global AlgoPerfCache for ConvGrad and ConvTransposeGrad only. +// Conv's perf cache is still per node. +// Need to apply such global cache for Conv, and move some shared code from here to conv.h/cc. +AlgoPerfCache bwd_data_algos; +AlgoPerfCache bwd_filter_algos; +AlgoPerfCache fwd_algos; + +template +struct AlgoSearch {}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_data_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_BwdDataAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_DATA_ALGO_0, CUDNN_CONVOLUTION_BWD_DATA_ALGO_1, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); + int perf_count; + std::unique_ptr candidates = std::make_unique(num_algos); + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardDataAlgorithm_v7(args.handle, args.w_desc, args.y_tensor, + args.conv_desc, args.x_tensor, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( + args.handle, args.w_desc, args.w_data, args.y_tensor, args.dy_data, args.conv_desc, args.x_tensor, + args.dx_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + static AlgoPerfCache& Cache() { return bwd_filter_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_BwdFilterAlgo algos[] = { + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, + }; + + // NOTE: - 1 because ALGO_WINOGRAD is not implemented. + static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); + int perf_count; + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionBackwardFilterAlgorithm_v7(args.handle, args.x_tensor, args.y_tensor, + args.conv_desc, args.w_desc, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 ? nullptr : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardFilterAlgorithmEx( + args.handle, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, args.w_desc, + args.dw_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +template <> +struct AlgoSearch { + static constexpr auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static AlgoPerfCache& Cache() { return fwd_algos; } + static Status FindAlgorithms(const ConvArgs& args, const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::vector& perf_results) { + static const T_FwdAlgo algos[] = { + CUDNN_CONVOLUTION_FWD_ALGO_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_FFT, + CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM, + CUDNN_CONVOLUTION_FWD_ALGO_DIRECT, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD, + CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED, + }; + + static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; + ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); + int perf_count; + if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7(args.handle, args.x_tensor, args.w_desc, + args.conv_desc, args.y_tensor, num_algos, + &perf_count, candidates.get())); + } else if (args.params.algo_mode == OrtCudnnConvAlgoSearchExhaustive) { + size_t max_workspace_size = provider->GetCudnnConvUseMaxWorkspace() ? GetMaxWorkspaceSize(args, algos, num_algos) + : AlgoSearchWorkspaceSize; + // Use GetTransientScratchBuffer() so the workspace can be freed instead of cached. + // Because the benchmarking uses a huge amount of memory, e.g. a few GBs. + IAllocatorUniquePtr workspace = max_workspace_size == 0 + ? nullptr + : IAllocator::MakeUniquePtr(allocator, max_workspace_size, true); + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + args.handle, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc, args.y_tensor, + args.y_data, num_algos, &perf_count, candidates.get(), workspace.get(), max_workspace_size)); + } else { + ORT_ENFORCE(false, "Algo mode should be EXHAUSTIVE (0) or HEURISTIC (1), but got ", args.params.algo_mode); + } + perf_results = GetValidAlgorithms(candidates.get(), perf_count); + return Status::OK(); + } +}; + +} // namespace + +size_t ConvParamsHash::operator()(const ConvParams& conv_params) const { + auto ptr = reinterpret_cast(&conv_params); + uint32_t value = 0x811C9DC5; + for (int i = 0; i < static_cast(sizeof(ConvParams)); ++i) { + value ^= ptr[i]; + value *= 0x01000193; + } + return static_cast(value); +} + +bool ConvParamsEqual::operator()(const ConvParams& a, const ConvParams& b) const { + auto ptr1 = reinterpret_cast(&a); + auto ptr2 = reinterpret_cast(&b); + return memcmp(ptr1, ptr2, sizeof(ConvParams)) == 0; +} + +template +Status AlgoIterator::OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results) { + perf_results.resize(1); + perf_results[0].algo = AlgoSearch::DEFAULT_ALGO; + if (args.params.data_type == CUDNN_DATA_HALF) { + perf_results[0].mathType = CUDNN_TENSOR_OP_MATH; + } else { + perf_results[0].mathType = CUDNN_DEFAULT_MATH; + } + CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(args, perf_results[0].algo, &(perf_results[0].memory))); + return Status::OK(); +} + +template +Status AlgoIterator::TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::function f) { + auto& cache = AlgoSearch::Cache(); + + if (T_Perf algo_perf; cache.Find(args_.params, &algo_perf) && f(algo_perf) == Status::OK()) { + return Status::OK(); + } + + std::vector perf_results; + ORT_RETURN_IF_ERROR(args_.params.algo_mode == OrtCudnnConvAlgoSearchDefault + ? OnlyDefaultAlgorithm(args_, perf_results) + : AlgoSearch::FindAlgorithms(args_, provider, allocator, perf_results)); + for (auto& algo_perf : perf_results) { + if (f(algo_perf) == Status::OK()) { + cache.Insert(args_.params, algo_perf); + return Status::OK(); + } + } + ORT_ENFORCE(false, "Unable to find a valid cuDNN algorithm to run convolution."); + return Status::OK(); +} + +template class AlgoIterator; +template class AlgoIterator; +template class AlgoIterator; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h new file mode 100644 index 0000000000000..a2d4bf3bdc006 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.h @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cudnn_common.h" +#include "core/providers/cuda/nn/conv.h" + +// The AlgoPerfCache and AlgoSearch here for Conv/ConvGrad/ConvTransposeGrad is adapted from PyTorch's implementation +// in aten/src/ATen/native/cudnn/Conv_v7.cpp. + +namespace onnxruntime::cuda { + +using T_BwdDataPerf = cudnnConvolutionBwdDataAlgoPerf_t; +using T_BwdDataAlgo = cudnnConvolutionBwdDataAlgo_t; +using T_BwdFilterPerf = cudnnConvolutionBwdFilterAlgoPerf_t; +using T_BwdFilterAlgo = cudnnConvolutionBwdFilterAlgo_t; +using T_FwdAlgo = cudnnConvolutionFwdAlgo_t; +using T_FwdPerf = cudnnConvolutionFwdAlgoPerf_t; + +// cuDNN only takes 4D or 5D x tensor. +static constexpr int MAX_DIM = 3; + +struct ConvParams { + int8_t device_id; + cudnnDataType_t data_type; + int input_size[2 + MAX_DIM]; + uint8_t input_dim; + int weight_size[2 + MAX_DIM]; + int padding[MAX_DIM * 2]; + int stride[MAX_DIM]; + int dilation[MAX_DIM]; + int64_t groups; + int algo_mode; +}; + +struct ConvArgs { + // Update needed if x or w's dims changed. + TensorShapeVector last_x_dims; // Input to the convolution + TensorShapeVector last_w_dims; // Weights of the convolution + + cudnnHandle_t handle; + ConvParams params; + CudnnTensor x_tensor, y_tensor, b_tensor; + CudnnFilterDescriptor w_desc; + CudnnConvolutionDescriptor conv_desc; + const void* x_data; + const void* w_data; + const void* dy_data; + void* y_data; + void* dx_data; + void* dw_data; + void* db_data; +}; + +struct ConvParamsHash { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + + size_t operator()(const ConvParams& conv_params) const; +}; + +struct ConvParamsEqual { + // ConvParams must be a POD because we read out its memory constant as char* when hashing. + static_assert(std::is_pod::value, "ConvParams is not POD"); + + bool operator()(const ConvParams& a, const ConvParams& b) const; +}; + +template +class AlgoIterator { + public: + AlgoIterator(const ConvArgs& args) : args_(args) {} + + Status TryAll(const CUDAExecutionProvider* provider, const AllocatorPtr& allocator, + std::function f); + + static Status OnlyDefaultAlgorithm(const ConvArgs& args, std::vector& perf_results); + + private: + const ConvArgs& args_; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc new file mode 100644 index 0000000000000..5f7206fc121ec --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -0,0 +1,308 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "orttraining/training_ops/cuda/nn/conv_transpose_grad.h" + +namespace onnxruntime::cuda { + +#define REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX(ConvTransposeGrad, kMSDomain, 1, T, kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ConvTransposeGrad); + +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(float) +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(double) +REGISTER_CONVTRANSPOSE_GRADIENT_KERNEL_TYPED(MLFloat16) + +template +Status ConvTransposeGrad::ComputeInternal(OpKernelContext* context) const { + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* W = context->Input(2); + Tensor* dX = context->Output(0, X->Shape()); + Tensor* dW = context->Output(1, W->Shape()); + Tensor* dB = context->Output(2, {W->Shape()[1] * conv_attrs_.group}); + + if (dX) { + ORT_RETURN_IF_ERROR(PrepareConvForwardArgs(*dY, *W, *dX, GetCudnnHandle(context), args_dx_)); + ORT_RETURN_IF_ERROR(ComputeInputGradient(context->GetComputeStream(), args_dx_)); + } + + if (dW || dB) { + ORT_RETURN_IF_ERROR(PrepareConvBackwardFilterArgs(*dY, *W, *X, dW, dB, GetCudnnHandle(context), args_dw_)); + if (dW) ORT_RETURN_IF_ERROR(ComputeWeightGradient(context->GetComputeStream(), args_dw_)); + if (dB) ORT_RETURN_IF_ERROR(ComputeBiasGradient(args_dw_)); + } + + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const { + return AlgoIterator(args).TryAll( + static_cast(Info().GetExecutionProvider()), + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + [&](const T_FwdPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward( + args.handle, &one, args.x_tensor, args.x_data, args.w_desc, args.w_data, args.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data)); + return Status::OK(); + }); + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const { + return AlgoIterator(args).TryAll( + static_cast(Info().GetExecutionProvider()), + Info().GetAllocator(OrtMemType::OrtMemTypeDefault), + [&](const T_BwdFilterPerf& algo_perf) -> Status { + const auto one = Consts::One; + const auto zero = Consts::Zero; + IAllocatorUniquePtr workspace = GetScratchBuffer(algo_perf.memory, stream); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(args.conv_desc, algo_perf.mathType)); + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardFilter( + args.handle, &one, args.x_tensor, args.x_data, args.y_tensor, args.dy_data, args.conv_desc, + algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data)); + return Status::OK(); + }); + return Status::OK(); +} + +template +Status ConvTransposeGrad::ComputeBiasGradient(const ConvArgs& args) const { + const auto one = Consts::One; + const auto zero = Consts::Zero; + CUDNN_RETURN_IF_ERROR(cudnnConvolutionBackwardBias(args.handle, &one, args.x_tensor, args.x_data, &zero, + args.b_tensor, args.db_data)); + return Status::OK(); +} + +template +Status ConvTransposeGrad::PrepareConvForwardArgs(const Tensor& X, const Tensor& W, + Tensor& Y, cudnnHandle_t cudnn_handle, + ConvArgs& args) const { + const TensorShape& x_shape = X.Shape(); + auto x_dims = x_shape.AsShapeVector(); + args.x_data = reinterpret_cast(X.template Data()); + + const TensorShape& w_shape = W.Shape(); + auto w_dims = w_shape.AsShapeVector(); + args.w_data = reinterpret_cast(W.template Data()); + + const TensorShape& y_shape = Y.Shape(); + auto y_dims = y_shape.AsShapeVector(); + args.y_data = reinterpret_cast(Y.template MutableData()); + + args.dy_data = nullptr; + args.db_data = nullptr; + args.dx_data = nullptr; + args.dw_data = nullptr; + + bool x_dims_changed = (args.last_x_dims != x_dims); + bool w_dims_changed = (args.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args.last_x_dims = x_dims; + if (w_dims_changed) args.last_w_dims = w_dims; + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + if (rank < 2) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims.insert(x_dims.begin() + 2, 1); + y_dims.insert(y_dims.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims.push_back(1); + y_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + memset(&args.params, 0, sizeof(ConvParams)); + args.params.device_id = static_cast(cuda_ep->GetDeviceId()); + args.params.data_type = CudnnTensor::GetDataType(); + args.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args.params.input_size[i] = static_cast(x_dims[i]); + args.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args.params.padding[i] = static_cast(pads[i]); + args.params.padding[i + rank] = static_cast(pads[i + rank]); + args.params.stride[i] = static_cast(strides[i]); + args.params.dilation[i] = static_cast(dilations[i]); + } + args.params.groups = conv_attrs_.group; + int algo_mode = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(algo_mode > -1 && algo_mode < 3, + "Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode); + args.params.algo_mode = algo_mode; + + args.handle = cudnn_handle; + ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + args.params.data_type)); + } + + return Status::OK(); +} + +template +Status ConvTransposeGrad::PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY, + Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle, + ConvArgs& args) const { + const TensorShape& x_shape = X.Shape(); + auto x_dims = x_shape.AsShapeVector(); + args.x_data = reinterpret_cast(X.template Data()); + + const TensorShape& y_shape = dY.Shape(); + auto y_dims = y_shape.AsShapeVector(); + args.dy_data = reinterpret_cast(dY.template Data()); + + const TensorShape& w_shape = W.Shape(); + auto w_dims = w_shape.AsShapeVector(); + + args.y_data = nullptr; + args.dw_data = dW ? reinterpret_cast(dW->template MutableData()) : nullptr; + args.db_data = dB ? reinterpret_cast(dB->template MutableData()) : nullptr; + args.dx_data = nullptr; + args.w_data = nullptr; + + bool x_dims_changed = (args.last_x_dims != x_dims); + bool w_dims_changed = (args.last_w_dims != w_dims); + if (x_dims_changed || w_dims_changed) { + if (x_dims_changed) args.last_x_dims = x_dims; + if (w_dims_changed) args.last_w_dims = w_dims; + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(&X, &W)); + + TensorShapeVector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(w_shape, kernel_shape)); + auto rank = kernel_shape.size(); + + ConvPadVector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + + TensorShapeVector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + + TensorShapeVector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + const CUDAExecutionProvider* cuda_ep = + static_cast(this->Info().GetExecutionProvider()); + + if (rank < 2) { + if (cuda_ep->GetCudnnConv1dPadToNc1d()) { + x_dims.insert(x_dims.begin() + 2, 1); + y_dims.insert(y_dims.begin() + 2, 1); + w_dims.insert(w_dims.begin() + 2, 1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.begin(), 0); + kernel_shape.insert(kernel_shape.begin(), 1); + strides.insert(strides.begin(), 1); + dilations.insert(dilations.begin(), 1); + } else { + x_dims.push_back(1); + y_dims.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + } + + memset(&args.params, 0, sizeof(ConvParams)); + args.params.device_id = static_cast(cuda_ep->GetDeviceId()); + args.params.data_type = CudnnTensor::GetDataType(); + args.params.input_dim = static_cast(x_dims.size()); + for (size_t i = 0; i < x_dims.size(); i++) { + args.params.input_size[i] = static_cast(x_dims[i]); + args.params.weight_size[i] = static_cast(w_dims[i]); + } + for (size_t i = 0; i < rank; i++) { + args.params.padding[i] = static_cast(pads[i]); + args.params.padding[i + rank] = static_cast(pads[i + rank]); + args.params.stride[i] = static_cast(strides[i]); + args.params.dilation[i] = static_cast(dilations[i]); + } + args.params.groups = conv_attrs_.group; + int algo_mode = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(algo_mode > -1 && algo_mode < 3, + "Algo mode should be EXHAUSTIVE (0), HEURISTIC (1) or DEFAULT (2), but got ", algo_mode); + args.params.algo_mode = algo_mode; + + args.handle = cudnn_handle; + ORT_RETURN_IF_ERROR(args.w_desc.Set(w_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.x_tensor.Set(x_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.y_tensor.Set(y_dims, args.params.data_type)); + ORT_RETURN_IF_ERROR(args.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + gsl::narrow_cast(conv_attrs_.group), CUDNN_CROSS_CORRELATION, + args.params.data_type)); + + if (dB) { + const auto& b_shape = dB->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + TensorShapeVector b_dims(2 + kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[1] = b_shape[0]; // C + for (size_t i = 0; i < kernel_shape.size(); i++) + b_dims[2 + i] = 1; + + ORT_RETURN_IF_ERROR(args.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); + } + } + + return Status::OK(); +} + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h new file mode 100644 index 0000000000000..72426323fefab --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_kernel.h" + +#include "core/providers/cpu/nn/conv_attributes.h" +#include "orttraining/training_ops/cuda/nn/conv_shared.h" + +namespace onnxruntime::cuda { + +template +class ConvTransposeGrad final : public CudaKernel { + public: + using CudaT = typename ToCudaType::MappedType; + + ConvTransposeGrad(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + Status ComputeWeightGradient(onnxruntime::Stream* stream, const ConvArgs& args) const; + Status ComputeInputGradient(onnxruntime::Stream* stream, const ConvArgs& args) const; + Status ComputeBiasGradient(const ConvArgs& args) const; + + Status PrepareConvForwardArgs(const Tensor& X, const Tensor& W, + Tensor& Y, cudnnHandle_t cudnn_handle, + ConvArgs& args) const; + + Status PrepareConvBackwardFilterArgs(const Tensor& X, const Tensor& W, const Tensor& dY, + Tensor* dW, Tensor* dB, cudnnHandle_t cudnn_handle, + ConvArgs& args) const; + + ConvAttributes conv_attrs_; + mutable ConvArgs args_dx_; + mutable ConvArgs args_dw_; +}; + +} // namespace onnxruntime::cuda diff --git a/setup.py b/setup.py index c4bbb67947eb9..3a16c38aec916 100644 --- a/setup.py +++ b/setup.py @@ -7,6 +7,7 @@ import datetime import logging import platform +import shlex import subprocess import sys from glob import glob, iglob @@ -183,108 +184,37 @@ def run(self): dest = "onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so" logger.info("copying %s -> %s", source, dest) copyfile(source, dest) - result = subprocess.run( - ["patchelf", "--print-needed", dest], check=True, stdout=subprocess.PIPE, text=True - ) - dependencies = [ - "librccl.so", - "libamdhip64.so", - "librocblas.so", - "libMIOpen.so", - "libhsa-runtime64.so", - "libhsakmt.so", - ] + to_preload = [] to_preload_cuda = [] to_preload_tensorrt = [] to_preload_cann = [] - cuda_dependencies = [] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in dependencies: - if dependency in line: - to_preload.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) - - dest = "onnxruntime/capi/libonnxruntime_providers_" + ("rocm.so" if is_rocm else "cuda.so") - if path.isfile(dest): - result = subprocess.run( - ["patchelf", "--print-needed", dest], - check=True, - stdout=subprocess.PIPE, - text=True, - ) - cuda_dependencies = [ - "libcublas.so", - "libcublasLt.so", - "libcudnn.so", - "libcudart.so", - "libcurand.so", - "libcufft.so", - "libnvToolsExt.so", - "libcupti.so", - ] - rocm_dependencies = [ - "librccl.so", - "libamdhip64.so", - "librocblas.so", - "libMIOpen.so", - "libhsa-runtime64.so", - "libhsakmt.so", - ] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in cuda_dependencies + rocm_dependencies: - if dependency in line: - if dependency not in to_preload: - to_preload_cuda.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) - - dest = "onnxruntime/capi/libonnxruntime_providers_" + ("migraphx.so" if is_rocm else "tensorrt.so") - if path.isfile(dest): - result = subprocess.run( - ["patchelf", "--print-needed", dest], - check=True, - stdout=subprocess.PIPE, - text=True, - ) - tensorrt_dependencies = ["libnvinfer.so", "libnvinfer_plugin.so", "libnvonnxparser.so"] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in cuda_dependencies + tensorrt_dependencies: - if dependency in line: - if dependency not in (to_preload + to_preload_cuda): - to_preload_tensorrt.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) - - dest = "onnxruntime/capi/libonnxruntime_providers_cann.so" - if path.isfile(dest): - result = subprocess.run( - ["patchelf", "--print-needed", dest], - check=True, - stdout=subprocess.PIPE, - text=True, - ) - cann_dependencies = ["libascendcl.so", "libacl_op_compiler.so", "libfmk_onnx_parser.so"] - args = ["patchelf", "--debug"] - for line in result.stdout.split("\n"): - for dependency in cann_dependencies: - if dependency in line: - if dependency not in to_preload: - to_preload_cann.append(line) - args.extend(["--remove-needed", line]) - args.append(dest) - if len(args) > 3: - subprocess.run(args, check=True, stdout=subprocess.PIPE) + + cuda_dependencies = [ + "libcublas.so.11", + "libcublasLt.so.11", + "libcudnn.so.8", + "libcudart.so.11.0", + "libcurand.so.10", + "libcufft.so.10", + ] + rocm_dependencies = [ + "librccl.so.1", + "libnuma.so.1", + "libamd_comgr.so.2", + "libdrm.so.2", + "librocblas.so.0", + "libdrm_amdgpu.so.1", + "libamdhip64.so.5", + "libroctracer64.so.4", + "libMIOpen.so.1", + "libtinfo.so.6", + "libelf.so.1", + "librocm_smi64.so.5", + "libhsa-runtime64.so.1", + ] + + tensorrt_dependencies = ["libnvinfer.so.8", "libnvinfer_plugin.so.8", "libnvonnxparser.so.8"] dest = "onnxruntime/capi/libonnxruntime_providers_openvino.so" if path.isfile(dest): @@ -308,10 +238,12 @@ def run(self): assert self.dist_dir is not None file = glob(path.join(self.dist_dir, "*linux*.whl"))[0] logger.info("repairing %s for manylinux1", file) + auditwheel_cmd = ["auditwheel", "-v", "repair", "-w", self.dist_dir, file] + for i in cuda_dependencies + rocm_dependencies + tensorrt_dependencies: + auditwheel_cmd += ["--exclude", i] + logger.info("Running {}".format(" ".join([shlex.quote(arg) for arg in auditwheel_cmd]))) try: - subprocess.run( - ["auditwheel", "repair", "-w", self.dist_dir, file], check=True, stdout=subprocess.PIPE - ) + subprocess.run(auditwheel_cmd, check=True, stdout=subprocess.PIPE) finally: logger.info("removing %s", file) remove(file) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 64dae354a9243..4c3e8b76e8029 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -96,7 +96,7 @@ def _openvino_verify_device_type(device_read): break def invalid_hetero_build(): - print("\nIf trying to build Hetero/Multi/Auto, specifiy the supported devices along with it.\n") + print("\nIf trying to build Hetero/Multi/Auto, specify the supported devices along with it.\n") print("specify the keyword HETERO or MULTI or AUTO followed by the devices ") print("in the order of priority you want to build\n") print("The different hardware devices that can be added in HETERO or MULTI or AUTO") @@ -107,7 +107,7 @@ def invalid_hetero_build(): sys.exit("Wrong Build Type selected") if res is False: - print("\nYou have selcted wrong configuration for the build.") + print("\nYou have selected wrong configuration for the build.") print("pick the build type for specific Hardware Device from following options: ", choices) print("(or) from the following options with graph partitioning disabled: ", choices1) print("\n") @@ -166,6 +166,15 @@ def convert_arg_line_to_args(self, arg_line): help="Use parallel build. The optional value specifies the maximum number of parallel jobs. " "If the optional value is 0 or unspecified, it is interpreted as the number of CPUs.", ) + parser.add_argument( + "--nvcc_threads", + nargs="?", + default=-1, + type=int, + help="Maximum number of NVCC threads in each parallel job." + "If the value is unspecified, it will be computed based on available memory and number of parallel jobs.", + ) + parser.add_argument("--test", action="store_true", help="Run unit tests.") parser.add_argument("--skip_tests", action="store_true", help="Skip all tests.") parser.add_argument( @@ -422,7 +431,7 @@ def convert_arg_line_to_args(self, arg_line): parser.add_argument("--wasm_run_tests_in_browser", action="store_true", help="Run WebAssembly tests in browser") parser.add_argument( - "--enable_wasm_profiling", action="store_true", help="Enable WebAsselby profiling and preserve function names" + "--enable_wasm_profiling", action="store_true", help="Enable WebAssembly profiling and preserve function names" ) parser.add_argument( "--enable_wasm_debug_info", action="store_true", help="Build WebAssembly with DWARF format debug info" @@ -519,7 +528,7 @@ def convert_arg_line_to_args(self, arg_line): "--llvm_config", type=str, default="", - help="Path to llvm-config.exe for LLVM buit from sources. It is strongly needed for build on Windows", + help="Path to llvm-config.exe for LLVM built from sources. It is strongly needed for build on Windows", ) parser.add_argument( "--skip_onnx_tests", @@ -864,6 +873,43 @@ def normalize_arg_list(nested_list): return [i for j in nested_list for i in j] if nested_list else [] +def number_of_parallel_jobs(args): + return os.cpu_count() if args.parallel == 0 else args.parallel + + +def number_of_nvcc_threads(args): + if args.nvcc_threads >= 0: + return args.nvcc_threads + + nvcc_threads = 1 + try: + import psutil + + available_memory = psutil.virtual_memory().available + if isinstance(available_memory, int) and available_memory > 0: + if available_memory > 60 * 1024 * 1024 * 1024: + # When available memory is large enough, chance of OOM is small. + nvcc_threads = 4 + else: + # NVCC need a lot of memory to compile 8 flash attention cu files in Linux or 4 cutlass fmha cu files in Windows. + # Here we select number of threads to ensure each thread has enough memory (>= 4 GB). For example, + # Standard_NC4as_T4_v3 has 4 CPUs and 28 GB memory. When parallel=4 and nvcc_threads=2, + # total nvcc threads is 4 * 2, which is barely able to build in 28 GB memory so we will use nvcc_threads=1. + memory_per_thread = 4 * 1024 * 1024 * 1024 + fmha_cu_files = 4 if is_windows() else 8 + fmha_parallel_jobs = min(fmha_cu_files, number_of_parallel_jobs(args)) + nvcc_threads = max(1, int(available_memory / (memory_per_thread * fmha_parallel_jobs))) + print( + f"nvcc_threads={nvcc_threads} to ensure memory per thread >= 4GB for available_memory={available_memory} and fmha_parallel_jobs={fmha_parallel_jobs}" + ) + except ImportError: + print( + "Failed to import psutil. Please `pip install psutil` for better estimation of nvcc threads. Use nvcc_threads=1" + ) + + return nvcc_threads + + def generate_build_tree( cmake_path, source_dir, @@ -1028,7 +1074,8 @@ def generate_build_tree( if args.use_migraphx: cmake_args.append("-Donnxruntime_MIGRAPHX_HOME=" + migraphx_home) if args.use_cuda: - cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(args.parallel)) + nvcc_threads = number_of_nvcc_threads(args) + cmake_args.append("-Donnxruntime_NVCC_THREADS=" + str(nvcc_threads)) if args.use_rocm: cmake_args.append("-Donnxruntime_ROCM_HOME=" + rocm_home) cmake_args.append("-Donnxruntime_ROCM_VERSION=" + args.rocm_version) @@ -2240,6 +2287,8 @@ def main(): args = parse_arguments() + print(args) + if os.getenv("ORT_BUILD_WITH_CACHE") == "1": args.use_cache = True @@ -2525,7 +2574,7 @@ def main(): if args.build: if args.parallel < 0: raise BuildError(f"Invalid parallel job count: {args.parallel}") - num_parallel_jobs = os.cpu_count() if args.parallel == 0 else args.parallel + num_parallel_jobs = number_of_parallel_jobs(args) build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, args.target) if args.test: diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index ea998963d956b..608112181bd4b 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -195,7 +195,7 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64-cuda - buildparameter: --use_cuda --cuda_version=11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" ${{parameters.AdditionalBuildFlag}} + buildparameter: --use_cuda --cuda_version=11.8 --cuda_home=$(Agent.TempDirectory)\v11.8 --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" ${{parameters.AdditionalBuildFlag}} runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true java_artifact_id: onnxruntime_gpu @@ -211,7 +211,7 @@ stages: buildArch: x64 msbuildPlatform: x64 packageName: x64-tensorrt - buildparameter: --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=11.8 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80" + buildparameter: --use_tensorrt --tensorrt_home="C:\local\TensorRT-8.6.1.6.Windows10.x86_64.cuda-11.8" --cuda_version=11.8 --cuda_home="C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8" --enable_onnx_tests --enable_wcos --build_java --cmake_extra_defines "CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80" runTests: ${{ parameters.RunOnnxRuntimeTests }} buildJava: true java_artifact_id: onnxruntime_gpu @@ -356,19 +356,32 @@ stages: - checkout: self submodules: false - template: templates/set-version-number-variables-step.yml - - task: DownloadPipelineArtifact@2 - displayName: 'Download Final Jar' - inputs: - buildType: 'current' - artifactName: 'onnxruntime-java-gpu' - targetPath: '$(Build.BinariesDirectory)/final-jar' - - task: Bash@3 + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Final Jar' + ArtifactName: onnxruntime-java-gpu + TargetPath: '$(Build.BinariesDirectory)/final-jar' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimeubi8packagestest + UpdateDepsTxt: false + + - bash: | + docker run --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + onnxruntimeubi8packagestest \ + /bin/bash /onnxruntime_src/tools/ci_build/github/linux/java_linux_final_test.sh -r /build -v $(OnnxRuntimeVersion) displayName: 'Test' - inputs: - targetType: filePath - filePath: 'tools/ci_build/github/linux/java_linux_final_test.sh' - arguments: '-r $(Build.BinariesDirectory) -v $(OnnxRuntimeVersion)' - template: templates/component-governance-component-detection-steps.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index b784ef72d6517..1e3d20b857d5c 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -45,10 +45,6 @@ stages: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -74,11 +70,11 @@ stages: inputs: script: | mkdir -p $HOME/.onnx - mkdir -p $(Pipeline.Workspace)/ccache docker run --rm \ --volume /data/onnx:/data/onnx:ro \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ --volume $(ORT_CACHE_DIR):/cache \ -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ @@ -90,135 +86,20 @@ stages: set -ex; \ ccache -s; \ /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ - --build_dir /build --cmake_generator Ninja \ + --build_dir /build --cmake_generator 'Ninja' \ --config Debug Release \ --skip_submodule_sync \ --build_shared_lib \ --parallel \ --build_wheel \ --build_csharp \ - --enable_onnx_tests \ - --enable_transformers_tool_test \ + --enable_onnx_tests --enable_symbolic_shape_infer_tests \ --use_cache \ - --build_java --build_nodejs --update --build --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON; \ + --build_java --build_nodejs --cmake_extra_defines onnxruntime_BUILD_BENCHMARKS=ON; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) - - task: UseDotNet@2 - displayName: "Setup dotnet" - inputs: - version: '6.0.408' - - - task: DotNetCoreCLI@2 - displayName: "Restore C# packages" - inputs: - command: 'restore' - projects: '$(Build.SourcesDirectory)/csharp/OnnxRuntime.DesktopOnly.CSharp.sln' - - # the props file was generated with docker container paths. convert to the 'real' path by replacing the - # the container path of '/build'. The '>' prefix is to match the closing angle bracket of the tag. - # e.g. /build/... so we only match the start of a path. - # We use powershell so we don't need extra escaping of the '/' chars in the path. - - task: CmdLine@2 - displayName: 'Update props from docker path to local and create models link' - inputs: - script: | - pwsh -Command '(Get-Content $(Build.SourcesDirectory)/csharp/Directory.Build.props) -replace ">/build", ">$(Build.BinariesDirectory)" | Set-Content $(Build.SourcesDirectory)/csharp/Directory.Build.props' - cat $(Build.SourcesDirectory)/csharp/Directory.Build.props - ln -s /data/models $(Build.BinariesDirectory)/models - - - task: DotNetCoreCLI@2 - displayName: 'dotnet build C# sln' - inputs: - command: 'build' - projects: '$(Build.SourcesDirectory)/csharp/OnnxRuntime.DesktopOnly.CSharp.sln' - - - task: DotNetCoreCLI@2 - displayName: 'dotnet test C#' - inputs: - command: 'test' - projects: '$(Build.SourcesDirectory)/csharp/OnnxRuntime.DesktopOnly.CSharp.sln' - # extra logging so all tests are listed in output to validate what's actually run - arguments: '-f net6.0 --no-build -l "console;verbosity=normal"' - workingDirectory: $(Build.SourcesDirectory)/csharp - - - task: CmdLine@2 - displayName: 'Install python deps and run java tests' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - cd $(Build.SourcesDirectory)/java - $(Build.SourcesDirectory)/java/gradlew "cmakeCheck" "-DcmakeBuildDir=$(Build.BinariesDirectory)/Release" - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_transformers_tool_test - --build_nodejs - --ctest_path "" - - - task: CmdLine@2 - displayName: 'Install Debug python package' - inputs: - script: | - set -e -x - rm -rf $(Build.BinariesDirectory)/Debug/onnxruntime $(Build.BinariesDirectory)/Debug/pybind11 - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml -qq - python3 -m pip install $(Build.BinariesDirectory)/Debug/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Debug unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Debug - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Debug - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_transformers_tool_test - --build_nodejs - --ctest_path "" - - - task: PythonScript@0 - displayName: 'Symbolic shape infer' - inputs: - scriptPath: $(Build.BinariesDirectory)/Release/onnxruntime_test_python_symbolic_shape_infer.py - workingDirectory: $(Build.BinariesDirectory)/Release - - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: @@ -245,8 +126,11 @@ stages: - stage: arm64_test dependsOn: ['arm64_build'] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/centos:7' + devtoolset_rootpath: /opt/rh/devtoolset-10/root + ld_library_path_arg: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/devtoolset-10/root/usr/bin:' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml index 5dc8fffbfecf8..461a62496c3b4 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml @@ -50,10 +50,6 @@ jobs: clean: true submodules: recursive - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu diff --git a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml index c1f1c39c855c0..f0c4422c7eac9 100644 --- a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml @@ -33,5 +33,3 @@ jobs: JobName: 'Linux_CI_Multi_GPU_TensorRT_Dev' # The latest TensorRT container only supports ubuntu20.04 and python 3.8 RunDockerBuildArgs: '-o ubuntu20.04 -d tensorrt -x "--enable_multi_device_test"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 2938b87ec6420..0264086c12ddc 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -32,6 +32,4 @@ jobs: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2023.0.0 -x "--use_openvino CPU_FP32 --build_wheel"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 53596a5ad50fd..12696e166a2e5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -103,6 +103,7 @@ jobs: ./build/Release/onnx_test_runner -e qnn \ -v -j 1 -c 1 -i "backend_path|$(QNN_SDK_ROOT)/lib/x86_64-linux-clang/libQnnHtp.so" \ /data/qdq_models + enabled: false - task: CmdLine@2 displayName: Run QDQ model tests with context cache enabled diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index cbe4e805bb219..b07d9a6089c17 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -50,19 +50,52 @@ stages: script: | ln -sf /data/models $(Build.BinariesDirectory) - - task: Bash@3 - displayName: 'Run Package Test' - inputs: - targetType: filePath - filePath: '$(Build.SourcesDirectory)/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh' - arguments: '$(Build.BinariesDirectory)/nuget-artifact $(NuGetPackageVersionNumber)' - workingDirectory: $(Build.BinariesDirectory) - env: - OnnxRuntimeBuildDirectory: $(Build.BinariesDirectory) - DisableContribOps: $(DisableContribOps) - DisableMlOps: $(DisableMlOps) - IsReleaseBuild: $(IsReleaseBuild) - PACKAGENAME: ${{ parameters.NugetPackageName }} + - ${{if contains(parameters.StageSuffix , 'GPU') }}: + - template: ../../templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker/ + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimepackagestest + - bash: | + docker run --rm \ + --gpus all \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + -e BUILD_SOURCESDIRECTORY='/onnxruntime_src' \ + -e OnnxRuntimeBuildDirectory='/build' \ + -e DisableContribOps='$(DisableContribOps)' \ + -e DisableMlOps='$(DisableMlOps)' \ + -e IsReleaseBuild='$(IsReleaseBuild)' \ + -e PACKAGENAME='${{ parameters.NugetPackageName }}' \ + onnxruntimepackagestest \ + /bin/bash -c " + set -ex; \ + pushd /build; \ + bash /onnxruntime_src/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh /build/nuget-artifact $(NuGetPackageVersionNumber); \ + popd + " + displayName: 'Run Package Test' + - ${{ else }}: + - task: CmdLine@2 + displayName: 'Create symlink for test models' + inputs: + script: | + ln -sf /data/models $(Build.BinariesDirectory) + - task: Bash@3 + displayName: 'Run Package Test' + inputs: + targetType: filePath + filePath: '$(Build.SourcesDirectory)/csharp/test/Microsoft.ML.OnnxRuntime.EndToEndTests/runtest.sh' + arguments: '$(Build.BinariesDirectory)/nuget-artifact $(NuGetPackageVersionNumber)' + workingDirectory: $(Build.BinariesDirectory) + env: + OnnxRuntimeBuildDirectory: $(Build.BinariesDirectory) + DisableContribOps: $(DisableContribOps) + DisableMlOps: $(DisableMlOps) + IsReleaseBuild: $(IsReleaseBuild) + PACKAGENAME: ${{ parameters.NugetPackageName }} - template: ../../templates/component-governance-component-detection-steps.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml index f5b221f23f8c4..355faa8b985ee 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml @@ -51,10 +51,6 @@ jobs: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -85,6 +81,7 @@ jobs: mkdir -p $(Pipeline.Workspace)/ccache docker run --rm \ --volume /data/onnx:/data/onnx:ro \ + --volume /data/models:/build/models:ro \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ @@ -107,53 +104,11 @@ jobs: --enable_onnx_tests \ --enable_training \ --use_cache \ - --build_java --build_nodejs --update --build; \ + --build_java --build_nodejs; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) - - task: CmdLine@2 - displayName: 'Install python deps and run java tests' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - cd $(Build.SourcesDirectory)/java - $(Build.SourcesDirectory)/java/gradlew "cmakeCheck" "-DcmakeBuildDir=$(Build.BinariesDirectory)/Release" - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_training - --build_nodejs - --ctest_path "" - - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml index 16d70a58a0827..da0a2a6026190 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml @@ -18,7 +18,6 @@ jobs: parameters: AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' JobName: 'Onnxruntime_Linux_GPU_Training' - SubmoduleCheckoutMode: 'recursive' RunDockerBuildArgs: > -o ubuntu20.04 -d gpu -t onnxruntime_orttraining_ortmodule_tests_image @@ -26,24 +25,16 @@ jobs: -e -x " --enable_training - --config $(buildConfig) + --config Release --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 --build_wheel --enable_nvtx_profile --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 " - DoNugetPack: 'false' RunInjectedPipeline: 'true' InjectedPipeline: 'orttraining-linux-gpu-ortmodule-test-ci-pipeline.yml' DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - BuildConfig: $(buildConfig) - ArtifactName: 'drop-linux' TimeoutInMinutes: 140 # Enable unreleased onnx opsets in CI builds # This facilitates testing the implementation for the new opsets AllowReleasedOpsetOnly: '0' - Strategy: - maxParallel: 2 - matrix: - Release: - buildConfig: Release diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml index 8806707d21317..ac551a53cddaa 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-cpu.yml @@ -1,14 +1,5 @@ trigger: none -variables: - - name: isMain - value: ${{ eq(variables['Build.SourceBranch'], 'refs/heads/main') }} - - name: finalStorage - ${{ if eq(variables['isMain'], 'true') }}: - value: '--final_storage' - ${{ else }}: - value: '' - resources: repositories: - repository: manylinux @@ -39,14 +30,6 @@ stages: PythonVersion: '3.11' steps: - - task: CmdLine@2 - displayName: 'check variables' - inputs: - script: | - echo "Branch is "${{ variables['Build.SourceBranch'] }} && \ - echo "isMain is "${{ variables['isMain'] }} && \ - echo "final_storage is "${{ variables['finalStorage'] }} - - checkout: self clean: true submodules: recursive @@ -102,17 +85,6 @@ stages: inputs: ArtifactName: onnxruntime_training_cpu - - task: CmdLine@2 - condition: succeeded() - displayName: 'Upload wheel' - inputs: - script: | - files=($(Build.ArtifactStagingDirectory)/Release/dist/*.whl) && \ - echo ${files[0]} && \ - echo ${{ variables['finalStorage'] }} && \ - tools/ci_build/upload_python_package_to_azure_storage.py \ - --python_wheel_path ${files[0]} ${{ variables['finalStorage'] }} - - template: templates/component-governance-component-detection-steps.yml parameters: condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index 872812a6aedd4..b858770583b3c 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -74,7 +74,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-MultiA10 @@ -93,7 +92,6 @@ stages: isX86: false job_name_suffix: x64_mimalloc RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -113,7 +111,6 @@ stages: isX86: false job_name_suffix: x64_no_memory_profiling RunOnnxRuntimeTests: false - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -133,7 +130,6 @@ stages: isX86: false job_name_suffix: x64_minimal_no_exception RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -153,7 +149,6 @@ stages: isX86: false job_name_suffix: x64_debug_node_input_output RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index c684e08ba1258..2db68d3bf1952 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -3,24 +3,38 @@ resources: - pipeline: build source: 'Python packaging pipeline' trigger: true + branch: rel-1.16.0 # branch to pick the artifact, Used only for manual triggered pipeline runs for testing the pipeline itself + #TODO: Remove the following dependency. Running python tests should not need to use manylinux. + repositories: + - repository: manylinux # The name used to reference this repository in the checkout step + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: - stage: Linux_Test_CPU_x86_64_stage jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' - device: 'CPU' + base_image: 'centos:7' + devtoolset_rootpath: /opt/rh/devtoolset-11/root + ld_library_path_arg: /opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/devtoolset-11/root/usr/bin:' - stage: Linux_Test_CPU_aarch64_stage dependsOn: [] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'aiinfra-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/centos:7' + devtoolset_rootpath: /opt/rh/devtoolset-10/root + ld_library_path_arg: /opt/rh/devtoolset-10/root/usr/lib64:/opt/rh/devtoolset-10/root/usr/lib:/opt/rh/devtoolset-10/root/usr/lib64/dyninst:/opt/rh/devtoolset-10/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/devtoolset-10/root/usr/bin:' - stage: Packages_Somking_Test dependsOn: [] @@ -31,19 +45,6 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_64_Wheels - itemPattern: '*/*win_amd64.whl' - machine_pool: - vmImage: 'windows-2022' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_32_Wheels - itemPattern: '*/*win32.whl' - python_arch: 'x86' - machine_pool: - vmImage: 'windows-2022' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels @@ -61,7 +62,7 @@ stages: - Linux_Test_CPU_aarch64_stage - Packages_Somking_Test jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cuda.yml parameters: arch: 'x86_64' machine_pool: 'Onnxruntime-Linux-GPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index 74007d9b55084..21cd3a44e8924 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -495,7 +495,7 @@ stages: PackageType: 'nuget' PackagePath: '$(Build.ArtifactStagingDirectory)' PackageName: 'Microsoft.ML.OnnxRuntime.*nupkg' - PlatformsSupported: 'win-x64,win-x86,linux-x64,linux-arm64,osx.10.14-x64' + PlatformsSupported: 'win-x64,win-x86,linux-x64,linux-arm64,osx-x64' VerifyNugetSigning: false - task: PublishPipelineArtifact@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/compliance.yml b/tools/ci_build/github/azure-pipelines/templates/compliance.yml index 04d999b556caa..0dfe398c8b836 100644 --- a/tools/ci_build/github/azure-pipelines/templates/compliance.yml +++ b/tools/ci_build/github/azure-pipelines/templates/compliance.yml @@ -18,27 +18,6 @@ steps: arguments: 'analyze $(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\*.dll --recurse --verbose' continueOnError: true -- task: DeleteFiles@1 - displayName: 'Delete files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - -# Manually set msBuildCommandline so that we can also set CAExcludePath -- task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - userProvideBuildInfo: msBuildInfo - msBuildArchitecture: x64 - msBuildVersion: 17.0 - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln" /p:platform="${{parameters.msbuildPlatform}}" /p:configuration="RelWithDebInfo" /p:CAExcludePath="$(Build.BinariesDirectory);$(Build.SourcesDirectory)\cmake;C:\program files (x86)" /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - - task: SdtReport@2 displayName: 'Create Security Analysis Report' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml b/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml index 0f4e0553d05bf..a83451a1b33d9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml +++ b/tools/ci_build/github/azure-pipelines/templates/flex-downloadPipelineArtifact.yml @@ -18,7 +18,7 @@ parameters: steps: - task: DownloadPipelineArtifact@2 - displayName: ${{ parameters.StepName }}} + displayName: ${{ parameters.StepName }} inputs: artifactName: ${{ parameters.ArtifactName}} targetPath: '${{ parameters.TargetPath }}' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 1cd21ea1991f1..b05602a57b166 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -34,11 +34,6 @@ parameters: type: boolean default: true -- name: RunStaticCodeAnalysis - displayName: Run Static Code Analysis - type: boolean - default: true - - name: ORT_EP_NAME type: string @@ -105,7 +100,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' force32bit: ${{ parameters.isX86 }} @@ -309,49 +304,6 @@ jobs: workingDirectory: '$(Build.BinariesDirectory)\${{ parameters.BuildConfig }}\${{ parameters.BuildConfig }}' displayName: 'Run tests' - - - ${{ if eq(parameters.RunStaticCodeAnalysis, true) }}: - - task: DeleteFiles@1 - displayName: 'Delete binaries files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - - - # Manually set msBuildCommandline so that we can also set CAExcludePath - # build_dir must be a sub folder of $(Build.SourcesDirectory) - # TODO: move this step to a CPU-only machine to save GPU resources. - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config RelWithDebInfo --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --build_shared_lib --update --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_onnx_tests ${{ parameters.additionalBuildFlags }} --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\RelWithDebInfo\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform=${{ parameters.msbuildPlatform }} /p:configuration=RelWithDebInfo /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - publishXML: true - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - - task: PublishSecurityAnalysisLogs@3 - displayName: 'Publish Security Analysis Logs' - continueOnError: true - - - task: PostAnalysis@2 - displayName: 'Guardian Break v2' - inputs: - GdnBreakGdnToolSDLNativeRulesSeverity: Note - GdnBreakGdnToolSDLNativeRules: true - - - - ${{ if eq(parameters.RunOnnxRuntimeTests, true) }}: - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml index 05b2dee77e689..7b9788d90b17d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml @@ -1,23 +1,14 @@ parameters: AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' StageName : 'Linux_CI_Dev' - SubmoduleCheckoutMode: '' RunDockerBuildArgs: '-o ubuntu20.04 -d cpu -x "--build_wheel"' - DoNodejsPack: 'false' - DoNugetPack: 'false' NuPackScript: '' RunInjectedPipeline: 'false' InjectedPipeline: '' DockerImageTag: '' - BuildConfig: '' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 # Controls whether unreleased onnx opsets are allowed. Default is set to 1 AllowReleasedOpsetOnly: '1' - # to inject strategy, you need to pass in the whole yaml structure - - # https://docs.microsoft.com/en-us/azure/devops/pipelines/yaml-schema?view=azure-devops&tabs=schema#strategies - # see example in orttraining-linux-gpu-ci-pipeline.yml - Strategy: '' jobs: - job: ${{ parameters.StageName }} @@ -28,16 +19,8 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} skipComponentGovernanceDetection: true pool: ${{ parameters.AgentPool }} - ${{ if ne(parameters.Strategy, '') }}: - strategy: - ${{ parameters.Strategy }} steps: - checkout: self - ${{ if ne(parameters.SubmoduleCheckoutMode, '') }}: - submodules: ${{ parameters.SubmoduleCheckoutMode }} - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - template: run-docker-build-steps.yml parameters: RunDockerBuildArgs: '${{ parameters.RunDockerBuildArgs }}' @@ -48,31 +31,10 @@ jobs: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() - - ${{ if eq(parameters['DoNugetPack'], 'true') }}: - - script: | - ${{ parameters.NuPackScript }} - displayName: 'Create Artifacts' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - - script: | - npm pack - cp $(Build.SourcesDirectory)/js/node/onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) - cp -R $(Build.SourcesDirectory)/js/node/prebuilds $(Build.ArtifactStagingDirectory)/prebuilds - workingDirectory: '$(Build.SourcesDirectory)/js/node' - displayName: 'Create NPM Package' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact: ${{ parameters.ArtifactName }}' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - ${{ if eq(parameters['RunInjectedPipeline'], 'true') }}: - template: | ${{ parameters.InjectedPipeline }} parameters: DockerImageTag: ${{ parameters.DockerImageTag }} - BuildConfig: ${{ parameters.BuildConfig }} + BuildConfig: Release - template: clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml index 5b9ffac6fabb0..e3cfa417d8943 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml @@ -46,7 +46,7 @@ stages: docker run --gpus all -e CC=/opt/rh/devtoolset-11/root/usr/bin/cc -e CXX=/opt/rh/devtoolset-11/root/usr/bin/c++ -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e NVIDIA_VISIBLE_DEVICES=all --rm --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ --volume /data/models:/build/models:ro --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecuda118xtrt86build \ /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py --build_dir /build --config Release \ - --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.buildJavaOption }} --use_tensorrt --cuda_version=$(CUDA_VERSION) --cuda_home=/usr/local/cuda-$(CUDA_VERSION) --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines CMAKE_CUDA_HOST_COMPILER=/opt/rh/devtoolset-11/root/usr/bin/cc 'CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80' + --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.buildJavaOption }} --use_tensorrt --cuda_version=$(CUDA_VERSION) --cuda_home=/usr/local/cuda-$(CUDA_VERSION) --cudnn_home=/usr --tensorrt_home=/usr --cmake_extra_defines CMAKE_CUDA_HOST_COMPILER=/opt/rh/devtoolset-11/root/usr/bin/cc 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' workingDirectory: $(Build.SourcesDirectory) - ${{ if eq(parameters.buildJava, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 93945a1cb5e96..4ee442a122878 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -79,7 +79,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 @@ -94,7 +94,6 @@ jobs: cd '$(Build.SourcesDirectory)/cmake/external/emsdk' ./emsdk install 3.1.44 ccache-git-emscripten-64bit ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - ln -s $(Build.SourcesDirectory)/cmake/external/emsdk/ccache/git-emscripten_64bit/bin/ccache /usr/local/bin/ccache displayName: 'emsdk install and activate ccache for emscripten' condition: eq('${{ parameters.WithCache }}', 'true') diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index adfcd98e37230..f5e5435cfaca0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -50,7 +50,7 @@ jobs: versionSpec: 3.11 - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml index 76fbf55331b07..79feae8cf517c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/templates/publish-nuget.yml @@ -57,7 +57,7 @@ stages: REM use a single .csv file to put the data echo os,arch,build_config,size > $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\linux-arm64\native\libonnxruntime.so | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo linux,aarch64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt - 7z.exe l -slt %%~ni.zip runtimes\osx.10.14-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt + 7z.exe l -slt %%~ni.zip runtimes\osx-x64\native\libonnxruntime.dylib | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo osx,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\win-x64\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x64,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt 7z.exe l -slt %%~ni.zip runtimes\win-x86\native\onnxruntime.dll | findstr /R /C:"^Size = [0-9]*" | for /F "tokens=3" %%a in ('more') do if not "%%a" == "" echo win,x86,default,%%a >> $(Build.BinariesDirectory)\binary_size_data.txt ) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index cee3bd9c9e968..8d5ca19a73535 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -39,36 +39,22 @@ jobs: versionSpec: $(PythonVersion) architecture: ${{ parameters.python_arch }} - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime' - targetPath: '$(Build.BinariesDirectory)/whl' - itemPattern: ${{parameters.itemPattern}} - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific + - download: build # pipeline resource identifier. + artifact: 'onnxruntime' - task: Bash@3 inputs: targetType: 'inline' script: | set -ex - files=(whl/*.whl) + files=(*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Build.BinariesDirectory)/whl" $PYTHON_PACKAGE_NAME - pip show $PYTHON_PACKAGE_NAME - python -c "import onnxruntime as ort; print(ort.__version__)" - workingDirectory: $(Build.BinariesDirectory) + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME + python3 -m pip show $PYTHON_PACKAGE_NAME + python3 -c "import onnxruntime as ort; print(ort.__version__)" + workingDirectory: $(Pipeline.Workspace)/build/onnxruntime displayName: Test Package Installation - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml new file mode 100644 index 0000000000000..00e40b520e867 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -0,0 +1,111 @@ +parameters: +- name: arch + type: string + +- name: base_image + type: string + +- name: devtoolset_rootpath + type: string + +- name: ld_library_path_arg + type: string + +- name: prepend_path + type: string + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_CPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + - download: current # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: current # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu + Context: tools/ci_build/github/linux/docker/inference/x64/python/cpu + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ parameters.base_image }} --build-arg PLATFORM=${{ parameters.arch }} --build-arg PREPEND_PATH=${{ parameters.prepend_path }} --build-arg LD_LIBRARY_PATH_ARG=${{ parameters.ld_library_path_arg }} --build-arg DEVTOOLSET_ROOTPATH=${{ parameters.devtoolset_rootpath }}" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + ${{ if eq(parameters.arch, 'aarch64') }}: + UpdateDepsTxt: false + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d CPU -c ${{parameters.cmake_build_type}} -i onnxruntimecpubuildpython${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml new file mode 100644 index 0000000000000..c521245b33a77 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -0,0 +1,93 @@ +parameters: +- name: arch + type: string + +- name: device + type: string + values: + - CPU + - GPU + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_GPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + # - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-gpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--network=host --build-arg POLICY=manylinux2014 --build-arg PLATFORM=x86_64 --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/devtoolset-11/root --build-arg PREPEND_PATH=/opt/rh/devtoolset-11/root/usr/bin: --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/devtoolset-11/root/usr/lib64:/opt/rh/devtoolset-11/root/usr/lib:/opt/rh/devtoolset-11/root/usr/lib64/dyninst:/opt/rh/devtoolset-11/root/usr/lib/dyninst:/usr/local/lib64 --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d GPU -c ${{parameters.cmake_build_type}} -i onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml deleted file mode 100644 index 8ddc917e8591e..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml +++ /dev/null @@ -1,85 +0,0 @@ -parameters: -- name: arch - type: string - -- name: device - type: string - -- name: machine_pool - type: string - -- name: extra_job_id - type: string - default: '' - -- name: python_wheel_suffix - type: string - default: '' - - -# TODO: Ideally it should fetch information from the build that triggers it -- name: cmake_build_type - type: string - default: 'Release' - values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel - -- name: timeout - type: number - default: 120 - -jobs: -- job: Linux_Test_${{ parameters.device }}${{ parameters.extra_job_id }}_${{ parameters.arch }} - timeoutInMinutes: ${{ parameters.timeout }} - variables: - skipComponentGovernanceDetection: true - workspace: - clean: all - pool: ${{ parameters.machine_pool }} - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'drop-linux-${{ lower(parameters.device) }}-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/${{parameters.cmake_build_type}}' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime${{ parameters.python_wheel_suffix }}' - targetPath: '$(Build.BinariesDirectory)/whl' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - - task: Bash@3 - displayName: 'Bash Script' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_tests.sh - arguments: -d ${{ parameters.device }} -c ${{parameters.cmake_build_type}} - - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 568ab6c8a8ba9..a872228f5d1e7 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -246,24 +246,6 @@ stages: workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run Python Tests' - #Skip it for 32 bits x86 build. Currently the scan tool has a bug: it doesn't allow me use 64 bits link.exe - #in 32 bits Win32 build. I tried all the settings but they all don't work. - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - condition: and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))) - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind --enable_onnx_tests --parallel $(TelemetryOption) --update --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\Debug\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform="$(MsbuildPlatform)" /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - task: TSAUpload@2 displayName: 'TSA upload' condition: and(and (succeeded(), and(eq(variables['buildArch'], 'x64'), eq(variables['PythonVersion'], '3.8'))), eq(variables['Build.SourceBranch'], 'refs/heads/main')) diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml index ef938a634554a..919749cac15b6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-gpu.yml @@ -22,65 +22,6 @@ parameters: default: '' jobs: -- ${{ if eq(parameters.PYTHON_VERSION, '3.8') }}: - - job: Win_py_${{ parameters.EP_NAME }}_Wheels_StaticAnalysis - timeoutInMinutes: 240 - workspace: - clean: all - pool: onnxruntime-Win-CPU-2022 - steps: - - checkout: self - clean: true - submodules: none - - task: UsePythonVersion@0 - inputs: - versionSpec: 3.8 - addToPath: true - architecture: 'x64' - - task: onebranch.pipeline.tsaoptions@1 - displayName: 'OneBranch TSAOptions' - inputs: - tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' - appendSourceBranchName: false - - - template: download-deps.yml - - - template: jobs/set-winenv.yml - parameters: - EnvSetupScript: ${{ parameters.ENV_SETUP_SCRIPT }} - DownloadCUDA: true - - - task: PythonScript@0 - displayName: 'Update deps.txt' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/replace_urls_in_deps.py - arguments: --new_dir $(Build.BinariesDirectory)/deps - workingDirectory: $(Build.BinariesDirectory) - - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --build_dir $(Build.SourcesDirectory)\b --skip_submodule_sync --cmake_generator "Visual Studio 17 2022" --enable_pybind ${{ parameters.BUILD_PY_PARAMETERS }} ${{ parameters.EP_BUILD_FLAGS }} --update --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON onnxruntime_ENABLE_LTO=OFF' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.SourcesDirectory)\b\Debug\onnxruntime.sln" /p:RunCodeAnalysis=true /p:platform=x64 /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.SourcesDirectory)\b#$(Build.SourcesDirectory)\cmake#C:\program files#C:\program files (x86)#C:\program files' - rulesetName: Custom - customRuleset: $(Build.SourcesDirectory)\cmake\Sdl.ruleset - publishXML: true - - - task: SdtReport@2 - displayName: 'Create Security Analysis Report' - inputs: - SDLNativeRules: true - - - task: TSAUpload@2 - displayName: 'TSA upload' - condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) - inputs: - GdnPublishTsaOnboard: false - GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' - - - job: Win_py_${{ parameters.EP_NAME }}_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} timeoutInMinutes: 240 workspace: diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 8c54e71448992..e63939ae0114c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -80,7 +80,7 @@ stages: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - script: brew install coreutils ninja npm yarn diff --git a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml index 15254ce4d1d5b..8bb3026520534 100644 --- a/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/stages/mac-ios-packaging-build-stage.yml @@ -116,7 +116,8 @@ stages: xcodeDeveloperDir: '/Applications/Xcode_${{ variables.xcodeVersion }}.app/Contents/Developer' signingOption: 'manual' signingIdentity: '$(APPLE_CERTIFICATE_SIGNING_IDENTITY)' - provisioningProfileName: 'iOS Team Provisioning Profile' + provisioningProfileName: 'temporary *' # temporary name, change it back to the original below later + #provisioningProfileName: 'iOS Team Provisioning Profile' args: '-derivedDataPath $(Build.BinariesDirectory)/app_center_test/ios_package_test/DerivedData' workingDirectory: '$(Build.BinariesDirectory)/app_center_test/ios_package_test/' displayName: 'Build App Center iPhone arm64 tests' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml index 4494fd36b336e..96e6ff89cd4f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml @@ -29,7 +29,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index e0a85cc1973bc..55535173bc341 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -74,7 +74,7 @@ stages: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index f6da7bb857b7d..8d28b4ce580b4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -101,7 +101,7 @@ stages: - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: jobs/set-winenv.yml parameters: @@ -263,25 +263,6 @@ stages: AnalyzeTargetGlob: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\**\*.dll' continueOnError: true - - task: DeleteFiles@1 - displayName: 'Delete files from $(Build.BinariesDirectory)\RelWithDebInfo' - inputs: - SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' - Contents: | - **/*.obj - **/*.pdb - **/*.dll - - #Manually set msBuildCommandline so that we can also set CAExcludePath - - task: SDLNativeRules@3 - displayName: 'Run the PREfast SDL Native Rules for MSBuild' - condition: and (succeeded(), eq(variables['msbuildPlatform'], 'x64')) - inputs: - msBuildArchitecture: amd64 - setupCommandlines: 'python $(Build.SourcesDirectory)\tools\ci_build\build.py --config Debug --disable_rtti --build_dir $(Build.BinariesDirectory) --skip_submodule_sync --build_shared_lib --update --cmake_generator "$(VSGenerator)" --enable_onnx_tests $(TelemetryOption) ${{ parameters.buildparameter }} --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON' - msBuildCommandline: '"C:\Program Files\Microsoft Visual Studio\2022\Enterprise\MSBuild\Current\Bin\amd64\msbuild.exe" "$(Build.BinariesDirectory)\Debug\onnxruntime.sln" /p:platform="$(MsbuildPlatform)" /p:configuration=Debug /p:VisualStudioVersion="17.0" /m /p:PreferredToolArchitecture=x64' - excludedPaths: '$(Build.BinariesDirectory)#$(Build.SourcesDirectory)\cmake#C:\program files (x86)' - - task: PostAnalysis@2 inputs: GdnBreakAllTools: false diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 9d36e2dbe4944..406683af80222 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -74,7 +74,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index 713396dd64532..90fc30141a916 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -72,7 +72,7 @@ jobs: displayName: 'Testing: force EOL to lf on windows for /js/**' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index 723567389579d..f7876f15029c1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -33,7 +33,7 @@ jobs: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index 0eb78412de418..58c0b8d353582 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -27,7 +27,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: BatchScript@1 displayName: 'setup env' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 7f71f41484b27..b7e3ce7940516 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -44,7 +44,6 @@ stages: isX86: false job_name_suffix: x64_debug RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -66,7 +65,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -86,7 +84,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: DNNL GenerateDocumentation: false @@ -108,7 +105,6 @@ stages: isX86: false job_name_suffix: x64_release RunOnnxRuntimeTests: true - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: XNNPACK GenerateDocumentation: false @@ -129,7 +125,6 @@ stages: job_name_suffix: x64_release_winml RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} # WinML has many warnings - RunStaticCodeAnalysis: false EnablePython: false isTraining: false ORT_EP_NAME: CPU @@ -150,7 +145,6 @@ stages: isX86: true job_name_suffix: x86_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false @@ -170,7 +164,6 @@ stages: isX86: false job_name_suffix: training_x64_debug RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false isTraining: true ORT_EP_NAME: CPU GenerateDocumentation: false @@ -190,7 +183,6 @@ stages: isX86: false job_name_suffix: training_x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: true isTraining: true ORT_EP_NAME: CPU GenerateDocumentation: false @@ -210,7 +202,6 @@ stages: isX86: false job_name_suffix: ort_training_apis_x64_release RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false EnablePython: false isTraining: true ORT_EP_NAME: CPU @@ -231,7 +222,6 @@ stages: isX86: false job_name_suffix: x64_release_azure RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false EnablePython: false isTraining: false ORT_EP_NAME: CPU diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 7ab55a5d803ce..806ed797f88f1 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -49,7 +49,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-A10 @@ -67,7 +66,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: CUDA WITH_CACHE: true # Some unit tests crash on A10 GPUs. So this job still needs to use A10. @@ -87,7 +85,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: ${{ parameters.RunOnnxRuntimeTests }} - RunStaticCodeAnalysis: false ORT_EP_NAME: DML WITH_CACHE: true MachinePool: onnxruntime-Win2022-GPU-dml-A10 @@ -106,7 +103,6 @@ stages: isX86: false job_name_suffix: x64_RelWithDebInfo RunOnnxRuntimeTests: false - RunStaticCodeAnalysis: false GenerateDocumentation: true ORT_EP_NAME: CUDA # It doesn't really matter which EP is selected here since this stage is for documentation. WITH_CACHE: true diff --git a/tools/ci_build/github/linux/build_cuda_c_api_package.sh b/tools/ci_build/github/linux/build_cuda_c_api_package.sh index 271f010a9d1c2..5ce1c2d2e80f8 100755 --- a/tools/ci_build/github/linux/build_cuda_c_api_package.sh +++ b/tools/ci_build/github/linux/build_cuda_c_api_package.sh @@ -1,10 +1,13 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + export CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" export CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" docker run --gpus all -e CFLAGS -e CXXFLAGS -e NVIDIA_VISIBLE_DEVICES=all --rm --volume \ $BUILD_SOURCESDIRECTORY:/onnxruntime_src --volume $BUILD_BINARIESDIRECTORY:/build \ --volume /data/models:/build/models:ro --volume /data/onnx:/data/onnx:ro -e NIGHTLY_BUILD onnxruntimecuda11centosbuild \ python3 /onnxruntime_src/tools/ci_build/build.py --build_java --build_dir /build --config Release \ ---skip_submodule_sync --parallel --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION \ +--skip_submodule_sync --parallel --nvcc_threads=1 --build_shared_lib --use_cuda --cuda_version=$CUDA_VERSION \ --cuda_home=/usr/local/cuda-$CUDA_VERSION --cudnn_home=/usr/local/cuda-$CUDA_VERSION \ ---cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80' +--cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=60;61;70;75;80' diff --git a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh b/tools/ci_build/github/linux/build_linux_arm64_python_package.sh index 58d7d32ac4b5f..a1a0d428c68f9 100755 --- a/tools/ci_build/github/linux/build_linux_arm64_python_package.sh +++ b/tools/ci_build/github/linux/build_linux_arm64_python_package.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + set -e -x # This script invokes build.py @@ -62,7 +65,7 @@ fi if [ "$BUILD_DEVICE" == "GPU" ]; then #Enable CUDA and TRT EPs. ONNXRUNTIME_CUDA_VERSION="11.8" - BUILD_ARGS+=("--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") + BUILD_ARGS+=("--nvcc_threads=1" "--use_cuda" "--use_tensorrt" "--cuda_version=$ONNXRUNTIME_CUDA_VERSION" "--tensorrt_home=/usr" "--cuda_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cudnn_home=/usr/local/cuda-$ONNXRUNTIME_CUDA_VERSION" "--cmake_extra_defines" "CMAKE_CUDA_ARCHITECTURES=52;60;61;70;75;80") fi export CFLAGS diff --git a/tools/ci_build/github/linux/build_yocto.sh b/tools/ci_build/github/linux/build_yocto.sh index e948a105c0cce..fab5173353d20 100755 --- a/tools/ci_build/github/linux/build_yocto.sh +++ b/tools/ci_build/github/linux/build_yocto.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + set -e -o -x SCRIPT_DIR="$( dirname "${BASH_SOURCE[0]}" )" TARGET_FOLDER="/datadrive/ARM" diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index 5832b0ec2ee65..af8f0ecb9c125 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts r:a:l:c:s:t: parameter_Option @@ -44,6 +46,7 @@ fi cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include diff --git a/tools/ci_build/github/linux/create_package.sh b/tools/ci_build/github/linux/create_package.sh index ed012a5abcc6f..305d261838bb1 100755 --- a/tools/ci_build/github/linux/create_package.sh +++ b/tools/ci_build/github/linux/create_package.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e SCRIPT=`realpath $0` SCRIPT_DIR=`dirname $SCRIPT` diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu index 033afde6aa93c..561df220afe38 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cpu @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -137,9 +135,7 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ build_scripts/requirements3.10.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 index dc52fb51d6389..8a092c437ae7e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 index 303e83eb23bca..68b779e6f13d6 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_4 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 index d17e4b24582fe..dfc9e819ade3c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_6_tensorrt8_5 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 index bcdc24d5eb61e..6e27db4eb3c7e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_cuda11_8_tensorrt8_6 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -147,7 +145,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.7.txt \ build_scripts/requirements3.8.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm index 9f7575d62e6c7..036d2610440b2 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_rocm @@ -52,7 +52,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -61,7 +60,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -164,7 +162,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 index 5d774460073ed..c3c7213212f5b 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_training_cuda11_8 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 new file mode 100644 index 0000000000000..cdf504c8e3b03 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -0,0 +1,45 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to Test ONNX Runtime on UBI8 with CUDA 11.8 and TensorRT 8.6 + +# Build base image with required system packages +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 AS base + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} + +RUN dnf install -y bash wget &&\ + dnf clean dbcache + +# Install python3 +RUN dnf install -y \ + python3.8 \ + python38-pip \ + python38-wheel &&\ + cd /usr/local/bin &&\ + ln -s /usr/bin/python3 python3.8 &&\ + ln -s /usr/bin/pip3 pip3.8; + +RUN pip3 install --upgrade pip +RUN pip3 install setuptools>=41.0.0 + +# Install TensorRT +RUN dnf install -y libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 +RUN v="8.6.1.6-1+cuda11.8" &&\ + dnf downgrade -y libnvinfer8-${v} libnvinfer8-${v} libnvonnxparsers8-${v} libnvparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-lean8-${v} libnvinfer-vc-plugin8-${v} libnvinfer-dispatch8-${v} &&\ + dnf install -y dnf-plugin-versionlock &&\ + dnf versionlock libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 +RUN dnf clean dbcache + + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && /tmp/scripts/install_java.sh && rm -rf /tmp/scripts + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 new file mode 100644 index 0000000000000..c211fa9b9e2b8 --- /dev/null +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 @@ -0,0 +1,53 @@ +# -------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------- +# Dockerfile to run ONNXRuntime with TensorRT integration + +# Build base image with required system packages +FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base + +# The local directory into which to build and install CMAKE +ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code + +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV DEBIAN_FRONTEND=noninteractive + +RUN apt-get update &&\ + apt-get install -y sudo git bash unattended-upgrades wget +RUN unattended-upgrade + +# Install python3 +RUN apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + python3-dev \ + python3-wheel &&\ + cd /usr/local/bin &&\ + ln -s /usr/bin/python3 python &&\ + ln -s /usr/bin/pip3 pip; + +RUN pip install --upgrade pip +RUN pip install setuptools>=41.0.0 + +# Install TensorRT +RUN v="8.6.1.6-1+cuda11.8" &&\ + apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ + apt-get update &&\ + sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ + libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v} libnvinfer-dispatch-dev=${v}\ + python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v} + +# Install Valgrind +RUN apt-get install -y valgrind + +ADD scripts /tmp/scripts +RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts + +# Build final image from base. +FROM base as final +ARG BUILD_USER=onnxruntimedev +ARG BUILD_UID=1000 +RUN adduser --uid $BUILD_UID $BUILD_USER +WORKDIR /home/$BUILD_USER +USER $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu index 8869a789028e0..691e45e743a11 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2014_cpu @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -132,7 +130,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index c0c6505ca010d..8a9c4dac1dd58 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -6,5 +6,5 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/manylinux.patch b/tools/ci_build/github/linux/docker/manylinux.patch index 7750118d01bb6..1a92b4c094765 100644 --- a/tools/ci_build/github/linux/docker/manylinux.patch +++ b/tools/ci_build/github/linux/docker/manylinux.patch @@ -50,6 +50,17 @@ index 961e34d..55ae11b 100755 make install > /dev/null } +diff --git a/finalize.sh b/finalize.sh +index 621eab9..4cbcf90 100755 +--- a/finalize.sh ++++ b/finalize.sh +@@ -86,6 +86,3 @@ clean_pyc /opt/_internal + rm -rf /root/.cache + + hardlink -cv /opt/_internal +- +-# update system packages +-LC_ALL=C ${MY_DIR}/update-system-packages.sh diff --git a/install-entrypoint.sh b/install-entrypoint.sh index 9ef1e99..ec52833 100755 --- a/install-entrypoint.sh @@ -65,7 +76,7 @@ index 9ef1e99..ec52833 100755 +fi \ No newline at end of file diff --git a/install-runtime-packages.sh b/install-runtime-packages.sh -index 137d2e2..21b60a7 100755 +index 137d2e2..7a17e16 100755 --- a/install-runtime-packages.sh +++ b/install-runtime-packages.sh @@ -73,9 +73,11 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then @@ -83,3 +94,15 @@ index 137d2e2..21b60a7 100755 elif [ "${AUDITWHEEL_ARCH}" == "aarch64" ] || [ "${AUDITWHEEL_ARCH}" == "ppc64le" ] || [ "${AUDITWHEEL_ARCH}" == "s390x" ]; then # Software collection (for devtoolset-10) yum -y install centos-release-scl-rh +@@ -121,11 +123,6 @@ else + exit 1 + fi + +-# update system packages, we already updated them but +-# the following script takes care of cleaning-up some things +-# and since it's also needed in the finalize step, everything's +-# centralized in this script to avoid code duplication +-LC_ALL=C ${MY_DIR}/update-system-packages.sh + + if [ "${BASE_POLICY}" == "manylinux" ]; then + # we'll be removing libcrypt.so.1 later on diff --git a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh new file mode 100755 index 0000000000000..d89a5e84c1ac1 --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh @@ -0,0 +1,24 @@ +#!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set -e -x + +if [ -f /etc/redhat-release ]; then + rpm -Uvh https://packages.microsoft.com/config/centos/7/packages-microsoft-prod.rpm + yum install -y dotnet-sdk-6.0 +elif [ -f /etc/os-release ]; then + # Get Ubuntu version + declare repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) + # Download Microsoft signing key and repository + wget https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + # Install Microsoft signing key and repository + dpkg -i packages-microsoft-prod.deb + # Clean up + rm packages-microsoft-prod.deb + # Update packages + apt-get update && apt-get install -y dotnet-sdk-6.0 +else + echo "Unsupported OS" + exit 1 +fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_java.sh b/tools/ci_build/github/linux/docker/scripts/install_java.sh new file mode 100755 index 0000000000000..d11e29f693b8b --- /dev/null +++ b/tools/ci_build/github/linux/docker/scripts/install_java.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e -x + +if [ -f /etc/redhat-release ]; then + dnf install -y java-11-openjdk-devel \ + && dnf clean dbcache +elif [ -f /etc/os-release ]; then + apt-get update && apt-get install -y openjdk-11-jdk +else + echo "Unsupported OS" + exit 1 +fi diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh index c34abbd2ba873..e569e58d549f2 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh @@ -3,7 +3,7 @@ set -e -x # Development tools and libraries if [ -f /etc/redhat-release ]; then - yum update && yum -y install graphviz + yum -y install graphviz os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) elif [ -f /etc/os-release ]; then apt-get update && apt-get install -y graphviz @@ -13,6 +13,9 @@ else exit 1 fi +# Install dotnet +source $(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)/install_dotnet.sh + if [ ! -d "/opt/conda/bin" ]; then PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11") else diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index c8ff7a804e1df..6b8003c01c24d 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -6,6 +6,6 @@ setuptools>=41.4.0 wheel git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers neural-compressor>=2.2.1 diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 2248652c98043..9dbe856753faa 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -7,7 +7,7 @@ setuptools>=41.4.0 wheel>=0.35.1 git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx argparse -sympy==1.10.1 +sympy==1.12 flatbuffers protobuf==3.20.2 packaging diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt index 202d43befcca4..891291b6fa733 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage2/requirements.txt @@ -2,7 +2,7 @@ pandas scikit-learn numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' -transformers==v4.4.2 +transformers==v4.16.1 rsa==4.9 tensorboard>=2.2.0,<2.5.0 h5py diff --git a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh index 9492b7bcf59a6..a442a93203174 100755 --- a/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh +++ b/tools/ci_build/github/linux/extract_and_bundle_gpu_package.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts a: parameter_Option diff --git a/tools/ci_build/github/linux/java_copy_strip_binary.sh b/tools/ci_build/github/linux/java_copy_strip_binary.sh index 329c1b0ab9b9e..8004e37a73fa0 100755 --- a/tools/ci_build/github/linux/java_copy_strip_binary.sh +++ b/tools/ci_build/github/linux/java_copy_strip_binary.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts r:a:l:n:c:h:v: parameter_Option diff --git a/tools/ci_build/github/linux/run_python_dockerbuild.sh b/tools/ci_build/github/linux/run_python_dockerbuild.sh index 18ac6482827f9..2ab01eacada34 100755 --- a/tools/ci_build/github/linux/run_python_dockerbuild.sh +++ b/tools/ci_build/github/linux/run_python_dockerbuild.sh @@ -1,4 +1,7 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + set -e -x BUILD_CONFIG="Release" diff --git a/tools/ci_build/github/linux/run_python_dockertest.sh b/tools/ci_build/github/linux/run_python_dockertest.sh new file mode 100755 index 0000000000000..7b080a9047d9a --- /dev/null +++ b/tools/ci_build/github/linux/run_python_dockertest.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +set -e -x +BUILD_CONFIG="Release" + +while getopts "i:d:x:c:" parameter_Option +do case "${parameter_Option}" +in +i) DOCKER_IMAGE=${OPTARG};; +d) DEVICE=${OPTARG};; +c) BUILD_CONFIG=${OPTARG};; +esac +done + +if [ $DEVICE = "GPU" ]; then + ADDITIONAL_DOCKER_PARAMETER="--gpus all" +fi + +mkdir -p $HOME/.onnx +docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ + --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -w /onnxruntime_src \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + $ADDITIONAL_DOCKER_PARAMETER \ + $DOCKER_IMAGE tools/ci_build/github/linux/run_python_tests.sh -d $DEVICE -c $BUILD_CONFIG diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh index 90362a3315e06..cecb790b19b2a 100755 --- a/tools/ci_build/github/linux/run_python_tests.sh +++ b/tools/ci_build/github/linux/run_python_tests.sh @@ -15,7 +15,8 @@ c) BUILD_CONFIG=${OPTARG};; esac done -cd $BUILD_BINARIESDIRECTORY +export PATH=/opt/python/cp38-cp38/bin:$PATH +cd /build files=(whl/*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) @@ -23,7 +24,7 @@ PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') echo "Package name:$PYTHON_PACKAGE_NAME" -BUILD_ARGS="--build_dir $BUILD_BINARIESDIRECTORY --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " +BUILD_ARGS="--build_dir /build --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " ARCH=$(uname -m) @@ -34,20 +35,15 @@ fi if [ $BUILD_DEVICE == "GPU" ]; then BUILD_ARGS="$BUILD_ARGS --use_cuda --use_tensorrt --cuda_version=11.8 --tensorrt_home=/usr --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8" fi -# We assume the machine doesn't have gcc and python development header files, so we don't build onnxruntime from source -sudo rm -rf /build /onnxruntime_src -sudo ln -s $BUILD_SOURCESDIRECTORY /onnxruntime_src -python3 -m pip uninstall -y $PYTHON_PACKAGE_NAME ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq +python3 -m pip install --upgrade pip # Install the packages that are needed for installing the onnxruntime python package -python3 -m pip install -r $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/requirements.txt +python3 -m pip install -r /build/$BUILD_CONFIG/requirements.txt # Install the packages that are needed for running test scripts -# Install the latest ONNX release which may contain not fixed bugs. However, it is what most people use. -python3 -m pip install onnx pytest +python3 -m pip install pytest # The "--no-index" flag is crucial. The local whl folder is just an additional source. Pypi's doc says "there is no # ordering in the locations that are searched" if we don't disable the default one with "--no-index" -python3 -m pip install --no-index --find-links $BUILD_BINARIESDIRECTORY/whl $PYTHON_PACKAGE_NAME -ln -s /data/models $BUILD_BINARIESDIRECTORY -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG +python3 -m pip install --no-index --find-links /build/whl $PYTHON_PACKAGE_NAME +cd /build/$BUILD_CONFIG # Restore file permissions xargs -a perms.txt chmod a+x -python3 $BUILD_SOURCESDIRECTORY/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' +python3 /onnxruntime_src/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' diff --git a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh index 56f5ff9f9eac0..ee4339c24d399 100755 --- a/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh +++ b/tools/ci_build/github/linux/test_custom_ops_pytorch_export.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. pip3 install --user --upgrade pip diff --git a/tools/ci_build/github/linux/upload_ortsrv_binaries.sh b/tools/ci_build/github/linux/upload_ortsrv_binaries.sh index 9d6f9406e4181..dcbea27847a8a 100755 --- a/tools/ci_build/github/linux/upload_ortsrv_binaries.sh +++ b/tools/ci_build/github/linux/upload_ortsrv_binaries.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e -o -x while getopts a:r:i:c:p:b: parameter_Option diff --git a/tools/ci_build/github/linux/yocto_build_toolchain.sh b/tools/ci_build/github/linux/yocto_build_toolchain.sh index 26d4f583487c6..a4e7fe7a38255 100755 --- a/tools/ci_build/github/linux/yocto_build_toolchain.sh +++ b/tools/ci_build/github/linux/yocto_build_toolchain.sh @@ -1,4 +1,6 @@ #!/bin/bash +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. set -e YOCTO_VERSION="4.19" diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements.txt index 620da1afa1f00..96659d70af81f 100644 --- a/tools/ci_build/requirements.txt +++ b/tools/ci_build/requirements.txt @@ -1,7 +1,8 @@ -# packages used by transformers tool test +# packages used by transformers python unittest (only enabled in Linux CPU CI Pipeline) packaging protobuf==3.20.2 numpy==1.24.0 coloredlogs==15.0 transformers==4.30.0 psutil +einops \ No newline at end of file diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index 9dc36633a553e..3aba1d0577f9c 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -67,7 +67,7 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs, is_versioned_dylib = re.match(r".*[\.\d+]+\.dylib$", child_file.name) if child_file.is_file() and child_file.suffix == ".dylib" and not is_versioned_dylib: files_list.append( - '' % cpu_arch + '' % cpu_arch ) for cpu_arch in ["x64", "aarch64"]: if child.name == get_package_name("linux", cpu_arch, ep, is_training_package): diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 43b904ce77ad0..b2dd331ccef2c 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -379,13 +379,6 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin name += tokenizedModelPath[tokenizedModelPath.size() - 2] += "_"; // model name name += tokenizedModelPath[tokenizedModelPath.size() - 3]; // opset version - // To introduce models from model zoo, the model path is structured like this "///?.onnx" - std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; - // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. - if (source != "models") { - name += "_" + source; - } - std::replace_if( name.begin(), name.end(), [](char c) { return !google::protobuf::ascii_isalnum(c); }, '_' ); @@ -404,6 +397,13 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin ModifyNameIfDisabledTest(/*inout*/ name, deviceKind); } + // To introduce models from model zoo, the model path is structured like this "///?.onnx" + std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; + // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. + if (source != "models") { + name += "_" + source; + } + return name; }