diff --git a/.github/workflows/sca.yml b/.github/workflows/sca.yml
deleted file mode 100644
index 1416f5a4d33a9..0000000000000
--- a/.github/workflows/sca.yml
+++ /dev/null
@@ -1,133 +0,0 @@
-name: Windows_SCA
-on:
- push:
- branches:
- - main
- - rel-*
- pull_request:
-
-concurrency:
- group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
- cancel-in-progress: true
-
-env:
- AZCOPY_AUTO_LOGIN_TYPE: MSI
- AZCOPY_MSI_CLIENT_ID: 63b63039-6328-442f-954b-5a64d124e5b4
-
-jobs:
- Onnxruntime-SCA-training-CUDA:
- runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"]
- steps:
- - uses: actions/checkout@v3
- with:
- submodules: false
- - uses: actions/setup-python@v4
- with:
- python-version: '3.11.x'
- architecture: 'x64'
-
- - uses: actions/setup-node@v3
- with:
- node-version: 18
-
- - name: Download cuda
- run: azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" cuda_sdk
-
-
- - name: Delete build folder
- run: |
- if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b }
- &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug
-
- # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter.
- - name: Build code
- env:
- CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake'
- run: python tools\ci_build\build.py --windows_sdk_version 10.0.22621.0 --enable_training --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --enable_pybind --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --use_cuda --cuda_home=${{ github.workspace }}\cuda_sdk\v11.8 --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75
-
- - name: Generate sarif
- working-directory: D:\b
- run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output
-
- - name: Upload SARIF to GitHub
- uses: github/codeql-action/upload-sarif@v2
- continue-on-error: true
- with:
- sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
- category: VS_SCA
-
- # No python
- Onnxruntime-SCA-win32-WINML-x64:
- runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"]
- steps:
- - uses: actions/checkout@v3
- with:
- submodules: false
- - uses: actions/setup-python@v4
- with:
- python-version: '3.11.x'
- architecture: 'x64'
-
- - uses: actions/setup-node@v3
- with:
- node-version: 18
-
- - name: Delete build folder
- run: |
- if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b }
- &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x64 -install_prefix D:\b\Debug\installed -build_config Debug
-
- # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter.
- - name: Build code
- env:
- CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake'
- run: python tools\ci_build\build.py --build_java --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib
-
- - name: Generate sarif
- working-directory: D:\b
- run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output
-
- - name: Upload SARIF to GitHub
- uses: github/codeql-action/upload-sarif@v2
- continue-on-error: true
- with:
- sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
- category: VS_SCA_WIN32_WINML_X64
-
- # No java, No python
- Onnxruntime-SCA-win32-WINML-x86:
- runs-on: ["self-hosted", "1ES.Pool=onnxruntime-github-vs2022-mms"]
- steps:
- - uses: actions/checkout@v3
- with:
- submodules: false
- - uses: actions/setup-python@v4
- with:
- python-version: '3.11.x'
- architecture: 'x86'
-
- - uses: actions/setup-node@v3
- with:
- node-version: 18
-
- - name: Delete build folder
- run: |
- if (Test-Path D:\b) { Remove-Item -Recurse -Force D:\b }
- &tools\ci_build\github\windows\install_third_party_deps.ps1 -cpu_arch x86 -install_prefix D:\b\Debug\installed -build_config Debug
-
- # The build machine doesn't have a GPU. So the value of CMAKE_CUDA_ARCHITECTURES doesn't matter.
- - name: Build code
- env:
- CAExcludePath: 'C:\Program Files;D:\b;${{ github.workspace }}\cmake'
- run: python tools\ci_build\build.py --compile_no_warning_as_error --config Debug --build_dir D:\b --skip_submodule_sync --build_csharp --update --build --parallel --cmake_generator "Visual Studio 17 2022" --build_shared_lib --cmake_extra_defines onnxruntime_USE_CUSTOM_STATIC_ANALYSIS_RULES=ON --cmake_extra_defines onnxruntime_ENABLE_STATIC_ANALYSIS=ON --cmake_extra_defines onnxruntime_REDIRECT_STATIC_ANALYSIS_OUTPUTS_TO_FILE=ON --ms_experimental --use_dml --use_winml --disable_rtti --enable_wcos --build_shared_lib
-
- - name: Generate sarif
- working-directory: D:\b
- run: npx @microsoft/sarif-multitool merge *.sarif --recurse --output-directory=${{ github.workspace }}\output --output-file=MergeResult.sarif --merge-runs && dir ${{ github.workspace }}\output
-
- - name: Upload SARIF to GitHub
- uses: github/codeql-action/upload-sarif@v2
- continue-on-error: true
- with:
- sarif_file: ${{ github.workspace }}\output\MergeResult.sarif
- category: VS_SCA_WIN32_WINML_X86
diff --git a/.gitmodules b/.gitmodules
index 036a248070855..7bb49e98bfec1 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -8,6 +8,3 @@
path = cmake/external/emsdk
url = https://github.com/emscripten-core/emsdk.git
branch = 3.1.44
-[submodule "cmake/external/onnxruntime-extensions"]
- path = cmake/external/onnxruntime-extensions
- url = https://github.com/microsoft/onnxruntime-extensions.git
diff --git a/ThirdPartyNotices.txt b/ThirdPartyNotices.txt
index 6f6faa3a2e56f..985eb645664c8 100644
--- a/ThirdPartyNotices.txt
+++ b/ThirdPartyNotices.txt
@@ -6230,3 +6230,37 @@ https://github.com/intel/neural-compressor
terms, and open source software license terms. These separate license terms
govern your use of the third party programs as set forth in the
"THIRD-PARTY-PROGRAMS" file.
+
+_____
+
+FlashAttention, https://github.com/Dao-AILab/flash-attention
+
+BSD 3-Clause License
+
+Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/VERSION_NUMBER b/VERSION_NUMBER
index 15b989e398fc7..4a02d2c3170bd 100644
--- a/VERSION_NUMBER
+++ b/VERSION_NUMBER
@@ -1 +1 @@
-1.16.0
+1.16.2
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index b01ed00350bb0..82a454791d159 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -84,7 +84,8 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
-option(onnxruntime_USE_FLASH_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
+cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
+option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
@@ -666,13 +667,16 @@ if (onnxruntime_USE_CUDA)
if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
+ set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
+ set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
+ set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
endif()
if (onnxruntime_USE_CUDA)
@@ -685,6 +689,11 @@ if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
endif()
+ if (onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
+ message( STATUS "Enable memory efficient attention for CUDA EP")
+ list(APPEND ORT_PROVIDER_FLAGS -DUSE_MEMORY_EFFICIENT_ATTENTION=1)
+ list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_MEMORY_EFFICIENT_ATTENTION=1)
+ endif()
endif()
if (onnxruntime_USE_VITISAI)
diff --git a/cmake/external/cutlass.cmake b/cmake/external/cutlass.cmake
index 18ac668bb1592..8c5d81d638ced 100644
--- a/cmake/external/cutlass.cmake
+++ b/cmake/external/cutlass.cmake
@@ -1,4 +1,4 @@
-if (onnxruntime_USE_FLASH_ATTENTION)
+if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
include(FetchContent)
FetchContent_Declare(
cutlass
diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
index c087ad8f6d81e..8e412c7847b70 100644
--- a/cmake/external/onnxruntime_external_deps.cmake
+++ b/cmake/external/onnxruntime_external_deps.cmake
@@ -46,8 +46,8 @@ if (onnxruntime_BUILD_UNIT_TESTS)
FetchContent_Declare(
googletest
URL ${DEP_URL_googletest}
+ FIND_PACKAGE_ARGS 1.13.0...<2.0.0 NAMES GTest
URL_HASH SHA1=${DEP_SHA1_googletest}
- OVERRIDE_FIND_PACKAGE
)
endif()
@@ -528,4 +528,3 @@ endif()
FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR)
FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR)
-
diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake
index 5adfc7ba03923..03360ff30c4c4 100644
--- a/cmake/onnxruntime_providers.cmake
+++ b/cmake/onnxruntime_providers.cmake
@@ -529,7 +529,7 @@ if (onnxruntime_USE_CUDA)
target_link_libraries(${target} PRIVATE cuda)
endif()
- if (onnxruntime_USE_FLASH_ATTENTION)
+ if (onnxruntime_USE_FLASH_ATTENTION OR onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION)
include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index c8592a4019461..ecee52f642b1f 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -201,6 +201,10 @@ set(training_ops_excluded_files
"reduction/reduction_ops.cc" # no double type support
"cuda_training_kernels.cc"
"cuda_training_kernels.h"
+ "nn/conv_shared.cc"
+ "nn/conv_shared.h"
+ "nn/conv_transpose_grad.cc"
+ "nn/conv_transpose_grad.h"
)
function(auto_set_source_files_hip_language)
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
index b374371446a90..86b44a6784817 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
@@ -743,7 +743,7 @@ internal static OrtValue CreateFromTensorObject(TensorBase value, out TensorElem
///
/// Creates an OrtValue that contains a string tensor of specified shape, and
/// containing empty strings. String tensors are always on CPU.
- /// Use FillStringTensorElement to assign individual elements values.
+ /// Use StringTensorSetElementAt to assign individual elements values.
///
///
/// disposable OrtValue
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
index c52ca4d1a4631..ac790242409e3 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs
@@ -15,6 +15,7 @@ public struct OrtTrainingApi
public IntPtr LoadCheckpoint;
public IntPtr SaveCheckpoint;
public IntPtr CreateTrainingSession;
+ public IntPtr CreateTrainingSessionFromBuffer;
public IntPtr TrainingSessionGetTrainingModelOutputCount;
public IntPtr TrainingSessionGetEvalModelOutputCount;
public IntPtr TrainingSessionGetTrainingModelOutputName;
diff --git a/docs/python/README.rst b/docs/python/README.rst
index 7d978b0941235..bcf7c635afd82 100644
--- a/docs/python/README.rst
+++ b/docs/python/README.rst
@@ -8,6 +8,16 @@ For more information on ONNX Runtime, please see `aka.ms/onnxruntime ((b & 0x80000000) >> 24); // sign
- if ((b & 0x7fc00000) == 0x7fc00000) {
- val |= 0x7f;
- } else if ((b & 0x7fffffff) == 0x7f800000) {
+ if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
val |= 126;
} else {
val |= 0x7f;
}
+ } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
+ val |= 0x7f;
} else {
uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent
uint32_t m = static_cast(b & 0x007FFFFF); // mantissa
if (e != 0) {
- if (e < 117) { // 0b1110101
- } else if (e < 118) { // 0b1110110
- val |= 1;
- if ((m >> 23) & 1) {
- // rounding
- val += 1;
+ if (e < 117) {
+ } else if (e < 121) {
+ // denormalized number
+ auto d = 120 - e;
+ if (d < 3) {
+ val |= 1 << (2 - d);
+ val |= m >> (21 + d);
+ } else if (m > 0) {
+ val |= 1;
}
- } else if (e < 121) { // 127 - 7 + 1 // 0b1111001
- auto d = 120 - e; // 0b1111000
- val |= 1 << (2 - d);
- val |= m >> (21 + d);
- if ((m >> (20 + d)) & 1) {
+ auto mask = 1 << (20 + d);
+ if ((m & mask) &&
+ ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
- } else if (e < 136) { // 127 + 8 + 1 // 0b10001000
- auto ex = e - 120; // 127 - 7
+ } else if (e < 136) {
+ // normalized number
+ auto ex = e - 120;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
@@ -83,7 +85,7 @@ struct Float8E4M3FN {
val &= 0xFE;
}
}
- if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7C000))) {
+ if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7E) {
// rounding
val += 1;
@@ -147,14 +149,22 @@ struct Float8E4M3FN {
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080
- explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { val = *reinterpret_cast(&value); }
+ explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) {
+ val = *reinterpret_cast(&value);
+ }
explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3() const { return *reinterpret_cast(&val); }
#endif
};
-inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val == right.val; }
-inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val != right.val; }
-inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val < right.val; }
+inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) {
+ return left.val == right.val;
+}
+inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) {
+ return left.val != right.val;
+}
+inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) {
+ return left.val < right.val;
+}
// User defined suffixes to make it easier to declare
// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char
@@ -164,9 +174,7 @@ inline Float8E4M3FN operator"" _f8e4m3fn(unsigned long long int v) {
return Float8E4M3FN(narrow(v), Float8E4M3FN::FromBits());
}
-inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) {
- return Float8E4M3FN(static_cast(v), true);
-}
+inline Float8E4M3FN operator"" _f8e4m3fnp8(long double v) { return Float8E4M3FN(static_cast(v), true); }
#endif
@@ -205,36 +213,38 @@ struct Float8E4M3FNUZ {
std::memcpy(&b, &v, sizeof(b));
val = static_cast((b & 0x80000000) >> 24); // sign
- if ((b & 0x7fc00000) == 0x7fc00000) {
- val = 0x80;
- } else if ((b & 0x7fffffff) == 0x7f800000) {
+ if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
val |= 0x7F;
} else {
// infinity
val = 0x80;
}
+ } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
+ val = 0x80;
} else {
uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent
uint32_t m = static_cast(b & 0x007FFFFF); // mantissa
if (e != 0) {
if (e < 116) {
- } else if (e < 117) {
- val |= 1;
- if ((m >> 23) & 1) {
- // rounding
- val += 1;
- }
- } else if (e < 120) { // 127 - 8 + 1
+ } else if (e < 120) {
+ // denormalized number
auto d = 119 - e;
- val |= 1 << (2 - d);
- val |= m >> (21 + d);
- if ((m >> (20 + d)) & 1) {
+ if (d < 3) {
+ val |= 1 << (2 - d);
+ val |= m >> (21 + d);
+ } else if (m > 0) {
+ val |= 1;
+ }
+ auto mask = 1 << (20 + d);
+ if ((m & mask) &&
+ ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
- } else if (e < 135) { // 127 + 8
- auto ex = e - 119; // 127 - 7
+ } else if (e < 135) {
+ // normalized number
+ auto ex = e - 119;
if (ex == 0) {
val |= 0x4;
val |= m >> 21;
@@ -242,7 +252,7 @@ struct Float8E4M3FNUZ {
val |= ex << 3;
val |= m >> 20;
}
- if (m & 0x80000) {
+ if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
@@ -303,9 +313,15 @@ struct Float8E4M3FNUZ {
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
};
-inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val == right.val; }
-inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val != right.val; }
-inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val < right.val; }
+inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) {
+ return left.val == right.val;
+}
+inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) {
+ return left.val != right.val;
+}
+inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) {
+ return left.val < right.val;
+}
// User defined suffixes to make it easier to declare
// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char
@@ -315,9 +331,7 @@ inline Float8E4M3FNUZ operator"" _f8e4m3p8fnuz(unsigned long long int v) {
return Float8E4M3FNUZ(narrow(v), Float8E4M3FNUZ::FromBits());
}
-inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) {
- return Float8E4M3FNUZ(static_cast(v), true);
-}
+inline Float8E4M3FNUZ operator"" _f8e4m3fnuzp8(long double v) { return Float8E4M3FNUZ(static_cast(v), true); }
#endif
@@ -357,32 +371,33 @@ struct Float8E5M2 {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));
- val = (b & 0x80000000) >> 24; // sign
- if ((b & 0x7fc00000) == 0x7fc00000) {
- val |= 0x7f;
- } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
+ val = (b & 0x80000000) >> 24; // sign
+ if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
val |= 0x7B;
} else {
val |= 0x7C;
}
+ } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
+ val |= 0x7f;
} else {
uint32_t e = (b & 0x7F800000) >> 23; // exponent
uint32_t m = b & 0x007FFFFF; // mantissa
if (e != 0) {
if (e < 110) {
- } else if (e < 111) {
- val |= 1;
- if ((m >> 23) & 1) {
- // rounding
- val += 1;
- }
- } else if (e < 113) { // 127 - 15 + 1
+ } else if (e < 113) {
+ // denormalized number
auto d = 112 - e;
- val |= 1 << (1 - d);
- val |= m >> (22 + d);
- if ((m >> (21 + d)) & 1) {
+ if (d < 2) {
+ val |= 1 << (1 - d);
+ val |= m >> (22 + d);
+ } else if (m > 0) {
+ val |= 1;
+ }
+ auto mask = 1 << (21 + d);
+ if ((m & mask) &&
+ ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
@@ -461,8 +476,12 @@ struct Float8E5M2 {
#endif
};
-inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { return left.val == right.val; }
-inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { return left.val != right.val; }
+inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) {
+ return left.val == right.val;
+}
+inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) {
+ return left.val != right.val;
+}
inline ORT_HOST_DEVICE bool operator<(const Float8E5M2& left, const Float8E5M2& right) { return left.val < right.val; }
// User defined suffixes to make it easier to declare
@@ -473,9 +492,7 @@ inline Float8E5M2 operator"" _f8e5m2fn(unsigned long long int v) {
return Float8E5M2(narrow(v), Float8E5M2::FromBits());
}
-inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) {
- return Float8E5M2(static_cast(v), true);
-}
+inline Float8E5M2 operator"" _f8e5m2fnp8(long double v) { return Float8E5M2(static_cast(v), true); }
#endif
@@ -513,40 +530,42 @@ struct Float8E5M2FNUZ {
uint32_t b;
std::memcpy(&b, &v, sizeof(b));
- val = (b & 0x80000000) >> 24; // sign
- if ((b & 0x7fc00000) == 0x7fc00000) {
- val = 0x80;
- } else if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
+ val = (b & 0x80000000) >> 24; // sign
+ if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
val |= 0x7F;
} else {
val = 0x80;
}
+ } else if ((b & 0x7F800000) == 0x7F800000) { // NaN
+ val = 0x80;
} else {
uint32_t e = (b & 0x7F800000) >> 23; // exponent
uint32_t m = b & 0x007FFFFF; // mantissa
if (e != 0) {
if (e < 109) {
- } else if (e < 110) {
- val |= 1;
- if ((m >> 23) & 1) {
- // rounding
- val += 1;
- }
- } else if (e < 112) { // 127 - 16 + 1
+ } else if (e < 112) {
+ // denormalized number
auto d = 111 - e;
- val |= 1 << (1 - d);
- val |= m >> (22 + d);
- if ((m >> (21 + d)) & 1) {
+ if (d < 2) {
+ val |= 1 << (1 - d);
+ val |= m >> (22 + d);
+ } else if (m > 0) {
+ val |= 1;
+ }
+ auto mask = 1 << (21 + d);
+ if ((m & mask) &&
+ ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) {
// rounding
val += 1;
}
- } else if (e < 143) { // 127 + 15 + 1
+ } else if (e < 143) {
+ // normalized number
auto ex = e - 111;
val |= ex << 2;
val |= m >> 21;
- if (m & 0x100000) {
+ if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) {
if ((val & 0x7F) < 0x7F) {
// rounding
val += 1;
@@ -554,7 +573,7 @@ struct Float8E5M2FNUZ {
val = 0x80;
}
}
- } else if ((e == 255) && (m == 0)) { // inf
+ } else if ((e == 255) && (m == 0)) {
val = 0x80;
} else if (saturate) {
val |= 0x7F;
@@ -605,9 +624,15 @@ struct Float8E5M2FNUZ {
inline ORT_HOST_DEVICE operator float() const { return ToFloat(); }
};
-inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val == right.val; }
-inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val != right.val; }
-inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val < right.val; }
+inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) {
+ return left.val == right.val;
+}
+inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) {
+ return left.val != right.val;
+}
+inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) {
+ return left.val < right.val;
+}
// User defined suffixes to make it easier to declare
// initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char
@@ -617,9 +642,7 @@ inline Float8E5M2FNUZ operator"" _f8e5m2fnuz(unsigned long long int v) {
return Float8E5M2FNUZ(narrow(v), Float8E5M2FNUZ::FromBits());
}
-inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) {
- return Float8E5M2FNUZ(static_cast(v), true);
-}
+inline Float8E5M2FNUZ operator"" _f8e5m2fnuzp8(long double v) { return Float8E5M2FNUZ(static_cast(v), true); }
#endif
diff --git a/include/onnxruntime/core/framework/ort_value.h b/include/onnxruntime/core/framework/ort_value.h
index 48c4e4320dfd7..a071f3182faad 100644
--- a/include/onnxruntime/core/framework/ort_value.h
+++ b/include/onnxruntime/core/framework/ort_value.h
@@ -68,11 +68,7 @@ struct OrtValue {
}
bool IsSparseTensor() const {
-#if !defined(DISABLE_SPARSE_TENSORS)
return (type_ != nullptr && type_->IsSparseTensorType());
-#else
- ORT_THROW("Sparse tensor is not supported in this build.");
-#endif
}
onnxruntime::MLDataType Type() const {
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index bc7792ba4366b..456a11603de65 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -4333,8 +4333,12 @@ struct OrtApi {
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
- * \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
- * The array will be passed back to run_async_callback
+ * \param[out] output OrtValue* array of size output_names_len.
+ * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue.
+ * Later, the output array will be passed to run_async_callback with all null(s) filled with valid
+ * OrtValue pointer(s) allocated by onnxruntime.
+ * NOTE: it is customer's duty to finally release the output array and each of its member,
+ * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer.
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
index b9b6676c0072d..47356c3fe3608 100644
--- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h
@@ -1073,11 +1073,15 @@ struct SessionImpl : ConstSessionImpl {
*
* \param[in] run_options
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
- * \param[in] input_values Array of ::OrtValue%s of the input values
+ * \param[in] input_values Array of Value objects of length input_count
* \param[in] input_count Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
- * \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
- * The array will be passed back to the callback
+ * \param[out] output_values Array of provided Values to be filled with outputs.
+ * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
+ * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
+ * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
+ * NOTE: it is customer's duty to finally release output_values and each of its member,
+ * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
* \param[in] output_count Number of elements in the output_names and outputs array
* \param[in] callback Callback function on model run completion
* \param[in] user_data User data that pass back to the callback
diff --git a/js/common/lib/version.ts b/js/common/lib/version.ts
index 8f597765ebe8a..3e303bcf64b8e 100644
--- a/js/common/lib/version.ts
+++ b/js/common/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.16.0';
+export const version = '1.16.2';
diff --git a/js/common/package-lock.json b/js/common/package-lock.json
index b9e5fd6082457..69cb6b60aaf35 100644
--- a/js/common/package-lock.json
+++ b/js/common/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-common",
- "version": "1.16.0",
+ "version": "1.16.2",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-common",
- "version": "1.16.0",
+ "version": "1.16.2",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/common/package.json b/js/common/package.json
index 331f17dbc44be..06616c3247c07 100644
--- a/js/common/package.json
+++ b/js/common/package.json
@@ -2,7 +2,7 @@
"license": "MIT",
"type": "module",
"name": "onnxruntime-common",
- "version": "1.16.0",
+ "version": "1.16.2",
"repository": {
"url": "https://github.com/Microsoft/onnxruntime.git",
"type": "git"
diff --git a/js/node/lib/version.ts b/js/node/lib/version.ts
index 8f597765ebe8a..3e303bcf64b8e 100644
--- a/js/node/lib/version.ts
+++ b/js/node/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.16.0';
+export const version = '1.16.2';
diff --git a/js/node/package-lock.json b/js/node/package-lock.json
index bd01302262273..6994f70a45233 100644
--- a/js/node/package-lock.json
+++ b/js/node/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-node",
- "version": "1.16.0",
+ "version": "1.16.2",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-node",
- "version": "1.16.0",
+ "version": "1.16.2",
"license": "MIT",
"os": [
"win32",
@@ -27,7 +27,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.16.0",
+ "version": "1.16.2",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/node/package.json b/js/node/package.json
index c898aeb56c0f5..faa07d1149fab 100644
--- a/js/node/package.json
+++ b/js/node/package.json
@@ -13,7 +13,7 @@
3
]
},
- "version": "1.16.0",
+ "version": "1.16.2",
"dependencies": {
"onnxruntime-common": "file:../common"
},
diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc
index 78f32ec09250b..f8aeadbe27c56 100644
--- a/js/node/src/inference_session_wrap.cc
+++ b/js/node/src/inference_session_wrap.cc
@@ -68,7 +68,7 @@ Napi::Value InferenceSessionWrap::LoadModel(const Napi::CallbackInfo &info) {
int64_t bytesOffset = info[1].As().Int64Value();
int64_t bytesLength = info[2].As().Int64Value();
- ParseSessionOptions(info[1].As(), sessionOptions);
+ ParseSessionOptions(info[3].As(), sessionOptions);
this->session_.reset(
new Ort::Session(OrtEnv(), reinterpret_cast(buffer) + bytesOffset, bytesLength, sessionOptions));
} else {
diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts
index b3f0c466308a5..058531f415d61 100644
--- a/js/react_native/lib/backend.ts
+++ b/js/react_native/lib/backend.ts
@@ -66,12 +66,14 @@ class OnnxruntimeSessionHandler implements SessionHandler {
let results: Binding.ModelLoadInfoType;
// load a model
if (typeof this.#pathOrBuffer === 'string') {
+ // load model from model path
results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options);
} else {
+ // load model from buffer
if (!this.#inferenceSession.loadModelFromBlob) {
throw new Error('Native module method "loadModelFromBlob" is not defined');
}
- const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer);
+ const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer.buffer);
results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options);
}
// resolve promise if onnxruntime session is successfully created
diff --git a/js/react_native/lib/version.ts b/js/react_native/lib/version.ts
index 8f597765ebe8a..3e303bcf64b8e 100644
--- a/js/react_native/lib/version.ts
+++ b/js/react_native/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.16.0';
+export const version = '1.16.2';
diff --git a/js/react_native/package.json b/js/react_native/package.json
index 3020a04f0af31..2c19037257051 100644
--- a/js/react_native/package.json
+++ b/js/react_native/package.json
@@ -36,7 +36,7 @@
"registry": "https://registry.npmjs.org/"
},
"source": "lib/index",
- "version": "1.16.0",
+ "version": "1.16.2",
"main": "dist/commonjs/index",
"homepage": "https://github.com/microsoft/onnxruntime/blob/main/js/react_native/README.md",
"files": [
diff --git a/js/react_native/yarn.lock b/js/react_native/yarn.lock
index 21734bc50b000..ff2cfd2c8f98f 100644
--- a/js/react_native/yarn.lock
+++ b/js/react_native/yarn.lock
@@ -5188,7 +5188,7 @@ onetime@^5.1.0, onetime@^5.1.2:
mimic-fn "^2.1.0"
"onnxruntime-common@file:../common":
- version "1.16.0"
+ version "1.16.2"
open@^6.2.0:
version "6.4.0"
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index 4a1109b9ec5dc..e33854819c5db 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -38,7 +38,7 @@ Do not modify directly.*
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| Gelu | com.microsoft(1+) | |
-| Gemm | ai.onnx(7-8,9-10,11+) | |
+| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
diff --git a/js/web/lib/version.ts b/js/web/lib/version.ts
index 8f597765ebe8a..3e303bcf64b8e 100644
--- a/js/web/lib/version.ts
+++ b/js/web/lib/version.ts
@@ -4,4 +4,4 @@
// This file is generated by /js/scripts/update-version.ts
// Do not modify file content manually.
-export const version = '1.16.0';
+export const version = '1.16.2';
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
index 1d490aa9028ff..82fe3d5b6af43 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts
@@ -26,43 +26,41 @@ import {ConvTransposeAttributes} from '../conv-transpose';
const createConvTranspose2DOpProgramShaderSource =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes,
- outputShape: readonly number[], hasBias: boolean, elementsPerThread: readonly number[]): string => {
+ outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => {
const isChannelsLast = attributes.format === 'NHWC';
const rowDim = isChannelsLast ? 1 : 2;
const colDim = isChannelsLast ? 2 : 3;
const channelDim = isChannelsLast ? 3 : 1;
const outputSize = ShapeUtil.size(outputShape);
- const outChannels = outputShape[isChannelsLast ? 3 : 1];
- const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
- const isVec4 = inChannels % 4 === 0 && outChannels % 4 === 0;
const workPerThread = isVec4 ? 2 : 1;
+ const group = attributes.group;
+ const wShape = inputs[1].dims;
+ const inputChannelsPerGroup = wShape[0] / group;
+ const outputChannelsPerGroup = wShape[1];
- const innerElementSize = isVec4 ? (isChannelsLast && inChannels % 4 !== 0 ? 3 : 4) : elementsPerThread[0];
-
- const declareInputs = [
- `@group(0) @binding(0) var Dy: array<${
- isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`,
- `@group(0) @binding(1) var W: array<${isVec4 ? 'vec4' : 'f32'}>;`
- ];
let declareFunctions = `
fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) {
result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value);
}`;
if (hasBias) {
- declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`);
declareFunctions += `
fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} {
return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}];
}`;
}
- const w = inputVariable('W', inputs[1].dataType, inputs[1].dims);
- const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims);
- const output = outputVariable('result', inputs[0].dataType, outputShape);
+ const components = isVec4 ? 4 : 1;
+ const w = inputVariable('W', inputs[1].dataType, inputs[1].dims, components);
+ const dy = inputVariable('Dy', inputs[0].dataType, inputs[0].dims, components);
+ const inputVariables = [dy, w];
+ if (hasBias) {
+ inputVariables.push(inputVariable('bias', inputs[2].dataType, [outputShape[channelDim]], components));
+ }
+ const output = outputVariable('result', inputs[0].dataType, outputShape, components);
const codeSnippet4 = `{
- let batch: u32 = global_id.z / outShape[1];
- let r = global_id.z % outShape[1];
- let c = global_id.y * ${workPerThread};
- let d1: u32 = global_id.x * 4;
+ let batch: u32 = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} / outShape[1];
+ let r = ${is1DimensionDispatch ? 'global_id.z' : 'workgroup_id.z'} % outShape[1];
+ let c = ${is1DimensionDispatch ? 'global_id.y' : 'workgroup_id.y'} * ${workPerThread};
+ let d1: u32 = ${is1DimensionDispatch ? 'global_id.x' : 'workgroup_id.x'} * 4;
let dyCorner = vec2(i32(r), i32(c)) - vec2(pads);
@@ -73,18 +71,21 @@ const createConvTranspose2DOpProgramShaderSource =
dotProd[i] = vec4(0.0);
}
for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) {
- var dyR = f32(dyCorner.x + wR) / f32(strides.x);
- let wRPerm: u32= filterDims[0] - 1 - wR;
+ var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x);
+ let wRPerm = filterDims[0] - 1 - wR;
if (dyR < 0.0 || dyR >= f32(outBackprop[1]) ||
- fract(dyR) > 0.0) {
+ fract(dyR) > 0.0 || wRPerm < 0) {
continue;
}
let idyR: u32 = u32(dyR);
for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) {
- let dyC = f32(dyCorner.y + wC) / f32(strides.y);
- let dyC2 = f32(dyCorner.y + 1 + wC) / f32(strides.y);
- let wCPerm: u32 = filterDims[1] - 1 - wC;
+ let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y);
+ let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y);
+ let wCPerm = filterDims[1] - 1 - wC;
+ if (wCPerm < 0) {
+ continue;
+ }
var bDyCVal = true;
var bDyCVal2 = true;
if (dyC < 0.0 || dyC >= f32(outBackprop[2]) ||
@@ -101,57 +102,53 @@ const createConvTranspose2DOpProgramShaderSource =
if (bDyCVal && bDyCVal2) {
let d2Length = outBackprop[3];
for (var d2 :u32 = 0; d2 < d2Length; d2 = d2 + 4) {
- let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')};
- let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')};
- let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')};
- let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')};
+ let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
+ let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
+ let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
+ let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};
- var xValue = ${
- isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')};
- let tmpval = vec4(xValue * wValue0,
- xValue * wValue1,
- xValue * wValue2,
- xValue * wValue3);
+ var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
+ let tmpval = vec4(dot(xValue, wValue0),
+ dot(xValue, wValue1),
+ dot(xValue, wValue2),
+ dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;
- xValue = ${
- isChannelsLast ? dy.get('batch', 'idyR', 'idyC2', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC2')};
+ xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
- dotProd[1] = dotProd[1] + vec4(xValue * wValue0,
- xValue * wValue1,
- xValue * wValue2,
- xValue * wValue3);
+ dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0),
+ dot(xValue, wValue1),
+ dot(xValue, wValue2),
+ dot(xValue, wValue3));
}
} else if (bDyCVal) {
- let d2Length = outBackprop[3];
+ let d2Length = outBackprop[${channelDim}];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
- let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')};
- let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')};
- let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')};
- let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')};
+ let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
+ let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
+ let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
+ let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};
- var xValue = ${
- isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')};
- let tmpval = vec4(xValue * wValue0,
- xValue * wValue1,
- xValue * wValue2,
- xValue * wValue3);
+ var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')};
+ let tmpval = vec4(dot(xValue, wValue0),
+ dot(xValue, wValue1),
+ dot(xValue, wValue2),
+ dot(xValue, wValue3));
dotProd[0] = dotProd[0] + tmpval;
}
} else if (bDyCVal2) {
let d2Length = outBackprop[3];
for (var d2: u32 = 0; d2 < d2Length; d2 = d2 + 4) {
- let wValue0 = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')};
- let wValue1 = ${w.get('d2', 'd1 + 1', 'wRPerm', 'wCPerm')};
- let wValue2 = ${w.get('d2', 'd1 + 2', 'wRPerm', 'wCPerm')};
- let wValue3 = ${w.get('d2', 'd1 + 3', 'wRPerm', 'wCPerm')};
+ let wValue0 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1', 'd2')};
+ let wValue1 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 1', 'd2')};
+ let wValue2 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 2', 'd2')};
+ let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')};
- var xValue = ${
- isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')};
- let tmpval = vec4(xValue * wValue0,
- xValue * wValue1,
- xValue * wValue2,
- xValue * wValue3);
+ var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')};
+ let tmpval = vec4(dot(xValue, wValue0),
+ dot(xValue, wValue1),
+ dot(xValue, wValue2),
+ dot(xValue, wValue3));
dotProd[1] = dotProd[1] + tmpval;
}
}
@@ -159,16 +156,21 @@ const createConvTranspose2DOpProgramShaderSource =
}
for (var i: u32 = 0; i < ${workPerThread}; i = i + 1) {
- ${output.set('batch', 'r', 'c+i', 'd1', 'dotProd[i]')};
+ let value = dotProd[i] + ${hasBias ? 'bias[c+i]' : '0.0'};
+ ${output.set('batch', 'r', 'c + i', 'd1', 'value')};
}
}`;
const codeSnippet = `
let outputIndices = ${output.offsetToIndices('global_idx')};
- let batch = outputIndices[0];
- let d1 = outputIndices[${channelDim}];
- let dyCorner = vec2(i32(outputIndices[${rowDim}]), i32(outputIndices[${colDim}])) - pads;
+ let batch = ${output.indicesGet('outputIndices', 0)};
+ let d1 = ${output.indicesGet('outputIndices', channelDim)};
+ let r = ${output.indicesGet('outputIndices', rowDim)};
+ let c = ${output.indicesGet('outputIndices', colDim)};
+ let dyCorner = vec2(i32(r), i32(c)) - pads;
let dyRCorner = dyCorner.x;
let dyCCorner = dyCorner.y;
+ let groupId = d1 / ${outputChannelsPerGroup};
+ let wOutChannel = d1 - groupId * ${outputChannelsPerGroup};
// Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).
// ? = to be determined. : = across all values in that axis.
var dotProd = 0.0;
@@ -178,7 +180,7 @@ const createConvTranspose2DOpProgramShaderSource =
}
let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]);
let wRPerm = filterDims.x - 1 - wR / dilations.x;
- if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || fract(dyR) > 0.0 ||
+ if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 ||
wRPerm < 0) {
continue;
}
@@ -190,30 +192,29 @@ const createConvTranspose2DOpProgramShaderSource =
}
let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y);
let wCPerm = filterDims.y - 1 - wC / dilations.y;
- if (dyC < 0.0 || dyC >= f32(outBackprop[2]) ||
+ if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) ||
fract(dyC) > 0.0 || wCPerm < 0) {
continue;
}
let idyC: u32 = u32(dyC);
- for (var d2: u32 = 0; d2 < outBackprop[3]; d2 = d2 + 1) {
+ for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) {
+ let inputChannel = groupId * ${inputChannelsPerGroup} + d2;
let xValue = ${
- isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'd2') : dy.get('batch', 'd2', 'idyR', 'idyC')};
- let wValue = ${w.get('d2', 'd1', 'wRPerm', 'wCPerm')};
+ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') :
+ dy.get('batch', 'inputChannel', 'idyR', 'idyC')};
+ let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')};
dotProd = dotProd + xValue * wValue;
}
}
}
- ${output.setByOffset('global_idx', 'dotProd')};
+ let value = dotProd + ${hasBias ? 'bias[d1]' : '0.0'};
+ ${output.setByOffset('global_idx', 'value')};
`;
return `
- ${w.impl('indicesToOffset', 'get')}
- ${dy.impl('indicesToOffset', 'get')}
- ${output.impl('offsetToIndices')}
+ ${shaderHelper.declareVariables(...inputVariables, output)}
${declareFunctions}
- ${declareInputs.join('\n')}
- @group(0) @binding(${declareInputs.length}) var result: array<${isVec4 ? 'vec4' : 'f32'}>;
const outShape : vec4 = vec4(${outputShape.join(',')});
const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')});
const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]});
@@ -240,25 +241,18 @@ export const createConvTranspose2DProgramInfo =
(inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes,
squeezeOutputShapeFunction?: (shape: readonly number[]) => number[]): ProgramInfo => {
const hasBias = inputs.length > 2;
- const isChannelsLast = attributes.format === 'NHWC';
+ // const isChannelsLast = attributes.format === 'NHWC';
const outputShape = attributes.outputShape;
- const batchSize = outputShape[0];
- const outWidth = outputShape[isChannelsLast ? 1 : 2];
- const outHeight = outputShape[isChannelsLast ? 2 : 3];
- const outChannels = outputShape[isChannelsLast ? 3 : 1];
- const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
- const isVec4 = inChannels % 4 === 0 && outChannels % 4 === 0;
+ const outputSize = ShapeUtil.size(outputShape);
- const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight;
- const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels;
- const workGroupSize: [number, number, number] =
- isVec4 ? [8, 8, 1] : [dispatchX <= 4 ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1];
- const elementsPerThread =
- isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 2, dispatchX > 4 && dispatchY <= 4 ? 1 : 2, 1];
+ // const inChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
+ // TODO Enable isVec4 for performance
+ // Disabled due to weight matrix layout issue
+ // const isVec4 = attributes.group === 1 && isChannelsLast && inChannels % 4 === 0 && outChannels % 4 === 0;
const dispatch = [
- Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]),
- Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]),
- Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[1])
+ Math.ceil(outputSize / 64),
+ 1,
+ 1,
];
LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`);
@@ -271,6 +265,6 @@ export const createConvTranspose2DProgramInfo =
}],
dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}),
getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource(
- shaderHelper, inputs, attributes, outputShape, hasBias, elementsPerThread),
+ shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1),
};
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
index 5f3d1564664bf..02b978a381de5 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts
@@ -50,8 +50,6 @@ const createBinaryOpProgramShader =
};
broadcastImpl = `
- ${output.impl('offsetToIndices')}
-
fn calcOffsetA(outputIndices: ${output.type.indices}) -> u32 {
return ${calcOffsetImpl(dimsA)};
}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index e64c74972581d..7da57bcb9c647 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -16,28 +16,6 @@ import {ShapeUtil} from '../../util';
**/
export const WORKGROUP_SIZE = 64;
-interface IndicesHelperImplementations {
- /**
- * implementation of `offsetToIndices` function.
- */
- readonly offsetToIndices: string;
-
- /**
- * implementation of `indicesToOffset` function.
- */
- readonly indicesToOffset: string;
-
- /**
- * implementation of `set`, `setByIndices` and `setByOffset` function.
- */
- readonly set: string;
-
- /**
- * implementation of `get`, `getByIndices` and `getByOffset` function.
- */
- readonly get: string;
-}
-
interface IndicesHelperTypes {
/**
* WGSL type of indices expression
@@ -96,12 +74,10 @@ interface IndicesHelperTypes {
*/
export interface IndicesHelper {
/**
- * get WGSL code of function implementation for the util functions
+ * get WGSL code of function implementation for the util functions.
*
- * @param functions - a list of function names to get implementation for. If not specified, all functions will be
- * returned.
*/
- readonly impl: (...functions: ReadonlyArray) => string;
+ readonly impl: () => string;
/**
* get type info
@@ -215,9 +191,12 @@ export interface IndicesHelper {
readonly shape: readonly number[];
}
-const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, string] => {
+const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, string] => {
// return type is [ storage type, runtime type ] or a single string for both
switch (type) {
+ // TODO: enable after "shader-f16" WSGL extension release
+ // case DataType.float16:
+ // return components > 1 ? `vec${components}` : 'f16';
case DataType.float:
return components > 1 ? `vec${components}` : 'f32';
case DataType.int32:
@@ -245,6 +224,11 @@ const getWgslValueType = (type: number, components: 1|2|3|4): string|[string, st
}
};
+export const tensorTypeToWsglStorageType = (type: DataType, components: 1|2|3|4 = 1) => {
+ const mappedType = getWgslMappedType(type, components);
+ return typeof mappedType === 'string' ? mappedType : mappedType[0];
+};
+
/**
* A helper function to get a IndicesHelper for a given input or output.
*
@@ -260,13 +244,22 @@ const createIndicesHelper =
components: 1|2|3|4): IndicesHelper => {
const rank = shape.length;
const indicesType = rank < 2 ? 'u32' : rank <= 4 ? `vec${rank}` : `array`;
- const mappedType = getWgslValueType(tensorType, components);
+ const mappedType = getWgslMappedType(tensorType, components);
const valueType = typeof mappedType === 'string' ? mappedType : mappedType[1];
const storageType = typeof mappedType === 'string' ? mappedType : mappedType[0];
const type = {indices: indicesType, value: valueType, storage: storageType, tensor: tensorType};
const normalizeDim = (dim: number|string): string => typeof dim === 'string' ? dim : `${dim}u`;
+ const implementationUsed = {
+ offsetToIndices: false,
+ indicesToOffset: false,
+ set: false,
+ setByIndices: false,
+ get: false,
+ getByIndices: false,
+ };
+
const strides = ShapeUtil.computeStrides(shape);
let o2iSnippet = '';
for (let i = 0; i < rank - 1; i++) {
@@ -287,7 +280,10 @@ const createIndicesHelper =
return indices;
}`;
- const offsetToIndices = (varOffset: string) => rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
+ const offsetToIndices = (varOffset: string) => {
+ implementationUsed.offsetToIndices = true;
+ return rank < 2 ? varOffset : `o2i_${name}(${varOffset})`;
+ };
const offsets: string[] = [];
if (rank >= 2) {
@@ -301,7 +297,10 @@ const createIndicesHelper =
return ${offsets.join('+')};
}`;
- const indicesToOffset = (varIndices: string) => rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
+ const indicesToOffset = (varIndices: string) => {
+ implementationUsed.indicesToOffset = true;
+ return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`;
+ };
const indices = (...init: ReadonlyArray) =>
rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`;
@@ -357,17 +356,18 @@ const createIndicesHelper =
}
})();
+ const getByIndicesImplementation = rank < 2 ? '' : `
+ fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
+ return ${name}[i2o_${name}(indices)];
+ }`;
+
const getImplementation = rank < 2 ? '' : (() => {
const params = shape.map((_, i) => `d${i}: u32`).join(', ');
const dims = shape.map((_, i) => `d${i}`).join(', ');
return `
- fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} {
- return ${name}[i2o_${name}(indices)];
- }
fn get_${name}(${params}) -> ${valueType} {
return get_${name}ByIndices(${indices(dims)});
- }
- `;
+ }`;
})();
const get = (...indices: ReadonlyArray) => {
@@ -376,14 +376,16 @@ const createIndicesHelper =
}
const normalizedIndices = indices.map(normalizeDim).join(',');
- const funcName = `get_${name}`;
if (rank === 0) {
return getByOffset('0u');
} else if (rank === 1) {
return getByOffset(normalizedIndices[0]);
} else {
- return `${funcName}(${normalizedIndices})`;
+ implementationUsed.get = true;
+ implementationUsed.getByIndices = true;
+ implementationUsed.indicesToOffset = true;
+ return `get_${name}(${normalizedIndices})`;
}
};
@@ -391,21 +393,24 @@ const createIndicesHelper =
if (rank < 2) {
return getByOffset(varIndices);
} else {
+ implementationUsed.getByIndices = true;
+ implementationUsed.indicesToOffset = true;
return `get_${name}ByIndices(${varIndices})`;
}
};
+ const setByIndicesImplementation = rank < 2 ? '' : `
+ fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
+ ${setByOffset(`i2o_${name}(indices)`, 'value')}
+ }`;
+
const setImplementation = rank < 2 ? '' : (() => {
const params = shape.map((_, i) => `d${i}: u32`).join(', ');
const dims = shape.map((_, i) => `d${i}`).join(', ');
return `
- fn set_${name}ByIndices(indices: ${type.indices}, value: ${valueType}) {
- ${setByOffset(`i2o_${name}(indices)`, 'value')}
- }
fn set_${name}(${params}, value: ${valueType}) {
set_${name}ByIndices(${indices(dims)}, value);
- }
- `;
+ }`;
})();
const set = (...indicesAndValue: ReadonlyArray) => {
@@ -424,6 +429,9 @@ const createIndicesHelper =
} else if (rank === 1) {
return setByOffset(normalizedIndices[0], value);
} else {
+ implementationUsed.set = true;
+ implementationUsed.setByIndices = true;
+ implementationUsed.indicesToOffset = true;
return `set_${name}(${normalizedIndices}, ${value})`;
}
};
@@ -432,32 +440,34 @@ const createIndicesHelper =
if (rank < 2) {
return setByOffset(varIndices, value);
} else {
+ implementationUsed.setByIndices = true;
+ implementationUsed.indicesToOffset = true;
return `set_${name}ByIndices(${varIndices}, ${value});`;
}
};
- const funcImpls = {
- offsetToIndices: offsetToIndicesImplementation,
- indicesToOffset: indicesToOffsetImplementation,
- set: setImplementation,
- get: getImplementation,
- };
- const impl = (...functions: Array) => {
+ const impl = () => {
const impls = [];
- if (functions.length === 0) {
- functions.push('offsetToIndices', 'indicesToOffset', 'set', 'get');
+ if (implementationUsed.offsetToIndices) {
+ impls.push(offsetToIndicesImplementation);
+ }
+ if (implementationUsed.indicesToOffset) {
+ impls.push(indicesToOffsetImplementation);
+ }
+ if (implementationUsed.set) {
+ impls.push(setImplementation);
+ }
+ if (implementationUsed.setByIndices) {
+ impls.push(setByIndicesImplementation);
+ }
+ if (implementationUsed.get) {
+ impls.push(getImplementation);
}
- for (const func of functions) {
- const impl = funcImpls[func];
- if (impl === undefined) {
- throw new Error(`unknown function ${func}`);
- } else {
- impls.push(impl);
- }
+ if (implementationUsed.getByIndices) {
+ impls.push(getByIndicesImplementation);
}
return impls.join('\n');
};
- impl.toString = () => impl();
return {
impl,
@@ -552,6 +562,11 @@ export interface ShaderHelper {
* @param variables - an array of IndicesHelper for the variables.
*/
declareVariables(...variables: IndicesHelper[]): string;
+
+ /**
+ * Get additional implementation that needs to be added to the shader source.
+ */
+ readonly additionalImplementations: string;
}
class ShaderHelperImpl implements ShaderHelper {
@@ -585,6 +600,7 @@ class ShaderHelperImpl implements ShaderHelper {
}
declareVariable(variable: IndicesHelper, bindingIndex: number): string {
+ this.indicesHelpers.push(variable);
const access = variable.usage === 'input' ? 'read' : 'read_write';
const storageType = variable.type.storage;
return `@group(0) @binding(${bindingIndex}) var ${variable.name}: array<${storageType}>;`;
@@ -594,6 +610,12 @@ class ShaderHelperImpl implements ShaderHelper {
let i = 0;
return variables.filter(v => ShapeUtil.size(v.shape) > 0).map(v => this.declareVariable(v, i++)).join('\n');
}
+
+ private indicesHelpers: IndicesHelper[] = [];
+
+ get additionalImplementations(): string {
+ return this.indicesHelpers.map(i => i.impl()).join('\n');
+ }
}
export const createShaderHelper = (dispatchGroup: [number, number, number]): ShaderHelper =>
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
index 8b91b64a09200..9b294803d3787 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/concat.ts
@@ -109,9 +109,6 @@ const createConcatProgramInfo =
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(...inputVars, output)}
- ${inputVars.map(i => i.impl('indicesToOffset', 'get')).join('\n')}
- ${output.impl('offsetToIndices')}
-
const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
${calculateInputIndexImpl(sizeInConcatAxis.length)}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
index 7a0e1f01c461f..8a794ce16a0b5 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-grouped.ts
@@ -47,9 +47,6 @@ const createGroupedConvProgramInfo =
${shaderHelper.declareVariables(...inputVars, output)}
${activationFunction}
- ${output.impl('offsetToIndices')}
- ${x.impl('indicesToOffset', 'get')}
- ${w.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
index b07fe3a90f3b9..2d845775f1c62 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
@@ -58,8 +58,6 @@ const createExpandProgramInfo = (metadata: ProgramMetadata, inputs: readonly Ten
const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = ${input.indices(...inputShape)};
${shaderHelper.declareVariables(input, output)}
- ${output.impl('offsetToIndices')}
- ${input.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
index 2ce8427bb6e7f..f62c766aa9ed0 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
+import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
-import {ShaderHelper} from './common';
+import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
@@ -45,7 +45,7 @@ const createInstanceNormProgramInfo =
Got scale size of ${scaleSize} and bias size of ${biasSize}`);
}
- const dataType = tensorTypeToWsglType(inputs[0].dataType);
+ const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const getShaderSource = (shaderHelper: ShaderHelper) => `
const C: u32 = ${C};
@@ -99,7 +99,7 @@ const createInstanceNormNHWCProgramInfo =
const C = xShape[xShape.length - 1];
const H = ShapeUtil.sizeFromDimension(xShape, 1) / C;
- const dataType = tensorTypeToWsglType(inputs[0].dataType);
+ const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const normCount = C * N;
const getShaderSource = (shaderHelper: ShaderHelper) => `
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
index 48627bfaec401..8a9927b25a52e 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/layer-norm.ts
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
+import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types';
-import {ShaderHelper} from './common';
+import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
export interface LayerNormAttributes extends AttributeWithCacheKey {
axis: number;
@@ -54,7 +54,7 @@ const createLayerNormProgramInfo =
}
}
- const dataType = tensorTypeToWsglType(inputs[0].dataType);
+ const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const hasMeanDataOutput = outputCount > 1;
const hasInvStdOutput = outputCount > 2;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
index 9af8fc7b6d33d..79071d32443d6 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
@@ -128,9 +128,6 @@ const generatePoolingCode = (${attributes.pads.map(i => `${i}u`).join(',')});
const inputDims = array(${inputDims.map(i => `${i}u`).join(',')});
const kernelStrides = array(${kernelStrides.map(i => `${i}u`).join(',')});
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
index b645510d8384b..cb592c838dd97 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
@@ -85,9 +85,6 @@ export const createReduceProgramInfo =
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
- ${output.impl('offsetToIndices')}
- ${input.impl('indicesToOffset')}
-
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
var inputIndices: ${input.type.indices};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
index 505bae7ce2302..1d0b8229a76f7 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
@@ -484,8 +484,6 @@ const createResizeProgramInfo =
}
})()};
${shaderHelper.declareVariables(input, output)}
- ${output.impl('offsetToIndices')}
- ${input.impl('indicesToOffset')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
if (${noScale}) {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
index 96bf1cd9a6ef6..4b845bcf2121b 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts
@@ -1,13 +1,13 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-import {DataType, tensorTypeToWsglType} from '../../../wasm-common';
+import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types';
-import {ShaderHelper} from './common';
+import {ShaderHelper, tensorTypeToWsglStorageType} from './common';
export interface SkipLayerNormAttributes extends AttributeWithCacheKey {
epsilon: number;
@@ -84,7 +84,7 @@ const createSkipLayerNormProgramInfo =
const meanInvStdDevDim = isTraining ? inputShape.slice(0, -1).concat(1) : [];
const hasBetaInput = inputs.length > 3;
const hasBiasInput = inputs.length > 4;
- const dataType = tensorTypeToWsglType(inputs[0].dataType);
+ const dataType = tensorTypeToWsglStorageType(inputs[0].dataType);
const hasMeanOutput = isTraining && outputCount > 1;
const hasInvStdDevOutput = isTraining && outputCount > 2;
const hasInputSkipBiasSumOutput = outputCount > 3;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
index 1f881a75ffbde..4211e526898e6 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
@@ -153,8 +153,6 @@ const createSliceProgramInfo =
const steps = array(${steps.map(i => `${i}u`).join(',')});
const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});
- ${output.impl('offsetToIndices')}
- ${input.impl('indicesToOffset', 'get')}
${calculateInputIndicesImpl(input, output, inputShape, outputShape)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
index 54f493422816f..9a150d21ea02e 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
@@ -23,10 +23,12 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
const createSplitAttributesFromInputs =
(inputs: readonly TensorView[], attributes: SplitAttributes): SplitAttributes => {
const splitSizes: number[] = [];
+ let numOutputs: number = attributes.numOutputs;
if (inputs[1].dims[0] > 0) {
inputs[1].getBigInt64Array().forEach(v => splitSizes.push(Number(v)));
+ numOutputs = splitSizes.length;
}
- return createAttributeWithCacheKey({numOutputs: attributes.numOutputs, axis: attributes.axis, splitSizes});
+ return createAttributeWithCacheKey({numOutputs, axis: attributes.axis, splitSizes});
};
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
@@ -85,8 +87,6 @@ const createSplitProgramInfo =
const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, ...outputs)}
- ${input.impl('indicesToOffset', 'offsetToIndices', 'get')}
- ${outputs.map(o => o.impl('indicesToOffset', 'set')).join('\n')}
const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
${calculateOutputIndexImpl(sizeInConcatAxis.length)}
${writeBufferDataImpl(outputs)}
@@ -114,7 +114,7 @@ const createSplitProgramInfoLoader =
const updatedAttributes = inputs.length === 1 ? attributes : createSplitAttributesFromInputs(inputs, attributes);
const metadata:
ProgramMetadata = {name: 'Split', inputTypes: [GpuDataType.default], cacheHint: updatedAttributes.cacheKey};
- return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], attributes)};
+ return {...metadata, get: () => createSplitProgramInfo(metadata, [inputs[0]], updatedAttributes)};
};
export const split = (context: ComputeContext, attributes: SplitAttributes): void => {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
index 2b80ce173245b..99d9668757caa 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
@@ -66,8 +66,6 @@ export const createTileProgramInfo =
const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = ${input.indices(...inputShape)};
${shaderHelper.declareVariables(input, output)}
- ${output.impl('offsetToIndices')}
- ${input.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let outputIndices = ${output.offsetToIndices('global_idx')};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
index 0b0185fc17c9b..ebedc61712e8a 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
@@ -64,8 +64,6 @@ export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: nu
${shaderHelper.declareVariables(input, output)}
${permFunctionBody(perm, rank, input, output)}
- ${output.impl('offsetToIndices')}
- ${input.impl('indicesToOffset', 'get')}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
index b46b35b71412e..da710b7dc2596 100644
--- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
@@ -114,7 +114,9 @@ export class ProgramManager {
build(programInfo: ProgramInfo, normalizedDispatchGroupSize: [number, number, number]): Artifact {
const device = this.backend.device;
- const code = programInfo.getShaderSource(createShaderHelper(normalizedDispatchGroupSize));
+ const shaderHelper = createShaderHelper(normalizedDispatchGroupSize);
+ const userCode = programInfo.getShaderSource(shaderHelper);
+ const code = `${shaderHelper.additionalImplementations}\n${userCode}`;
const shaderModule = device.createShaderModule({code});
LOG_DEBUG('verbose', () => `[WebGPU] shader code: ${code}`);
diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts
index a89a585906f9d..389773f3e8884 100644
--- a/js/web/lib/wasm/wasm-common.ts
+++ b/js/web/lib/wasm/wasm-common.ts
@@ -164,19 +164,3 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro
throw new Error(`unsupported logging level: ${logLevel}`);
}
};
-
-export const tensorTypeToWsglType = (type: DataType) => {
- switch (type) {
- case DataType.float:
- return 'f32';
- // TODO: enable after "shader-f16" WSGL extension release
- // case DataType.float16:
- // return 'f16';
- case DataType.int32:
- return 'i32';
- case DataType.uint32:
- return 'u32';
- default:
- throw new Error(`Unsupported type: ${type}`);
- }
-};
diff --git a/js/web/package-lock.json b/js/web/package-lock.json
index 4c5649d8806c9..8ad55996f7455 100644
--- a/js/web/package-lock.json
+++ b/js/web/package-lock.json
@@ -1,12 +1,12 @@
{
"name": "onnxruntime-web",
- "version": "1.16.0",
+ "version": "1.16.2",
"lockfileVersion": 2,
"requires": true,
"packages": {
"": {
"name": "onnxruntime-web",
- "version": "1.16.0",
+ "version": "1.16.2",
"license": "MIT",
"dependencies": {
"flatbuffers": "^1.12.0",
@@ -49,7 +49,7 @@
},
"../common": {
"name": "onnxruntime-common",
- "version": "1.16.0",
+ "version": "1.16.2",
"license": "MIT",
"devDependencies": {
"typedoc": "^0.23.22"
diff --git a/js/web/package.json b/js/web/package.json
index ce06475f672fd..76f793263e01a 100644
--- a/js/web/package.json
+++ b/js/web/package.json
@@ -8,7 +8,7 @@
"type": "git"
},
"author": "fs-eire",
- "version": "1.16.0",
+ "version": "1.16.2",
"jsdelivr": "dist/ort.min.js",
"dependencies": {
"flatbuffers": "^1.12.0",
diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc
new file mode 100644
index 0000000000000..a249dc807fa0b
--- /dev/null
+++ b/js/web/test/data/ops/conv-transpose.jsonc
@@ -0,0 +1,289 @@
+[
+ {
+ "name": "ConvTranspose without bias addition A",
+ "operator": "ConvTranspose",
+ "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [10, 20, 30, 40],
+ "dims": [1, 1, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 1, 2, 2],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [10, 40, 40, 60, 200, 160, 90, 240, 160],
+ "dims": [1, 1, 3, 3],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "ConvTranspose without bias addition B",
+ "operator": "ConvTranspose",
+ "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [10, 20, 30, 40, 50, 60, 70, 80],
+ "dims": [1, 2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
+ "dims": [2, 2, 2, 2],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 400, 940, 560, 1080, 2520, 1480, 760, 1740, 1000, 640, 1500, 880, 1720, 3960, 2280, 1160, 2620, 1480
+ ],
+ "dims": [1, 2, 3, 3],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "ConvTranspose with bias addition A",
+ "operator": "ConvTranspose",
+ "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [10, 20, 30, 40],
+ "dims": [1, 4, 1, 1],
+ "type": "float32"
+ },
+ {
+ "data": [
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
+ ],
+ "dims": [4, 4, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0.1, 0.2, 0.3, 0.4],
+ "dims": [4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219,
+ 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781,
+ 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789,
+ 100.4000015258789
+ ],
+ "dims": [1, 4, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "ConvTranspose with bias addition B",
+ "operator": "ConvTranspose",
+ "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [6, 8, 7, 9, 15, 11, 8, 12, 9],
+ "dims": [1, 1, 3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1, 1, 1, 1],
+ "dims": [1, 1, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [5],
+ "dims": [1],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14],
+ "dims": [1, 1, 4, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "ConvTranspose- group - A",
+ "operator": "ConvTranspose",
+ "attributes": [
+ { "name": "kernel_shape", "data": [1, 1], "type": "ints" },
+ { "name": "group", "data": 2, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 18.0, 20.0, 22.0, 24.0, 26.0, 28.0, 30.0, 32.0, 34.0],
+ "dims": [1, 2, 3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1.0, 2.0],
+ "dims": [2, 1, 1, 1],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 36, 40, 44, 48, 52, 56, 60, 64, 68],
+ "dims": [1, 2, 3, 3],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "ConvTranspose- group - B",
+ "operator": "ConvTranspose",
+ "attributes": [
+ { "name": "kernel_shape", "data": [2, 2], "type": "ints" },
+ { "name": "group", "data": 3, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [
+ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
+ 19.0, 20.0, 21.0, 22.0, 23.0, 0, 0, 0
+ ],
+ "dims": [1, 3, 3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
+ "dims": [3, 1, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0.125, 0.25, 0.375],
+ "dims": [3],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 0.125, 1.125, 4.125, 4.125, 3.125, 13.125, 23.125, 18.125, 15.125, 43.125, 53.125, 36.125, 18.125, 45.125,
+ 52.125, 32.125, 45.25, 104.25, 115.25, 66.25, 123.25, 279.25, 305.25, 172.25, 159.25, 357.25, 383.25,
+ 214.25, 105.25, 232.25, 247.25, 136.25, 162.375, 351.375, 370.375, 200.375, 387.375, 833.375, 875.375,
+ 470.375, 231.375, 494.375, 517.375, 276.375, 0.375, 0.375, 0.375, 0.375
+ ],
+ "dims": [1, 3, 4, 4],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "ConvTranspose- group - C",
+ "operator": "ConvTranspose",
+ "attributes": [
+ { "name": "kernel_shape", "data": [2, 2], "type": "ints" },
+ { "name": "group", "data": 3, "type": "int" }
+ ],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [
+ 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0,
+ 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0
+ ],
+ "dims": [1, 3, 3, 4],
+ "type": "float32"
+ },
+ {
+ "data": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0],
+ "dims": [3, 1, 2, 2],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [
+ 0, 1, 4, 7, 6, 4, 16, 26, 36, 26, 20, 56, 66, 76, 50, 24, 59, 66, 73, 44, 60, 137, 148, 159, 90, 164, 368,
+ 394, 420, 234, 212, 472, 498, 524, 290, 140, 307, 322, 337, 184, 216, 465, 484, 503, 270, 516, 1104, 1146,
+ 1188, 634, 596, 1272, 1314, 1356, 722, 352, 747, 770, 793, 420
+ ],
+ "dims": [1, 3, 4, 5],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+
+ {
+ "name": "ConvTranspose- pointwise",
+ "operator": "ConvTranspose",
+ "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }],
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0],
+ "dims": [1, 2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0.0, 1.0, 2.0, 3.0],
+ "dims": [2, 2, 1, 1],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2],
+ "dims": [2],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [9, 11, 13, 15, 14, 18, 22, 26],
+ "dims": [1, 2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ }
+]
diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py
index d39d8edf0b73a..022451c885dd8 100644
--- a/onnxruntime/__init__.py
+++ b/onnxruntime/__init__.py
@@ -7,7 +7,7 @@
For more information on ONNX Runtime, please see `aka.ms/onnxruntime `_
or the `Github project `_.
"""
-__version__ = "1.16.0"
+__version__ = "1.16.2"
__author__ = "Microsoft"
# we need to do device version validation (for example to check Cuda version for an onnxruntime-training package).
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index f1ab3e691b702..4c9c15d07a9b8 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -37,6 +37,7 @@ enum AttentionKernelType {
AttentionKernel_TrtFlashAttention,
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
+ AttentionKernel_FlashAttention,
AttentionKernel_Default
};
@@ -98,8 +99,16 @@ constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTI
// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION";
+// Environment variable to enable or disable flash attention. Default is 0 (enabled).
+constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
+
// Minimum sequence length to enable memory efficient attention in FP32.
-constexpr int kMinSequenceLengthForMemoryEfficientAttentionFp32 = 256;
+constexpr int kMinSeqLenForMemoryEfficientAttentionFp32 = 256;
+
+// Minimum sequence length to prefer flash attention when input format is packed QKV for MultiHeadAttention
+constexpr const char* kMinSeqLenForFlashAttentionPackedQKV = "ORT_MIN_SEQ_LEN_FLASH_ATTENTION_PACKED_QKV";
+// Default value for the above setting.
+constexpr int kDefaultMinSeqLenForFlashAttentionPackedQKV = 513;
} // namespace attention
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index b8066567fc357..c911b6e76701c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -8,6 +8,7 @@
#include "contrib_ops/cuda/bert/attention.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
@@ -39,20 +40,36 @@ REGISTER_KERNEL_TYPED(MLFloat16)
template
Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionBase(info, false) {
- disable_fused_self_attention_ = sizeof(T) != 2 ||
- ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false);
+ disable_fused_self_attention_ =
+ sizeof(T) != 2 ||
+ ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false);
- enable_trt_flash_attention_ = sizeof(T) == 2 &&
- !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false);
+ enable_trt_flash_attention_ =
+ sizeof(T) == 2 &&
+ !ParseEnvironmentVariableWithDefault(attention::kDisableTrtFlashAttention, false);
- enable_fused_causal_attention_ = sizeof(T) == 2 &&
- ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false);
+ enable_fused_causal_attention_ =
+ sizeof(T) == 2 &&
+ ParseEnvironmentVariableWithDefault(attention::kEnableFusedCausalAttention, false);
-#if USE_FLASH_ATTENTION
- disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false);
+#if USE_MEMORY_EFFICIENT_ATTENTION
+ disable_memory_efficient_attention_ =
+ ParseEnvironmentVariableWithDefault(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
+
+#if USE_FLASH_ATTENTION
+ disable_flash_attention_ =
+ sizeof(T) != 2 ||
+ onnxruntime::ParseEnvironmentVariableWithDefault(attention::kDisableFlashAttention, false);
+ min_seq_len_for_flash_attention_packed_qkv_ = ParseEnvironmentVariableWithDefault(
+ attention::kMinSeqLenForFlashAttentionPackedQKV,
+ attention::kDefaultMinSeqLenForFlashAttentionPackedQKV);
+#else
+ disable_flash_attention_ = true;
+ min_seq_len_for_flash_attention_packed_qkv_ = 0;
+#endif
}
template
@@ -100,71 +117,98 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
MHARunner* fused_runner = nullptr;
// Check whether we can use fused kernel
- int sm = device_prop.major * 10 + device_prop.minor;
- bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
- bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
-
- if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
- // GPT fused kernels requires left side padding. mask can be:
- // none (no padding), 1D sequence lengths or 2d mask.
- // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
- // where past state is empty.
- bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
- bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
- nullptr == relative_position_bias &&
- parameters.past_sequence_length == 0 &&
- parameters.hidden_size == parameters.v_hidden_size &&
- FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
- enable_trt_flash_attention_, true);
- if (use_causal_fused_runner) {
- // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
- if (nullptr == fused_fp16_runner_.get()) {
- fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
- enable_trt_flash_attention_, parameters.scale);
- }
+ const int sm = device_prop.major * 10 + device_prop.minor;
+ const bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
- // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
- fused_runner = fused_fp16_runner_.get();
- }
- } else { // BERT
- bool use_fused_runner = !disable_fused_self_attention_ &&
- (nullptr == mask_index || is_mask_1d_seq_len) &&
- nullptr == past &&
- nullptr == present &&
- nullptr == relative_position_bias &&
- parameters.hidden_size == parameters.v_hidden_size &&
- FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
- enable_trt_flash_attention_, false);
-
- if (use_fused_runner) {
- // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
- if (nullptr == fused_fp16_runner_.get()) {
- fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
- enable_trt_flash_attention_, parameters.scale);
+#if USE_FLASH_ATTENTION
+ bool use_flash_attention = !disable_flash_attention_ &&
+ (nullptr == relative_position_bias) &&
+ nullptr == past &&
+ nullptr == present &&
+ parameters.hidden_size == parameters.v_hidden_size &&
+ nullptr == mask_index &&
+ onnxruntime::flash::is_supported(device_prop,
+ parameters.head_size,
+ parameters.num_heads,
+ parameters.num_heads);
+ // When input is packed QKV format, TensorRT kernel might be faster when sequence length <= 512.
+ if (use_flash_attention && parameters.sequence_length < min_seq_len_for_flash_attention_packed_qkv_) {
+ use_flash_attention = false;
+ }
+#else
+ constexpr bool use_flash_attention = false;
+#endif
+
+ if (!use_flash_attention) {
+ if (is_unidirectional_) { // GPT
+ if (enable_fused_causal_attention_) {
+ // GPT fused kernels requires left side padding. mask can be:
+ // none (no padding), 1D sequence lengths or 2d mask.
+ // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
+ // where past state is empty.
+ bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
+ bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
+ nullptr == relative_position_bias &&
+ parameters.past_sequence_length == 0 &&
+ parameters.hidden_size == parameters.v_hidden_size &&
+ FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
+ enable_trt_flash_attention_, true);
+ if (use_causal_fused_runner) {
+ // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
+ if (nullptr == fused_fp16_runner_.get()) {
+ fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
+ enable_trt_flash_attention_, parameters.scale);
+ }
+
+ // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
+ fused_runner = fused_fp16_runner_.get();
+ }
}
+ } else { // BERT
+ bool use_fused_runner = !disable_fused_self_attention_ &&
+ (nullptr == mask_index || is_mask_1d_seq_len) &&
+ nullptr == past &&
+ nullptr == present &&
+ nullptr == relative_position_bias &&
+ parameters.hidden_size == parameters.v_hidden_size &&
+ FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
+ enable_trt_flash_attention_, false);
- // In case some kernel not loaded due to shared memory limit, we need to double check here.
- const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
- if (fused_fp16_runner_->isValid(S)) {
- fused_runner = fused_fp16_runner_.get();
+ if (use_fused_runner) {
+ // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
+ if (nullptr == fused_fp16_runner_.get()) {
+ fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_,
+ enable_trt_flash_attention_, parameters.scale);
+ }
+
+ // In case some kernel not loaded due to shared memory limit, we need to double check here.
+ const int S = fused_fp16_runner_->getSFromMaxSeqLen(sequence_length);
+ if (fused_fp16_runner_->isValid(S)) {
+ fused_runner = fused_fp16_runner_.get();
+ }
}
}
}
-#if USE_FLASH_ATTENTION
- bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
- bool use_memory_efficient_attention = fused_runner == nullptr &&
- !disable_memory_efficient_attention_ &&
- (nullptr == mask_index || is_mask_1d_key_seq_len_start) &&
- nullptr == past &&
- nullptr == present &&
- (nullptr == relative_position_bias || is_good_for_rpb) &&
- (sizeof(T) == 2 || // sequence length threshold is 0 in FP16
- parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) &&
- has_memory_efficient_attention(sm, sizeof(T) == 2);
+#if USE_MEMORY_EFFICIENT_ATTENTION
+ bool use_memory_efficient_attention =
+ !use_flash_attention &&
+ fused_runner == nullptr &&
+ !disable_memory_efficient_attention_ &&
+ nullptr == past &&
+ nullptr == present &&
+ (parameters.head_size & 7) == 0 &&
+ (parameters.v_head_size & 7) == 0 &&
+ (nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
+ (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
+ has_memory_efficient_attention(sm, sizeof(T) == 2);
+
+ if (use_memory_efficient_attention) {
+ bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
+ use_memory_efficient_attention = (nullptr == relative_position_bias || is_good_for_rpb);
+ }
#else
constexpr bool use_memory_efficient_attention = false;
- ORT_UNUSED_PARAMETER(is_mask_1d_key_seq_len_start);
#endif
cublasHandle_t cublas = GetCublasHandle(context);
@@ -199,6 +243,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner,
+ use_flash_attention,
use_fused_cross_attention,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer(workSpaceSize, context->GetComputeStream());
@@ -215,7 +260,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
data.past = (nullptr == past) ? nullptr : reinterpret_cast(past->Data());
data.past_key = nullptr;
data.past_value = nullptr;
- data.relative_position_bias = (nullptr == relative_position_bias) ? nullptr : reinterpret_cast(relative_position_bias->Data());
+ data.relative_position_bias = (nullptr == relative_position_bias)
+ ? nullptr
+ : reinterpret_cast(relative_position_bias->Data());
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast(work_space.get());
data.output = reinterpret_cast(output->MutableData());
@@ -224,6 +271,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
data.present_value = nullptr;
data.fused_runner = reinterpret_cast(fused_runner);
data.fused_cross_attention_kernel = nullptr;
+ data.use_flash_attention = use_flash_attention;
data.use_memory_efficient_attention = use_memory_efficient_attention;
data.cumulated_sequence_length_q_cache = nullptr;
data.cumulated_sequence_length_kv_cache = nullptr;
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.h b/onnxruntime/contrib_ops/cuda/bert/attention.h
index ba7c56c04fdde..455e55ba05a66 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.h
@@ -21,10 +21,12 @@ class Attention final : public CudaKernel, public AttentionBase {
Status ComputeInternal(OpKernelContext* context) const override;
protected:
+ bool disable_flash_attention_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool enable_fused_causal_attention_;
bool disable_memory_efficient_attention_;
+ int min_seq_len_for_flash_attention_packed_qkv_;
mutable std::unique_ptr fused_fp16_runner_;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index 4d478ef158503..ae7696eb9fe0f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -42,6 +42,7 @@ limitations under the License.
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
using namespace onnxruntime::cuda;
using namespace onnxruntime::contrib::attention_softmax_cuda;
@@ -64,7 +65,8 @@ size_t AlignSize(size_t bytes) {
void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) {
if (this->sequence_length != sequence_length) {
ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0);
- LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, this->max_batch_size, sequence_length, stream);
+ LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr,
+ this->max_batch_size, sequence_length, stream);
this->sequence_length = sequence_length;
}
}
@@ -114,6 +116,7 @@ size_t GetAttentionWorkspaceSize(
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner,
+ bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention) {
// Note that q, k and v might need alignment for fused attention kernels.
@@ -121,6 +124,14 @@ size_t GetAttentionWorkspaceSize(
((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size);
#if USE_FLASH_ATTENTION
+ if (use_flash_attention) {
+ return qkv_bytes + onnxruntime::flash::get_softmax_lse_size(sequence_length, batch_size, num_heads);
+ }
+#else
+ ORT_UNUSED_PARAMETER(use_flash_attention);
+#endif
+
+#if USE_MEMORY_EFFICIENT_ATTENTION
if (use_memory_efficient_attention) {
size_t fmha_buffer_bytes = 0;
if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) {
@@ -276,333 +287,439 @@ template Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
half* present);
template
-Status PrepareQkv(contrib::AttentionParameters& parameters,
- AttentionData& data,
- cudaStream_t stream,
- int max_threads_per_block,
- T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block,
+ AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
const int qk_head_size = parameters.head_size;
const int v_head_size = parameters.v_head_size;
const bool past_present_share_buffer = parameters.past_present_share_buffer;
void* fused_runner = data.fused_runner;
- bool use_memory_efficient_attention = data.use_memory_efficient_attention;
+ bool use_flash_or_efficient_attention = data.use_flash_attention || data.use_memory_efficient_attention;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
- // Default format for memory efficient attention.
- // When there is past state, the format shall be BxNxSxH, so we disable memory efficient attention when there is past.
- DUMP_TENSOR_INIT();
- if (nullptr != data.gemm_buffer) {
- if (data.bias == nullptr) {
- assert(nullptr == fused_runner);
- // For quantized attention, bias has been added so only need transpose here.
- // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
- assert(qk_head_size == v_head_size);
- int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.gemm_buffer, qkv, 3));
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
- } else {
- // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
- // For memory efficient attention, transpose to 3xBxSxNxH (format 3)
- // For unfused kernel, transpose to 3xBxNxSxH (format 1)
- // For fused causal kernel, use format 1 since we need have K and V to update present state,
- // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
- const int format = (use_fused_kernel ? 2 : (use_memory_efficient_attention ? 3 : 1));
- qkv_format = use_fused_kernel
- ? AttentionQkvFormat::QKV_BSN3H
- : (use_memory_efficient_attention
- ? AttentionQkvFormat::Q_K_V_BSNH
- : (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH));
-
- // For fused causal, we will update gemm_buffer with bias directly.
- T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
-
- int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
- // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
- // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
- LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
- 3, parameters.do_rotary, parameters.past_sequence_length);
- }
+ if (data.bias == nullptr) {
+ assert(nullptr == fused_runner);
+ // For quantized attention, bias has been added so only need transpose here.
+ // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
+ assert(qk_head_size == v_head_size);
+ int matrix_to_trans = (past_present_share_buffer ? 1 : 3);
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.gemm_buffer, qkv, 3));
+ qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ } else {
+ // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
+ // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
+ // For unfused kernel, transpose to 3xBxNxSxH (format 1)
+ // For fused causal kernel, use format 1 since we need have K and V to update present state,
+ // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
+ const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1));
+ qkv_format = use_fused_kernel
+ ? AttentionQkvFormat::QKV_BSN3H
+ : (use_flash_or_efficient_attention
+ ? AttentionQkvFormat::Q_K_V_BSNH
+ : (use_fused_causal
+ ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
+ : AttentionQkvFormat::Q_K_V_BNSH));
+
+ // For fused causal, we will update gemm_buffer with bias directly.
+ T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
+
+ int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3);
+ // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
+ // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
+ LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
+ 3, parameters.do_rotary, parameters.past_sequence_length);
}
- // attention with past/present state
- else if (data.past_key != nullptr || data.present_key != nullptr) {
- // Below logic does not support memory efficient attention with past (like pass_past_in_kv) but without bias
- if (data.bias == nullptr) {
- // cross attention with past state
- if (data.past_key != nullptr && data.present_key == nullptr) {
- assert(data.past_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key == nullptr);
- assert(data.value == nullptr);
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
- }
- // cross attention with present state or self attention with present state
- else if (data.past_key == nullptr && data.present_key != nullptr) {
- assert(data.past_value == nullptr);
- assert(data.present_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key != nullptr);
- assert(data.value != nullptr);
-
- // TODO: supporting packed qkv for self attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
-
- // TODO: supporting packed kv for cross attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, data.present_key));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, data.present_value));
- }
- // self attention with past and present state
- else {
- assert(data.past_key != nullptr);
- assert(data.past_value != nullptr);
- assert(data.present_key != nullptr);
- assert(data.present_value != nullptr);
- assert(data.query != nullptr);
- assert(data.key != nullptr);
- assert(data.value != nullptr);
- // TODO: supporting packed qkv for self attention may benefit performance
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.query, q));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.key, k));
- ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.value, v));
- }
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ return Status::OK();
+}
+
+// For MultiHeadAttention with past state
+template
+Status PrepareQkv_MHA_WithPast(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block,
+ T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ DUMP_TENSOR_INIT();
+
+ if (data.bias == nullptr) {
+ // Below logic does not support fused attention with past without bias
+ // When there is past state, the format shall be BxNxSxH, so we disable fused attention when there is past.
+
+ // cross attention with past state
+ if (data.past_key != nullptr && data.present_key == nullptr) {
+ assert(data.past_value != nullptr);
+ assert(data.query != nullptr);
+ assert(data.key == nullptr);
+ assert(data.value == nullptr);
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.query, q));
}
-#if USE_FLASH_ATTENTION
- // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
- else if (use_memory_efficient_attention && data.past_key != nullptr && data.past_value != nullptr && parameters.pass_past_in_kv) {
- // Transpose past_key and past_value to use memory efficient attention
-
- // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
- ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.past_key, data.temp_k_workspace));
- // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
- ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.past_value, data.temp_v_workspace));
-
- // query => q, temp_k_workspace => k, temp_v_workspace => v
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
-
- DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
-
- data.past_key = nullptr;
- data.past_value = nullptr;
+ // cross attention with present state or self attention with present state
+ else if (data.past_key == nullptr && data.present_key != nullptr) {
+ assert(data.past_value == nullptr);
+ assert(data.present_value != nullptr);
+ assert(data.query != nullptr);
+ assert(data.key != nullptr);
+ assert(data.value != nullptr);
+
+ // TODO: supporting packed qkv for self attention may benefit performance
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.query, q));
+
+ // TODO: supporting packed kv for cross attention may benefit performance
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.key, data.present_key));
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
+ max_threads_per_block, false, data.value, data.present_value));
}
- // When there is no past_key/past_value and there is present_key/present_value (e.g. get initial kv to use as past_kv in the next iteration)
- else if (use_memory_efficient_attention && data.present_key != nullptr && data.present_value != nullptr) {
- // Use memory efficient attention kernel
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
-
- // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
+ // self attention with past and present state
+ else {
+ assert(data.past_key != nullptr);
+ assert(data.past_value != nullptr);
+ assert(data.present_key != nullptr);
+ assert(data.present_value != nullptr);
+ assert(data.query != nullptr);
+ assert(data.key != nullptr);
+ assert(data.value != nullptr);
+ // TODO: supporting packed qkv for self attention may benefit performance
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.query, q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, false, data.temp_k_workspace, data.present_key));
-
- // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
+ max_threads_per_block, false, data.key, k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
- max_threads_per_block, false, data.temp_v_workspace, data.present_value));
-
- DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size * kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ max_threads_per_block, false, data.value, v));
}
+ qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
+ // When past_key/past_value are inputted directly as key/value and there is no present_key/present_value
+ else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
+ data.past_key != nullptr &&
+ data.past_value != nullptr &&
+ parameters.pass_past_in_kv) {
+ // Transpose past_key and past_value to use memory efficient attention
+
+ // past_key (BxNxSxH) => temp_k_workspace (BxSxNxH)
+ ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.past_key, data.temp_k_workspace));
+ // past_value (BxNxSxH_v) => temp_v_workspace (BxSxNxH_v)
+ ORT_RETURN_IF_ERROR(LaunchTransCtx(stream, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.past_value, data.temp_v_workspace));
+
+ // query => q, temp_k_workspace => k, temp_v_workspace => v
+ LaunchAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, kv_sequence_length,
+ num_heads, qk_head_size, v_head_size,
+ data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
+
+ DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+
+ data.past_key = nullptr;
+ data.past_value = nullptr;
+ }
+ // When there is no past_key/past_value and there is present_key/present_value
+ // (e.g. get initial kv to use as past_kv in the next iteration)
+ else if ((data.use_memory_efficient_attention || data.use_flash_attention) &&
+ data.present_key != nullptr &&
+ data.present_value != nullptr) {
+ // Use memory efficient attention kernel
+ LaunchAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, kv_sequence_length,
+ num_heads, qk_head_size, v_head_size,
+ data.bias, data.query, data.key, data.value, q, data.temp_k_workspace, data.temp_v_workspace);
+
+ // temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, false, data.temp_k_workspace, data.present_key));
+
+ // temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
+ ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
+ max_threads_per_block, false, data.temp_v_workspace, data.present_value));
+
+ DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size, kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", data.temp_v_workspace, batch_size, kv_sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ }
#endif
- else {
- // Use unfused kernel for Q, use unfused kernel for K and V if needed
- constexpr int format = 0;
- // Query (BxSxNxH) => Q (BxNxSxH)
+ else {
+ // Use unfused kernel for Q, use unfused kernel for K and V if needed
+ constexpr int format = 0;
+ // Query (BxSxNxH) => Q (BxNxSxH)
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, q,
+ true, -1);
+
+ if (!parameters.pass_past_in_kv) {
+ T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
+ T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
+
+ // Key (BxLxNxH) => K (BxNxLxH)
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, data.bias + num_heads * qk_head_size, k_dest,
+ true, -1);
+
+ // Value (BxLxNxH_v) => V (BxNxLxH_v)
LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, q,
+ batch_size, kv_sequence_length, num_heads, v_head_size,
+ data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
true, -1);
- if (!parameters.pass_past_in_kv) {
- T* k_dest = (data.past_key == nullptr && data.present_key != nullptr) ? data.present_key : k;
- T* v_dest = (data.past_value == nullptr && data.present_value != nullptr) ? data.present_value : v;
-
- // Key (BxLxNxH) => K (BxNxLxH)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, data.bias + num_heads * qk_head_size, k_dest,
- true, -1);
-
- // Value (BxLxNxH_v) => V (BxNxLxH_v)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, v_head_size,
- data.value, data.bias + 2 * num_heads * qk_head_size, v_dest,
- true, -1);
-
- DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size * num_heads, kv_sequence_length, qk_head_size);
- DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size * num_heads, kv_sequence_length, v_head_size);
- }
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("k(BNSH)", k_dest, batch_size, num_heads, kv_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("v(BNSH)", v_dest, batch_size, num_heads, kv_sequence_length, v_head_size);
}
- } else if (data.key == nullptr) { // gemm_buffer == nullptr and packed qkv
- assert(data.bias == nullptr);
- assert(qk_head_size == v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+ return Status::OK();
+}
- DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
-
- if (use_memory_efficient_attention) {
- // unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
- constexpr int format = 4;
- T* qkv_add_bias = nullptr;
- LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, qkv,
- true, v_head_size, qkv_add_bias, 3);
- DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- } else {
- if (!use_fused_kernel) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
- }
+// For MultiHeadAttention without past state, with packed QKV inputs
+template
+Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block,
+ T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+ void* fused_runner = data.fused_runner;
+
+ T* qkv = data.workspace;
+
+ bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
- qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ assert(data.bias == nullptr);
+ assert(qk_head_size == v_head_size);
+
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("packed_qkv", data.query, batch_size * sequence_length, num_heads, 3, qk_head_size);
+
+ if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ // unpack qkv to BSNH. Note that there is no bias so we need not output query to q.
+ constexpr int format = 4;
+ T* qkv_add_bias = nullptr;
+ LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, qkv,
+ true, v_head_size, qkv_add_bias, 3);
+ DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size, sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ } else {
+ if (!use_fused_kernel) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, NOT_IMPLEMENTED,
+ "packed QKV format is not implemented for current GPU. Please disable it in fusion options.");
}
- } else if (data.value == nullptr) { // gemm_buffer == nullptr and packed kv
- // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
- // CheckInputs verified this constraint.
- assert(data.bias == nullptr);
- assert(qk_head_size == v_head_size);
- DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
-
- if (use_memory_efficient_attention) {
- // unpack kv to BSNH. Note that there is no bias so we need not output query to q.
- constexpr int format = 4;
- T* qkv_add_bias = nullptr;
- const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
- LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, kv_bias, k,
- true, v_head_size, qkv_add_bias, 2);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- } else {
- if (data.fused_cross_attention_kernel == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
- }
+ qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ }
+ return Status::OK();
+}
- qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
+// For MultiHeadAttention without past state, with packed KV inputs
+template
+Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block,
+ T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+ const int batch_size = parameters.batch_size;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+
+ // TODO: unpack kv to BNSH for unfused kernel so that we can remove the following constraint.
+ // CheckInputs verified this constraint.
+ assert(data.bias == nullptr);
+ assert(qk_head_size == v_head_size);
+
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("packed_kv", data.key, batch_size * kv_sequence_length, num_heads, 2, qk_head_size);
+
+ if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ // unpack kv to BSNH. Note that there is no bias so we need not output query to q.
+ constexpr int format = 4;
+ T* qkv_add_bias = nullptr;
+ const T* kv_bias = (data.bias == nullptr ? data.bias : data.bias + parameters.hidden_size);
+ LaunchAddBiasTranspose(stream, 2, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, kv_bias, k,
+ true, v_head_size, qkv_add_bias, 2);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ } else {
+ if (data.fused_cross_attention_kernel == nullptr) {
+ return ORT_MAKE_STATUS(
+ ONNXRUNTIME, NOT_IMPLEMENTED,
+ "packed KV format is not implemented for current GPU. Please disable packed kv in fusion options.");
}
- } else { // gemm_buffer == nullptr and not packed
- assert(data.query != nullptr && data.key != nullptr && data.value != nullptr);
- DUMP_TENSOR_D("query", data.query, batch_size * sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("key", data.key, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
+ }
+ return Status::OK();
+}
+
+// For MultiHeadAttention without past state, with Q, K and V inputs
+template
+Status PrepareQkv_MHA_NotPacked(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block,
+ T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = parameters.kv_sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int qk_head_size = parameters.head_size;
+ const int v_head_size = parameters.v_head_size;
+ void* fused_runner = data.fused_runner;
+
+ T* qkv = data.workspace;
+
+ bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional);
+ bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
+
+ // gemm_buffer == nullptr and not packed
+ assert(data.query != nullptr && data.key != nullptr && data.value != nullptr);
+
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("query", data.query, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("key", data.key, batch_size, kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("value", data.value, batch_size, kv_sequence_length, num_heads, v_head_size);
#if DUMP_TENSOR_LEVEL > 1
- if (data.bias != nullptr) {
- DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
- DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
- DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
- }
+ if (data.bias != nullptr) {
+ DUMP_TENSOR_D("query_bias", data.bias, num_heads, qk_head_size);
+ DUMP_TENSOR_D("key_bias", data.bias + num_heads * qk_head_size, num_heads, qk_head_size);
+ DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
+ }
#endif
- if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
- DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, num_heads, sequence_length, kv_sequence_length);
- }
+ if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
+ DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
+ num_heads, sequence_length, kv_sequence_length);
+ }
- if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
- DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
- }
+ if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
+ DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
+ }
- if (data.fused_cross_attention_kernel != nullptr) {
- assert(qk_head_size == v_head_size);
+ if (data.fused_cross_attention_kernel != nullptr) {
+ assert(qk_head_size == v_head_size);
- // For fused cross attention, besides adding bias, K and V needed to be packed:
- // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
+ // For fused cross attention, besides adding bias, K and V needed to be packed:
+ // K (BxSxNxH), V (BxSxNxH) => BxSxNx2xH
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
- qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
- }
-#if USE_FLASH_ATTENTION
- else if (use_memory_efficient_attention) {
- LaunchAddBias(stream, max_threads_per_block,
- batch_size, sequence_length, kv_sequence_length,
- num_heads, qk_head_size, v_head_size,
- data.bias, data.query, data.key, data.value, q, k, v);
-
- DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
- DUMP_TENSOR_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
- }
+ qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
+ }
+#if USE_MEMORY_EFFICIENT_ATTENTION || USE_FLASH_ATTENTION
+ else if (data.use_memory_efficient_attention || data.use_flash_attention) {
+ LaunchAddBias(stream, max_threads_per_block,
+ batch_size, sequence_length, kv_sequence_length,
+ num_heads, qk_head_size, v_head_size,
+ data.bias, data.query, data.key, data.value, q, k, v);
+
+ DUMP_TENSOR_D("q(BSNH)", q, batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size, kv_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size, kv_sequence_length, num_heads, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
+ }
#endif
- else if (use_fused_kernel) {
- assert(qk_head_size == v_head_size);
-
- // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
- LaunchAddBiasTransposeTrt(
- stream, max_threads_per_block,
- batch_size, sequence_length,
- num_heads, qk_head_size,
- data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
- DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
-
- qkv_format = AttentionQkvFormat::QKV_BSN3H;
- } else { // unfused kernel
- ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal");
-
- // Query (BxSxNxH) => Q (BxNxSxH)
- constexpr int format = 0;
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, sequence_length, num_heads, qk_head_size,
- data.query, data.bias, q,
- true, -1);
-
- // Key (BxLxNxH) => K (BxNxLxH)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, qk_head_size,
- data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
- true, -1);
+ else if (use_fused_kernel) {
+ assert(qk_head_size == v_head_size);
- // Value (BxLxNxH_v) => K (BxNxLxH_v)
- LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
- batch_size, kv_sequence_length, num_heads, v_head_size,
- data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
- true, -1);
+ // Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
+ LaunchAddBiasTransposeTrt(
+ stream, max_threads_per_block,
+ batch_size, sequence_length,
+ num_heads, qk_head_size,
+ data.bias, data.query, data.key, data.value, qkv, false, kv_sequence_length);
+ DUMP_TENSOR_D("qkv(BSN3H)", qkv, batch_size, sequence_length, num_heads, 2 * qk_head_size + v_head_size);
+
+ qkv_format = AttentionQkvFormat::QKV_BSN3H;
+ } else { // unfused kernel
+ ORT_ENFORCE(!use_fused_causal, "MultiHeadAttention has not enabled fused causal");
+
+ // Query (BxSxNxH) => Q (BxNxSxH)
+ constexpr int format = 0;
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, sequence_length, num_heads, qk_head_size,
+ data.query, data.bias, q,
+ true, -1);
+
+ // Key (BxLxNxH) => K (BxNxLxH)
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, qk_head_size,
+ data.key, nullptr == data.bias ? nullptr : data.bias + num_heads * qk_head_size, k,
+ true, -1);
+
+ // Value (BxLxNxH_v) => K (BxNxLxH_v)
+ LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block,
+ batch_size, kv_sequence_length, num_heads, v_head_size,
+ data.value, nullptr == data.bias ? nullptr : data.bias + 2 * num_heads * qk_head_size, v,
+ true, -1);
+
+ DUMP_TENSOR_D("q(BNSH)", q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("k(BNSH)", k, batch_size, num_heads, kv_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("v(BNSH)", v, batch_size, num_heads, kv_sequence_length, v_head_size);
+ qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
+ }
+ return Status::OK();
+}
- DUMP_TENSOR_D("q(BNSH)", q, batch_size * num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("k(BNSH)", k, batch_size * num_heads, kv_sequence_length, qk_head_size);
- DUMP_TENSOR_D("v(BNSH)", v, batch_size * num_heads, kv_sequence_length, v_head_size);
- qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
- }
+template
+Status PrepareQkv(contrib::AttentionParameters& parameters,
+ AttentionData& data,
+ cudaStream_t stream,
+ int max_threads_per_block,
+ T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
+ if (nullptr != data.gemm_buffer) { // Attention operator
+ ORT_RETURN_IF_ERROR(PrepareQkv_Attention(parameters, data, stream, max_threads_per_block, qkv_format));
+ } else if (data.past_key != nullptr || data.present_key != nullptr) { // mha operator with past/present state
+ ORT_RETURN_IF_ERROR(PrepareQkv_MHA_WithPast(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
+ } else if (data.key == nullptr) { // multihead attention operator, no past, packed qkv
+ ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedQKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
+ } else if (data.value == nullptr) { // multihead attention operator, no past, packed kv
+ ORT_RETURN_IF_ERROR(PrepareQkv_MHA_PackedKV(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
+ } else { // multihead attention operator, no past, separated Q/K/V inputs
+ ORT_RETURN_IF_ERROR(PrepareQkv_MHA_NotPacked(parameters, data, stream, max_threads_per_block, q, k, v, qkv_format));
}
CUDA_RETURN_IF_ERROR(cudaGetLastError());
@@ -631,7 +748,10 @@ Status QkvToContext(
void* fused_runner = data.fused_runner;
// At most one fused kernel is enabled.
- assert(int(data.use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1);
+ assert((int(data.use_flash_attention) +
+ int(data.use_memory_efficient_attention) +
+ int(fused_runner != nullptr) +
+ int(data.fused_cross_attention_kernel != nullptr)) <= 1);
const int batches = batch_size * num_heads;
@@ -673,8 +793,9 @@ Status QkvToContext(
if (nullptr != data.present) {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH || qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
ORT_RETURN_IF_ERROR(
- LaunchConcatPastToPresent(stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
- max_threads_per_block, data.past, k, data.present));
+ LaunchConcatPastToPresent(
+ stream, total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
+ max_threads_per_block, data.past, k, data.present));
// Update pointers to present_k and present_v.
k = data.present;
@@ -708,22 +829,25 @@ Status QkvToContext(
cudaMemcpyAsync(data.present_value, data.past_value, v_size * sizeof(T), cudaMemcpyDeviceToDevice, stream);
} else {
ORT_RETURN_IF_ERROR(
- LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
+ LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
+ batch_size, qk_head_size, num_heads,
max_threads_per_block, 1, data.past_key, k, data.present_key));
ORT_RETURN_IF_ERROR(
- LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length, batch_size, v_head_size, num_heads,
+ LaunchConcatTensorToTensor(stream, parameters.total_sequence_length, sequence_length,
+ batch_size, v_head_size, num_heads,
max_threads_per_block, 1, data.past_value, v, data.present_value));
// Update pointers to present_k and present_v.
k = data.present_key;
v = data.present_value;
}
}
- } else {
+ } else { // past_present_share_buffer
assert(qk_head_size == v_head_size);
assert(data.fused_cross_attention_kernel == nullptr);
assert(!use_fused_kernel);
assert(data.gemm_buffer != nullptr);
assert(!data.use_memory_efficient_attention);
+ assert(!data.use_flash_attention);
assert(data.has_qkv_workspace);
if (nullptr != data.past_key || nullptr != data.present_key) {
@@ -799,7 +923,7 @@ Status QkvToContext(
kv_sequence_length, // sequence length of KV
stream);
- DUMP_TENSOR("trt cross output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("trt cross output", data.output, batch_size, sequence_length, num_heads, v_head_size);
return Status::OK();
}
@@ -836,11 +960,11 @@ Status QkvToContext(
}
fused_fp16_runner->run(packed_qkv, sequence_offset, data.output, stream);
- DUMP_TENSOR("fused output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("fused output", data.output, batch_size, sequence_length, num_heads, v_head_size);
} else {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH);
fused_fp16_runner->run(data.gemm_buffer, sequence_offset, data.output, stream);
- DUMP_TENSOR("fused causal output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("fused causal output", data.output, batch_size, sequence_length, num_heads, v_head_size);
}
return Status::OK();
}
@@ -850,6 +974,37 @@ Status QkvToContext(
: parameters.scale;
#if USE_FLASH_ATTENTION
+ if (data.use_flash_attention) {
+ assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
+ assert(nullptr == data.mask_index);
+ assert(nullptr == data.relative_position_bias);
+ assert(parameters.head_size == parameters.v_head_size);
+
+ void* query = reinterpret_cast(q);
+ void* key = reinterpret_cast(k);
+ void* value = reinterpret_cast(v);
+ // For packed KV, we can use query input directly.
+ if (data.gemm_buffer == nullptr && data.key != nullptr && data.value == nullptr && data.bias == nullptr) {
+ query = reinterpret_cast(const_cast(data.query));
+ }
+
+ DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
+
+ constexpr bool is_causal = false;
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd(
+ device_prop, stream, query, key, value, data.output, reinterpret_cast(scratch1),
+ parameters.batch_size, parameters.num_heads, parameters.num_heads, parameters.head_size,
+ parameters.sequence_length, parameters.total_sequence_length, scale, is_causal));
+
+ DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, v_head_size);
+
+ return Status::OK();
+ }
+#endif
+
+#if USE_MEMORY_EFFICIENT_ATTENTION
if (data.use_memory_efficient_attention) {
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
@@ -864,9 +1019,9 @@ Status QkvToContext(
query = data.query;
}
- DUMP_TENSOR_D("attention q(BSNH)", q, batch_size * sequence_length, num_heads * qk_head_size);
- DUMP_TENSOR_D("attention k(BSNH)", k, batch_size * sequence_length, num_heads * qk_head_size);
- DUMP_TENSOR_D("attention v(BSNH)", v, batch_size * sequence_length, num_heads * v_head_size);
+ DUMP_TENSOR_D("q(BSNH)", reinterpret_cast(query), batch_size, sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("k(BSNH)", k, batch_size, parameters.total_sequence_length, num_heads, qk_head_size);
+ DUMP_TENSOR_D("v(BSNH)", v, batch_size, parameters.total_sequence_length, num_heads, v_head_size);
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
@@ -879,19 +1034,28 @@ Status QkvToContext(
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
p.scale = scale;
- p.seqlen_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index));
- p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + batch_size));
- p.seqstart_k_ptr = nullptr == data.mask_index ? nullptr : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1));
+ p.seqlen_k_ptr = nullptr == data.mask_index
+ ? nullptr
+ : const_cast(reinterpret_cast(data.mask_index));
+ p.seqstart_q_ptr = nullptr == data.mask_index
+ ? nullptr
+ : const_cast(reinterpret_cast(data.mask_index + batch_size));
+ p.seqstart_k_ptr = nullptr == data.mask_index
+ ? nullptr
+ : const_cast(reinterpret_cast(data.mask_index + 2 * batch_size + 1));
p.query = query;
p.key = key;
p.value = value;
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
- p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr;
+ p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float))
+ ? scratch1
+ : nullptr;
p.stream = stream;
run_memory_efficient_attention(p);
- DUMP_TENSOR("attention cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, v_head_size);
+
return Status::OK();
}
#endif
@@ -922,9 +1086,9 @@ Status QkvToContext(
q, qk_head_size, sequence_length * qk_head_size,
&zero, scratch1, total_sequence_length, sequence_length * total_sequence_length, batches, device_prop));
- DUMP_TENSOR_D("Q", q, batch_size * num_heads, sequence_length, qk_head_size);
- DUMP_TENSOR_D("K", k, batch_size * num_heads, qk_head_size, sequence_length);
- DUMP_TENSOR_D("QK", scratch1, batch_size * num_heads, sequence_length, total_sequence_length);
+ DUMP_TENSOR_D("Q", q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("K", k, batch_size, num_heads, qk_head_size, sequence_length);
+ DUMP_TENSOR_D("QK", scratch1, batch_size, num_heads, sequence_length, total_sequence_length);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads,
sequence_length, total_sequence_length);
@@ -940,11 +1104,12 @@ Status QkvToContext(
T* persistent_softmax_workspace = scratch1; // replace Q*K' in place with masked score for persistent softmax.
ORT_RETURN_IF_ERROR(
- ComputeSoftmaxWithRawMask(ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
- mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias,
- scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension,
- parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
- mask_filter_value));
+ ComputeSoftmaxWithRawMask(
+ ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
+ mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias,
+ scratch1, scratch2, parameters.is_unidirectional, scale, mask_dimension,
+ parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
+ mask_filter_value));
} else if (nullptr != mask_index) { // 1d mask index
assert(mask_index_dims.size() == 1);
// mask_index has 1D shape: either (batch_size) or (2*batch_size). Only the later one has start postions.
@@ -955,12 +1120,13 @@ Status QkvToContext(
scratch1, scratch2, parameters.is_unidirectional));
} else { // no mask
ORT_RETURN_IF_ERROR(
- ComputeSoftmax(stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias,
- parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional));
+ ComputeSoftmax(
+ stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias,
+ parameters.broadcast_res_pos_bias, scratch1, scratch2, parameters.is_unidirectional));
}
- DUMP_TENSOR_D("Softmax", scratch2, batch_size * num_heads, sequence_length, total_sequence_length);
- DUMP_TENSOR_D("V", v, batch_size * num_heads, sequence_length, v_head_size);
+ DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
+ DUMP_TENSOR_D("V", v, batch_size, num_heads, sequence_length, v_head_size);
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = qkv;
@@ -974,7 +1140,7 @@ Status QkvToContext(
// Temp_output is BxNxSxH_v, transpose to output BxSxNxH_v
Status result = LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, temp_output, data.output);
- DUMP_TENSOR("unfused output", data.output, batch_size * sequence_length, num_heads, v_head_size);
+ DUMP_TENSOR("unfused output", data.output, batch_size, sequence_length, num_heads, v_head_size);
return result;
}
@@ -1109,15 +1275,17 @@ Status DecoderQkvToContext(
if (has_key_padding_mask) {
constexpr int mask_dimension = 2;
constexpr int max_sequence_length = 0;
- ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask(ort_stream, kv_sequence_length, sequence_length, batch_size,
- num_heads, nullptr, key_padding_mask, add_before_softmax,
- false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
- 1.0f, mask_dimension, max_sequence_length, false, nullptr,
- mask_filter_value));
+ ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask(
+ ort_stream, kv_sequence_length, sequence_length, batch_size,
+ num_heads, nullptr, key_padding_mask, add_before_softmax,
+ false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
+ 1.0f, mask_dimension, max_sequence_length, false, nullptr,
+ mask_filter_value));
} else {
- ORT_RETURN_IF_ERROR(ComputeSoftmax(stream, kv_sequence_length, sequence_length, batch_size, num_heads,
- add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2,
- is_unidirectional));
+ ORT_RETURN_IF_ERROR(ComputeSoftmax(
+ stream, kv_sequence_length, sequence_length, batch_size, num_heads,
+ add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2,
+ is_unidirectional));
}
// compute P*V (as V*P), and store in scratch3: BxNxSxH
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index 5c63a8d8a80b6..af7373dd9fa1b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -43,6 +43,7 @@ size_t GetAttentionWorkspaceSize(
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner,
+ bool use_flash_attention,
bool use_fused_cross_attention,
bool use_memory_efficient_attention);
@@ -74,6 +75,7 @@ struct AttentionData {
void* fused_runner;
const void* fused_cross_attention_kernel;
+ bool use_flash_attention;
bool use_memory_efficient_attention;
mutable CumulatedSequenceLengthCache* cumulated_sequence_length_q_cache;
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
index 00fa265e117bc..ed330b0fca332 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#if USE_FLASH_ATTENTION
+#if USE_MEMORY_EFFICIENT_ATTENTION
#if defined(__GNUC__)
#pragma GCC diagnostic push
@@ -124,4 +124,4 @@ void DispatchBlockSize(const MemoryEfficientAttentionParams& params) {
#pragma GCC diagnostic pop
#endif
-#endif // USE_FLASH_ATTENTION
+#endif // USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu
index 237f7ea8c9c42..540a2699587eb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#if USE_FLASH_ATTENTION
+#if USE_MEMORY_EFFICIENT_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
@@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& p
} // namespace contrib
} // namespace onnxruntime
-#endif // USE_FLASH_ATTENTION
+#endif // USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu
index 941ea87baa398..005425c56e0ae 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#if USE_FLASH_ATTENTION
+#if USE_MEMORY_EFFICIENT_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
@@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& p
} // namespace contrib
} // namespace onnxruntime
-#endif // USE_FLASH_ATTENTION
+#endif // USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu
index 5a0e7c9ed5b7a..955423b6c6762 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#if USE_FLASH_ATTENTION
+#if USE_MEMORY_EFFICIENT_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
@@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& p
} // namespace contrib
} // namespace onnxruntime
-#endif // USE_FLASH_ATTENTION
+#endif // USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu
index d0775a29c4cf1..0b54d90c4da30 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#if USE_FLASH_ATTENTION
+#if USE_MEMORY_EFFICIENT_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
@@ -21,4 +21,4 @@ void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& p
} // namespace contrib
} // namespace onnxruntime
-#endif // USE_FLASH_ATTENTION
+#endif // USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu
index 284211f96514d..750cace39ae39 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.cu
@@ -1,6 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
-#if USE_FLASH_ATTENTION
+#if USE_MEMORY_EFFICIENT_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
@@ -27,4 +27,4 @@ void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params
} // namespace contrib
} // namespace onnxruntime
-#endif // USE_FLASH_ATTENTION
+#endif // USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h
index 326ff451e600a..f725be8d7cf89 100644
--- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h
@@ -2,7 +2,7 @@
// Licensed under the MIT License.
#pragma once
-#if USE_FLASH_ATTENTION
+#if USE_MEMORY_EFFICIENT_ATTENTION
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
@@ -58,4 +58,4 @@ void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& p
} // namespace contrib
} // namespace onnxruntime
-#endif // USE_FLASH_ATTENTION
+#endif // USE_MEMORY_EFFICIENT_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
new file mode 100644
index 0000000000000..9db98061bbd66
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
@@ -0,0 +1,40 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+namespace onnxruntime {
+namespace flash {
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+struct BlockInfo {
+ template
+ __device__ BlockInfo(const Params& params, const int bidb)
+ : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
+ sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
+ actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q),
+ actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k) {
+ }
+
+ template
+ inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
+ return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
+ }
+
+ template
+ inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
+ return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
+ }
+
+ const int sum_s_q;
+ const int sum_s_k;
+ const int actual_seqlen_q;
+ const int actual_seqlen_k;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+} // namespace flash
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
new file mode 100644
index 0000000000000..9394a19c9897a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
@@ -0,0 +1,85 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#include
+#include
+
+namespace onnxruntime {
+namespace flash {
+
+constexpr int TOTAL_DIM = 0;
+constexpr int H_DIM = 1;
+constexpr int D_DIM = 2;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct Qkv_params {
+ using index_t = uint32_t;
+ // The QKV matrices.
+ void* __restrict__ q_ptr;
+ void* __restrict__ k_ptr;
+ void* __restrict__ v_ptr;
+
+ // The stride between rows of the Q, K and V matrices.
+ index_t q_batch_stride;
+ index_t k_batch_stride;
+ index_t v_batch_stride;
+ index_t q_row_stride;
+ index_t k_row_stride;
+ index_t v_row_stride;
+ index_t q_head_stride;
+ index_t k_head_stride;
+ index_t v_head_stride;
+
+ // The number of heads.
+ int h, h_k;
+ // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
+ // different from nheads (query).
+ int h_h_k_ratio; // precompute h / h_k,
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+struct Flash_fwd_params : public Qkv_params {
+ // The O matrix (output).
+ void* __restrict__ o_ptr;
+
+ // The stride between rows of O.
+ index_t o_batch_stride;
+ index_t o_row_stride;
+ index_t o_head_stride;
+
+ // The pointer to the P matrix.
+ void* __restrict__ p_ptr;
+
+ // The pointer to the softmax sum.
+ void* __restrict__ softmax_lse_ptr;
+
+ // The dimensions.
+ int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
+
+ // The scaling factors for the kernel.
+ float scale_softmax;
+ float scale_softmax_log2;
+
+ // array of length b+1 holding starting offset of each sequence.
+ int* __restrict__ cu_seqlens_q;
+ int* __restrict__ cu_seqlens_k;
+
+ int* __restrict__ blockmask;
+
+ bool is_bf16 = false;
+ bool is_causal;
+
+ const cudaDeviceProp* dprops;
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
+
+} // namespace flash
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
new file mode 100644
index 0000000000000..87831d1eddfe9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
@@ -0,0 +1,198 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_api.h"
+#include
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash.h"
+#include "contrib_ops/cuda/bert/flash_attention/static_switch.h"
+
+namespace onnxruntime {
+namespace flash {
+
+void set_params_fprop(Flash_fwd_params& params,
+ // sizes
+ size_t batch_size,
+ size_t seqlen_q,
+ size_t seqlen_k,
+ size_t seqlen_q_rounded,
+ size_t seqlen_k_rounded,
+ size_t num_heads,
+ size_t num_heads_k,
+ size_t head_size,
+ size_t head_size_rounded,
+ // device pointers
+ void* q,
+ void* k,
+ void* v,
+ void* out,
+ void* cu_seqlens_q_d,
+ void* cu_seqlens_k_d,
+ void* p_d,
+ void* softmax_lse_d,
+ float softmax_scale,
+ bool is_causal) {
+ // Set the pointers and strides.
+ params.q_ptr = q;
+ params.k_ptr = k;
+ params.v_ptr = v;
+ params.o_ptr = out;
+
+ // All stride are in elements, not bytes.
+ params.q_row_stride = num_heads * head_size;
+ params.k_row_stride = num_heads_k * head_size;
+ params.v_row_stride = num_heads * head_size;
+ params.q_head_stride = head_size;
+ params.k_head_stride = head_size;
+ params.v_head_stride = head_size;
+ params.o_row_stride = num_heads * head_size;
+ params.o_head_stride = head_size;
+ params.is_bf16 = false;
+
+ if (cu_seqlens_q_d == nullptr) {
+ params.q_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
+ params.k_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
+ params.v_batch_stride = seqlen_k * num_heads_k * head_size; // stride(0)
+ params.o_batch_stride = seqlen_q * num_heads * head_size; // stride(0)
+ } else {
+ params.q_batch_stride = 0;
+ params.k_batch_stride = 0;
+ params.v_batch_stride = 0;
+ params.o_batch_stride = 0;
+ }
+
+ params.cu_seqlens_q = static_cast(cu_seqlens_q_d);
+ params.cu_seqlens_k = static_cast(cu_seqlens_k_d);
+
+ // P = softmax(QK^T)
+ params.p_ptr = p_d;
+
+ // Softmax sum
+ params.softmax_lse_ptr = softmax_lse_d;
+
+ // Set the dimensions.
+ params.b = batch_size;
+ params.h = num_heads;
+ params.h_k = num_heads_k;
+ params.h_h_k_ratio = num_heads / num_heads_k;
+ params.seqlen_q = seqlen_q;
+ params.seqlen_k = seqlen_k;
+ params.seqlen_q_rounded = seqlen_q_rounded;
+ params.seqlen_k_rounded = seqlen_k_rounded;
+ params.d = head_size;
+ params.d_rounded = head_size_rounded;
+
+ // Set the different scale values.
+ params.scale_softmax = softmax_scale;
+ params.scale_softmax_log2 = softmax_scale * M_LOG2E;
+
+ params.is_causal = is_causal;
+}
+
+size_t get_softmax_lse_size(int seqlen, int batch_size, int num_heads) {
+ size_t bytes = sizeof(float) * batch_size * num_heads * seqlen;
+ return bytes;
+}
+
+void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream) {
+ FP16_SWITCH(!params.is_bf16, [&] {
+ FWD_HEADDIM_SWITCH(params.d, [&] {
+ run_mha_fwd_(params, stream);
+ });
+ });
+}
+
+Status mha_fwd(const cudaDeviceProp& dprops,
+ cudaStream_t stream,
+ void* q, // batch_size x seqlen_q x num_heads x head_size
+ void* k, // batch_size x seqlen_k x num_heads_k x head_size
+ void* v, // batch_size x seqlen_k x num_heads_k x head_size
+ void* out, // batch_size x seqlen_q x num_heads x head_size
+ void* softmax_lse, // batch_size x num_heads x seqlen_q
+ int batch_size,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ int seqlen_q,
+ int seqlen_k,
+ float softmax_scale,
+ bool is_causal) {
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+ const int head_size_rounded = round_multiple(head_size, 32);
+ const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
+ const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+
+ Flash_fwd_params params;
+ params.dprops = &dprops;
+ set_params_fprop(params,
+ batch_size,
+ seqlen_q, seqlen_k,
+ seqlen_q_rounded, seqlen_k_rounded,
+ num_heads, num_heads_k,
+ head_size, head_size_rounded,
+ q, k, v, out,
+ /*cu_seqlens_q*/ nullptr,
+ /*cu_seqlens_k*/ nullptr,
+ nullptr,
+ softmax_lse,
+ softmax_scale,
+ is_causal);
+
+ run_mha_fwd(params, stream);
+ return Status::OK();
+}
+
+Status mha_varlen_fwd(const cudaDeviceProp& dprops,
+ cudaStream_t stream,
+ void* q, // half (total_q, num_heads, head_size)
+ void* k, // half (total_k, num_heads, head_size)
+ void* v, // half (total_k, num_heads, head_size)
+ void* out, // half (total_q, num_heads, head_size)
+ int* cu_seqlens_q, // int (batch_size + 1)
+ int* cu_seqlens_k, // int (batch_size + 1)
+ void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q)
+ int batch_size,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ int max_seqlen_q,
+ int max_seqlen_k,
+ float softmax_scale,
+ bool is_causal) {
+ auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
+ const int head_size_rounded = round_multiple(head_size, 32);
+ const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
+ const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
+
+ Flash_fwd_params params;
+ params.dprops = &dprops;
+ set_params_fprop(params,
+ batch_size,
+ max_seqlen_q, max_seqlen_k,
+ seqlen_q_rounded, seqlen_k_rounded,
+ num_heads, num_heads_k,
+ head_size, head_size_rounded,
+ q, k, v, out,
+ cu_seqlens_q,
+ cu_seqlens_k,
+ nullptr,
+ softmax_lse,
+ softmax_scale,
+ is_causal);
+ run_mha_fwd(params, stream);
+ return Status::OK();
+}
+
+bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k) {
+ bool is_sm8x = dprops.major == 8 && dprops.minor >= 0;
+ bool is_sm90 = dprops.major == 9 && dprops.minor == 0;
+ return (is_sm8x || is_sm90) && (head_size % 8 == 0) && (head_size <= 256) && (num_heads % num_heads_k == 0);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+
+#endif // USE_FLASH_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
new file mode 100644
index 0000000000000..2ae46d34c373a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -0,0 +1,78 @@
+/******************************************************************************
+ * Copyright (c) 2022, Tri Dao.
+ * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of the NVIDIA CORPORATION nor the
+ * names of its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ *
+ ******************************************************************************/
+
+#pragma once
+
+#if USE_FLASH_ATTENTION
+
+#include "core/providers/cuda/cuda_common.h"
+
+namespace onnxruntime {
+namespace flash {
+Status mha_fwd(const cudaDeviceProp& dprops,
+ cudaStream_t stream,
+ void* q, // batch_size x seqlen_q x num_heads x head_size
+ void* k, // batch_size x seqlen_k x num_heads_k x head_size
+ void* v, // batch_size x seqlen_k x num_heads_k x head_size
+ void* out, // batch_size x seqlen_q x num_heads x head_size
+ void* softmax_lse, // batch_size x num_heads x seqlen_q
+ int batch_size,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ int seqlen_q,
+ int seqlen_k,
+ float softmax_scale,
+ bool is_causal);
+
+Status mha_varlen_fwd(const cudaDeviceProp& dprops,
+ cudaStream_t stream,
+ void* q, // half (total_q, num_heads, head_size)
+ void* k, // half (total_k, num_heads, head_size)
+ void* v, // half (total_k, num_heads, v_head_size)
+ void* out, // half (total_q, num_heads, v_head_size)
+ int* cu_seqlens_q, // int (batch_size + 1)
+ int* cu_seqlens_k, // int (batch_size + 1)
+ void* softmax_lse, // float (batch_size, num_heads, max_seqlen_q)
+ int batch_size,
+ int num_heads,
+ int num_heads_k,
+ int head_size,
+ int max_seqlen_q,
+ int max_seqlen_k,
+ float softmax_scale,
+ bool is_causal);
+
+size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
+
+bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, int num_heads_k);
+
+} // namespace flash
+} // namespace onnxruntime
+
+#endif // USE_FLASH_ATTENTION
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu
new file mode 100644
index 0000000000000..44ea92e58c86e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim128_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim128(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu
new file mode 100644
index 0000000000000..a2bf16bc74e72
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim160_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim160(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu
new file mode 100644
index 0000000000000..56fc04126ab12
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim192_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim192(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu
new file mode 100644
index 0000000000000..6fb24640710a3
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim224_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim224(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu
new file mode 100644
index 0000000000000..94d51e922d7cb
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim256_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim256(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu
new file mode 100644
index 0000000000000..d32eec27634ce
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim32_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim32(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu
new file mode 100644
index 0000000000000..65a2e42192532
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim64_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim64(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu
new file mode 100644
index 0000000000000..f37ee5005855a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_hdim96_fp16_sm80.cu
@@ -0,0 +1,18 @@
+// Copyright (c) 2023, Tri Dao.
+
+// Splitting the different head dimensions to different files to speed up compilation.
+#if USE_FLASH_ATTENTION
+
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template <>
+void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream) {
+ run_mha_fwd_hdim96(params, stream);
+}
+
+} // namespace flash
+} // namespace onnxruntime
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
new file mode 100644
index 0000000000000..b5af31e432d42
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h
@@ -0,0 +1,532 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic push
+#pragma GCC diagnostic ignored "-Wunused-variable"
+#pragma GCC diagnostic ignored "-Wunused-but-set-variable"
+#endif
+
+#include
+#include
+#include
+
+#include
+#include
+#include
+#include
+
+#include "contrib_ops/cuda/bert/flash_attention/block_info.h"
+#include "contrib_ops/cuda/bert/flash_attention/kernel_traits.h"
+#include "contrib_ops/cuda/bert/flash_attention/utils.h"
+#include "contrib_ops/cuda/bert/flash_attention/softmax.h"
+
+namespace onnxruntime {
+namespace flash {
+using namespace cute;
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+CUTE_HOST_DEVICE auto
+make_tiled_copy_A_warpcontiguousM(Copy_Atom const& copy_atom,
+ TiledMMA const& tiled_mma) {
+ using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
+ using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
+ constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value;
+ constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M;
+ constexpr int MMAStride_M = MMA_M * AtomShape_M;
+ auto t = make_tile(cute::Layout, cute::Int>,
+ cute::Stride<_1, cute::Int>>{},
+ make_layout(cute::size<2>(TileShape_MNK{})));
+
+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+CUTE_HOST_DEVICE auto
+make_tiled_copy_C_warpcontiguousM(Copy_Atom const& copy_atom,
+ TiledMMA const& tiled_mma) {
+ using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
+ using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
+ constexpr int AtomShape_M = decltype(cute::size<0>(AtomShape_MNK{}))::value;
+ constexpr int kNWarps = decltype(cute::size<0>(TileShape_MNK{}))::value / AtomShape_M;
+ constexpr int MMAStride_M = MMA_M * AtomShape_M;
+ auto t = make_tile(cute::Layout, cute::Int>,
+ cute::Stride<_1, cute::Int>>{},
+ // TODO: Shouldn't this be size<1>?
+ make_layout(cute::size<2>(TileShape_MNK{})));
+ // if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
+ return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void softmax_rescale_o(Tensor0& scores, Tensor1& scores_max, Tensor1& scores_sum,
+ Tensor2& acc_o, float softmax_scale_log2) {
+ if (Is_first) {
+ flash::template reduce_max*zero_init=*/true>(scores, scores_max);
+ flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
+ flash::reduce_sum(scores, scores_sum);
+ } else {
+ cute::Tensor scores_max_prev = make_fragment_like(scores_max);
+ copy(scores_max, scores_max_prev);
+ flash::template reduce_max*zero_init=*/false>(scores, scores_max);
+ // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+ cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+#pragma unroll
+ for (int mi = 0; mi < cute::size(scores_max); ++mi) {
+ float scores_max_cur = !Check_inf
+ ? scores_max(mi)
+ : (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
+ float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
+ scores_sum(mi) *= scores_scale;
+#pragma unroll
+ for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) {
+ acc_o_rowcol(mi, ni) *= scores_scale;
+ }
+ }
+ flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
+ cute::Tensor scores_sum_cur = make_fragment_like(scores_sum);
+ flash::reduce_sum(scores, scores_sum_cur);
+#pragma unroll
+ for (int mi = 0; mi < cute::size(scores_sum); ++mi) {
+ scores_sum(mi) += scores_sum_cur(mi);
+ }
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void write_softmax_to_gmem(
+ cute::Tensor const& tOrP, cute::Tensor& tPgP, TiledCopy gmem_thr_copy_P) {
+ // Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
+ cute::Layout l = tOrP.layout();
+ cute::Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
+ CUTE_STATIC_ASSERT_V(cute::size<2>(tPgP) == _1{});
+ CUTE_STATIC_ASSERT_V(cute::size<1>(tPrP) == cute::size<1>(tPgP));
+#pragma unroll
+ for (int mi = 0; mi < cute::size<1>(tPrP); ++mi) {
+ copy(gmem_thr_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
+ }
+};
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void compute_attn_1rowblock(const Params& params, const int bidb, const int bidh, const int m_block) {
+ using Element = typename Kernel_traits::Element;
+ using ElementAccum = typename Kernel_traits::ElementAccum;
+ using index_t = typename Kernel_traits::index_t;
+
+ // Shared memory.
+ extern __shared__ char smem_[];
+
+ // The thread index.
+ const int tidx = threadIdx.x;
+
+ constexpr int kBlockM = Kernel_traits::kBlockM;
+ constexpr int kBlockN = Kernel_traits::kBlockN;
+ constexpr int kHeadDim = Kernel_traits::kHeadDim;
+ constexpr int kNWarps = Kernel_traits::kNWarps;
+ constexpr int MMA_M = kBlockM / decltype(cute::size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
+
+ const BlockInfo*Varlen=*/!Is_even_MN> binfo(params, bidb);
+ if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
+
+ int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
+ if (Is_causal) {
+ n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q, kBlockN));
+ }
+
+ // We iterate over the blocks in reverse order. This is because the last block is the only one
+ // that needs masking when we read K and V from global memory. Moreover, iterating in reverse
+ // might save us 1 register (we just need n_block instead of both n_block and n_block_max).
+
+ const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb) + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
+ // We move K and V to the last block.
+ const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
+ const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb) + (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
+ const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
+
+ cute::Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q),
+ cute::Shape, cute::Int>{},
+ make_stride(params.q_row_stride, _1{}));
+ cute::Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr) + row_offset_k),
+ cute::Shape, cute::Int>{},
+ make_stride(params.k_row_stride, _1{}));
+ cute::Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr) + row_offset_v),
+ cute::Shape, cute::Int>{},
+ make_stride(params.v_row_stride, _1{}));
+ cute::Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast(params.p_ptr) + row_offset_p),
+ cute::Shape, cute::Int>{},
+ make_stride(params.seqlen_k_rounded, _1{}));
+
+ cute::Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)),
+ typename Kernel_traits::SmemLayoutQ{});
+ // Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
+ cute::Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : cute::size(sQ)),
+ typename Kernel_traits::SmemLayoutKV{});
+ cute::Tensor sV = make_tensor(sK.data() + cute::size(sK), typename Kernel_traits::SmemLayoutKV{});
+ cute::Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
+ cute::Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
+
+ typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
+ auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
+ typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
+ auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
+
+ cute::Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
+ cute::Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
+ cute::Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
+ cute::Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
+ cute::Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
+ cute::Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
+ cute::Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
+
+ typename Kernel_traits::TiledMma tiled_mma;
+ auto thr_mma = tiled_mma.get_thread_slice(tidx);
+ cute::Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
+ cute::Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
+ cute::Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
+
+ cute::Tensor acc_o = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // MMA, MMA_M, MMA_K
+
+ //
+ // Copy Atom retiling
+ //
+
+ auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
+ cute::Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
+
+ auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
+ auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
+ cute::Tensor tSsK = smem_thr_copy_K.partition_S(sK);
+
+ auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
+ auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
+ cute::Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
+
+ // TODO: this might need to change if we change the mma instruction in SM70
+ cute::Tensor scores_max = make_tensor(cute::Shape(acc_o)>>{});
+ cute::Tensor scores_sum = make_fragment_like(scores_max);
+
+ //
+ // PREDICATES
+ //
+
+ // Construct identity layout for sQ and sK
+ cute::Tensor cQ = make_identity_tensor(make_shape(cute::size<0>(sQ), cute::size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ cute::Tensor cKV = make_identity_tensor(make_shape(cute::size<0>(sK), cute::size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
+
+ // Repeat the partitioning with identity layouts
+ cute::Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+ cute::Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
+
+ // Allocate predicate tensors for k
+ cute::Tensor tQpQ = make_tensor(make_shape(cute::size<2>(tQsQ)));
+ cute::Tensor tKVpKV = make_tensor(make_shape(cute::size<2>(tKsK)));
+
+ // Set predicates for k bounds
+ if (!Is_even_K) {
+#pragma unroll
+ for (int k = 0; k < cute::size(tQpQ); ++k) {
+ tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d;
+ }
+#pragma unroll
+ for (int k = 0; k < cute::size(tKVpKV); ++k) {
+ tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d;
+ }
+ }
+
+ // Prologue
+
+ cute::Tensor tQrQ = make_fragment_like(tQgQ);
+ // We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
+ flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
+ binfo.actual_seqlen_q - m_block * kBlockM);
+ if (Kernel_traits::Is_Q_in_regs) {
+ cute::cp_async_fence();
+ }
+
+ if (Kernel_traits::Share_Q_K_smem) {
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
+ CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
+ __syncthreads();
+ }
+
+ int n_block = n_block_max - 1;
+ // We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
+ flash::copy(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
+ binfo.actual_seqlen_k - n_block * kBlockN);
+ cute::cp_async_fence();
+
+ if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
+ flash::cp_async_wait<1>();
+ __syncthreads();
+ cute::Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
+ CUTE_STATIC_ASSERT_V(cute::size<1>(tSsQ) == cute::size<1>(tSrQ_copy_view)); // M
+ cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
+ }
+
+ clear(acc_o);
+
+ // For performance reason, we separate out two kinds of iterations:
+ // those that need masking on S, and those that don't.
+ // We need masking on S for the very last block when K and V has length not multiple of kBlockN.
+ // We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
+ // We will have at least 1 "masking" iteration.
+
+ // If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
+ // mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
+ constexpr int n_masking_steps = !Is_causal
+ ? 1
+ : (Is_even_MN ? cute::ceil_div(kBlockM, kBlockN) : cute::ceil_div(kBlockM, kBlockN) + 1);
+#pragma unroll
+ for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
+ cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N)
+ clear(acc_s);
+ flash::cp_async_wait<0>();
+ __syncthreads();
+
+ // Advance gV
+ if (masking_step > 0) {
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ } else {
+ // Clear the smem tiles to account for predicated off loads
+ flash::copy(
+ gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN);
+ }
+ cute::cp_async_fence();
+
+ flash::gemm*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K);
+ // if (cute::thread0()) { print(acc_s); }
+
+ // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+
+ // We don't put the masking before the matmul S = Q K^T because we don't clear sK
+ // for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
+ // can produce Inf / NaN.
+ if (!Is_causal) {
+ if (!Is_even_MN) {
+ flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN);
+ }
+ } else {
+ // I can't get the stride from idx_row
+ flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
+ // m_block * kBlockM + get<0>(idx_row(0)),
+ m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
+ binfo.actual_seqlen_q,
+ kNWarps * 16);
+ }
+
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ if (n_block > 0) {
+ // Advance gK
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ // This cp_async_fence needs to be in the if block, otherwise the synchronization
+ // isn't right and we get race conditions.
+ cute::cp_async_fence();
+ }
+
+ // TODO: when we have key_padding_mask we'll need to Check_inf
+ masking_step == 0
+ ? softmax_rescale_o*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
+ : softmax_rescale_o*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+ // Convert scores from fp32 to fp16/bf16
+ cute::Tensor rP = flash::convert_type(scores);
+ // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+ cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout()));
+ // if (Return_softmax) {
+ // cute::Tensor tOrP_copy = make_fragment_like(tOrP);
+ // copy(tOrP, tOrP_copy);
+ // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+ // tPgP.data() = tPgP.data() + (-kBlockN);
+ // }
+
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+
+ // This check is at the end of the loop since we always have at least 1 iteration
+ if (n_masking_steps > 1 && n_block <= 0) {
+ --n_block;
+ break;
+ }
+ }
+
+ // These are the iterations where we don't need masking on S
+ for (; n_block >= 0; --n_block) {
+ cute::Tensor acc_s = partition_fragment_C(tiled_mma, cute::Shape, cute::Int>{}); // (MMA=4, MMA_M, MMA_N)
+ clear(acc_s);
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ // Advance gV
+ tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
+ cute::cp_async_fence();
+
+ flash::gemm*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
+ acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
+ smem_thr_copy_Q, smem_thr_copy_K);
+
+ flash::cp_async_wait<0>();
+ __syncthreads();
+ if (n_block > 0) {
+ // Advance gK
+ tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
+ flash::copy*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
+ // This cp_async_fence needs to be in the if block, otherwise the synchronization
+ // isn't right and we get race conditions.
+ cute::cp_async_fence();
+ }
+
+ // Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
+ cute::Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
+ softmax_rescale_o*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
+
+ cute::Tensor rP = flash::convert_type(scores);
+ // Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
+ // if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
+ cute::Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_rowcol_Aregs(rP.layout()));
+ // if (Return_softmax) {
+ // cute::Tensor tOrP_copy = make_fragment_like(tOrP);
+ // copy(tOrP, tOrP_copy);
+ // flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_thr_copy_P);
+ // tPgP.data() = tPgP.data() + (-kBlockN);
+ // }
+
+ flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
+ }
+
+ // Epilogue
+
+ // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
+ cute::Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
+ cute::Tensor lse = make_fragment_like(scores_sum);
+#pragma unroll
+ for (int mi = 0; mi < cute::size<0>(acc_o_rowcol); ++mi) {
+ float sum = scores_sum(mi);
+ float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
+ lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
+ float scale = inv_sum;
+#pragma unroll
+ for (int ni = 0; ni < cute::size<1>(acc_o_rowcol); ++ni) {
+ acc_o_rowcol(mi, ni) *= scale;
+ }
+ }
+
+ // Convert acc_o from fp32 to fp16/bf16
+ cute::Tensor rO = flash::convert_type(acc_o);
+ cute::Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
+ // Partition sO to match the accumulator partitioning
+ auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
+ auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx); // auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
+ cute::Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
+ cute::Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
+
+ // sO has the same size as sQ, so we don't need to sync here.
+ if (Kernel_traits::Share_Q_K_smem) {
+ __syncthreads();
+ }
+
+ cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
+
+ const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
+ const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
+ cute::Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o),
+ cute::Shape, cute::Int>{},
+ make_stride(params.o_row_stride, _1{}));
+ cute::Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse),
+ cute::Shape>{}, cute::Stride<_1>{});
+
+ typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
+ auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
+ cute::Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
+ cute::Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
+
+ __syncthreads();
+
+ cute::Tensor tOrO = make_tensor(cute::shape(tOgO));
+ cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
+
+ cute::Tensor caccO = make_identity_tensor(cute::Shape, cute::Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ cute::Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
+ static_assert(decltype(cute::size<0>(taccOcO))::value == 4);
+ // Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
+ cute::Tensor taccOcO_row = logical_divide(taccOcO, cute::Shape<_2>{})(make_coord(0, _), _, 0);
+ CUTE_STATIC_ASSERT_V(cute::size(lse) == cute::size(taccOcO_row)); // MMA_M
+ if (get<1>(taccOcO_row(0)) == 0) {
+#pragma unroll
+ for (int mi = 0; mi < cute::size(lse); ++mi) {
+ const int row = get<0>(taccOcO_row(mi));
+ if (row < binfo.actual_seqlen_q - m_block * kBlockM) {
+ gLSE(row) = lse(mi);
+ }
+ }
+ }
+
+ // Construct identity layout for sO
+ cute::Tensor cO = make_identity_tensor(make_shape(cute::size<0>(sO), cute::size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
+ // Repeat the partitioning with identity layouts
+ cute::Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
+ cute::Tensor tOpO = make_tensor(make_shape(cute::size<2>(tOgO)));
+ if (!Is_even_K) {
+#pragma unroll
+ for (int k = 0; k < cute::size(tOpO); ++k) {
+ tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d;
+ }
+ }
+ // Clear_OOB_K must be false since we don't want to write zeros to gmem
+ flash::copy(
+ gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+
+template
+inline __device__ void compute_attn(const Params& params) {
+ const int m_block = blockIdx.x;
+ // The block index for the batch.
+ const int bidb = blockIdx.y;
+ // The block index for the head.
+ const int bidh = blockIdx.z;
+
+ // We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
+ // them to have the same number of threads or have to traverse the attention matrix
+ // in the same order.
+ // In the Philox RNG, we use the offset to store the batch, head, and the lane id
+ // (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
+ // the attention matrix. This way, as long as we have the batch, head, and the location of
+ // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
+
+ flash::compute_attn_1rowblock(params, bidb, bidh, m_block);
+}
+
+////////////////////////////////////////////////////////////////////////////////////////////////////
+} // namespace flash
+} // namespace onnxruntime
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic pop
+#endif
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
new file mode 100644
index 0000000000000..e633ef4d45fbb
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h
@@ -0,0 +1,210 @@
+/******************************************************************************
+ * Copyright (c) 2023, Tri Dao.
+ ******************************************************************************/
+#pragma once
+
+#include "contrib_ops/cuda/bert/flash_attention/static_switch.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash.h"
+#include "contrib_ops/cuda/bert/flash_attention/flash_fwd_kernel.h"
+
+namespace onnxruntime {
+namespace flash {
+
+template
+__global__ void flash_fwd_kernel(Flash_fwd_params params) {
+ flash::compute_attn(params);
+}
+
+template